Tutorial 4: Integrating mouse embryo slices
This tutorial demonstrates STAligner’s ablility to integrate four mouse embryo slices sampled at the time stages of E9.5, E10.5, E11.5 and E12.5 profiled by Stereo-seq. The raw data can be downloaded from https://db.cngb.org/stomics/mosta/.
Preparation
[1]:
import warnings
warnings.filterwarnings("ignore")
[2]:
import STAligner
# the location of R (used for the mclust clustering)
import os
os.environ['R_HOME'] = "/mnt/disk1/xzhou/anaconda3/envs/STAligner/lib/R"
os.environ['R_USER'] = "/mnt/disk1/xzhou/anaconda3/envs/STAligner/lib/python3.8/site-packages/rpy2"
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import scipy.sparse as sp
import scipy.linalg
import torch
used_device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
Load Data
[3]:
Batch_list = []
adj_list = []
section_ids = ['E9.5_E1S1', 'E10.5_E2S1', 'E11.5_E1S1', 'E12.5_E1S1']
for section_id in section_ids:
print(section_id)
adata = sc.read_h5ad(os.path.join("Data/" + section_id + ".MOSTA.h5ad"))
adata.X = adata.layers['count']
# make spot name unique
adata.obs_names = [x + '_' + section_id for x in adata.obs_names]
STAligner.Cal_Spatial_Net(adata, rad_cutoff=1.3)
# Normalization
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000) #ensure enough common HVGs in the combined matrix
adata = adata[:, adata.var['highly_variable']]
adj_list.append(adata.uns['adj'])
Batch_list.append(adata)
E9.5_E1S1
------Calculating spatial graph...
The graph contains 23166 edges, 5913 cells.
3.9178 neighbors per cell on average.
E10.5_E2S1
------Calculating spatial graph...
The graph contains 33410 edges, 8494 cells.
3.9334 neighbors per cell on average.
E11.5_E1S1
------Calculating spatial graph...
The graph contains 119204 edges, 30124 cells.
3.9571 neighbors per cell on average.
E12.5_E1S1
------Calculating spatial graph...
The graph contains 204168 edges, 51365 cells.
3.9748 neighbors per cell on average.
Concat the scanpy objects for multiple slices
[4]:
adata_concat = ad.concat(Batch_list, label="slice_name", keys=section_ids)
adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')
print('adata_concat.shape: ', adata_concat.shape)
adata_concat.shape: (95896, 693)
Concat the spatial network for multiple slices
[5]:
adj_concat = np.asarray(adj_list[0].todense())
for batch_id in range(1,len(section_ids)):
adj_concat = scipy.linalg.block_diag(adj_concat, np.asarray(adj_list[batch_id].todense()))
adata_concat.uns['edgeList'] = np.nonzero(adj_concat)
Running STAligner
[6]:
# Important parameter:
# "iter_comb" is used to specify the order of integration
# "margin" is used to control the intensity/weight of batch correction
iter_comb = [(0, 3), (1, 3), (2, 3)] ## Fix slice 3 as reference to align
adata_concat = STAligner.train_STAligner(adata_concat, verbose=True, knn_neigh = 100, iter_comb = iter_comb,
margin=2.5, device=used_device)
STAligner(
(conv1): GATConv(693, 512, heads=1)
(conv2): GATConv(512, 30, heads=1)
(conv3): GATConv(30, 512, heads=1)
(conv4): GATConv(512, 693, heads=1)
)
Pretrain with STAGATE...
100%|██████████| 500/500 [00:34<00:00, 14.68it/s]
Train with STAligner...
0%| | 0/500 [00:00<?, ?it/s]
Update spot triplets at epoch 500
20%|█▉ | 99/500 [00:39<00:27, 14.54it/s]
Update spot triplets at epoch 600
40%|███▉ | 199/500 [01:33<00:20, 14.38it/s]
Update spot triplets at epoch 700
60%|█████▉ | 299/500 [02:35<00:13, 14.39it/s]
Update spot triplets at epoch 800
80%|███████▉ | 399/500 [03:38<00:07, 14.40it/s]
Update spot triplets at epoch 900
100%|██████████| 500/500 [04:43<00:00, 1.76it/s]
Clustering
[7]:
sc.pp.neighbors(adata_concat, use_rep='STAligner', random_state=666)
sc.tl.louvain(adata_concat, random_state=666, key_added="louvain", resolution=0.4)
Visualization
[8]:
sc.tl.umap(adata_concat, random_state=666)
colors_default = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
'#8c564b', '#e377c2', '#bcbd22', '#17becf', '#aec7e8',
'#ffbb78', '#98df8a', '#ff9896', '#bec1d4', '#bb7784',
'#0000ff']
adata_concat.uns['louvain_colors'] = [colors_default[0:][i] for i in np.sort(adata_concat.obs['louvain'].unique().astype('int'))]
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams["figure.figsize"] = (2, 2)
plt.rcParams['font.size'] = 10
sc.pl.umap(adata_concat, color=['batch_name', 'louvain'], ncols=2, wspace=1, show=True)
[9]:
for ss in range(len(section_ids)):
Batch_list[ss].obs['louvain'] = adata_concat[adata_concat.obs['batch_name'] == section_ids[ss]].obs['louvain'].values
Batch_list[ss].uns['louvain_colors'] = [colors_default[0:][i] for i in np.sort(adata_concat[adata_concat.obs['batch_name'] ==
section_ids[ss]].obs['louvain'].unique().astype('int'))]
import matplotlib.pyplot as plt
spot_size = 1
title_size = 10
fig, ax = plt.subplots(1, len(section_ids), figsize=(len(section_ids)*3, 3), gridspec_kw={'wspace':1, 'hspace': 0.1})
for ss in range(len(section_ids)-1):
_sc_0 = sc.pl.spatial(Batch_list[ss], img_key=None, color=['louvain'], title=['louvain'], size=1.5, legend_fontsize=8,
show=False, frameon=False, ax=ax[ss], spot_size=spot_size)
_sc_0[0].set_title(section_ids[ss], size=title_size)
_sc_0 = sc.pl.spatial(Batch_list[-1], img_key=None, color=['louvain'], title=['louvain'], size=1.5, legend_fontsize=8,
show=False, frameon=False, ax=ax[-1], spot_size=spot_size)
_sc_0[0].set_title(section_ids[-1], size=title_size)
plt.show()
[ ]: