#!/usr/bin/env
"""
# Author: Kai Cao
"""
import torch
import numpy as np
import os
import scanpy as sc
from anndata import AnnData
import scipy
import sklearn
import pandas as pd
from scipy.sparse import issparse
from .model.vae import VAE
from .model.utils import EarlyStopping
from .logger import create_logger
from .data_loader import load_data
from .metrics import *
from anndata import AnnData
from sklearn.preprocessing import MaxAbsScaler
from glob import glob
np.warnings.filterwarnings('ignore')
DATA_PATH = os.path.expanduser("~")+'/.uniport/'
CHUNK_SIZE = 20000
def read_mtx(path):
"""
Read mtx format data folder including:
* matrix file: e.g. count.mtx or matrix.mtx or their gz format
* barcode file: e.g. barcode.txt
* feature file: e.g. feature.txt
Parameters
----------
path
the path store the mtx files
Return
------
AnnData
"""
for filename in glob(path+'/*'):
if ('count' in filename or 'matrix' in filename or 'data' in filename) and ('mtx' in filename):
adata = sc.read_mtx(filename).T
for filename in glob(path+'/*'):
if 'barcode' in filename:
barcode = pd.read_csv(filename, sep='\t', header=None).iloc[:, -1].values
adata.obs = pd.DataFrame(index=barcode)
if 'gene' in filename or 'peaks' in filename:
gene = pd.read_csv(filename, sep='\t', header=None).iloc[:, -1].values
adata.var = pd.DataFrame(index=gene)
elif 'feature' in filename:
gene = pd.read_csv(filename, sep='\t', header=None).iloc[:, 1].values
adata.var = pd.DataFrame(index=gene)
return adata
[docs]def load_file(path):
"""
Load single cell dataset from file
Parameters
----------
path
the path store the file
Return
------
AnnData
"""
if os.path.exists(DATA_PATH+path+'.h5ad'):
adata = sc.read_h5ad(DATA_PATH+path+'.h5ad')
elif os.path.isdir(path): # mtx format
adata = read_mtx(path)
elif os.path.isfile(path):
if path.endswith(('.csv', '.csv.gz')):
adata = sc.read_csv(path).T
elif path.endswith(('.txt', '.txt.gz', '.tsv', '.tsv.gz')):
df = pd.read_csv(path, sep='\t', index_col=0).T
adata = AnnData(df.values, dict(obs_names=df.index.values), dict(var_names=df.columns.values))
elif path.endswith('.h5ad'):
adata = sc.read_h5ad(path)
else:
raise ValueError("File {} not exists".format(path))
if not issparse(adata.X):
adata.X = scipy.sparse.csr_matrix(adata.X)
return adata
def tfidf(X, n_components, binarize=True, random_state=0):
from sklearn.feature_extraction.text import TfidfTransformer
sc_count = np.copy(X)
if binarize:
sc_count = np.where(sc_count < 1, sc_count, 1)
tfidf = TfidfTransformer(norm='l2', sublinear_tf=True)
normed_count = tfidf.fit_transform(sc_count)
lsi = sklearn.decomposition.TruncatedSVD(n_components=n_components, random_state=random_state)
lsi_r = lsi.fit_transform(normed_count)
X_lsi = lsi_r[:,1:]
return X_lsi
def TFIDF_LSI(adata, n_comps=50, binarize=True, random_state=0):
'''
Computes LSI based on a TF-IDF transformation of the data from MultiMap. Putative dimensionality
reduction for scATAC-seq data. Adds an ``.obsm['X_lsi']`` field to the object it was ran on.
Input
-----
adata : ``AnnData``
The object to run TFIDF + LSI on. Will use ``.X`` as the input data.
n_comps : ``int``
The number of components to generate. Default: 50
binarize : ``bool``
Whether to binarize the data prior to the computation. Often done during scATAC-seq
processing. Default: True
random_state : ``int``
The seed to use for randon number generation. Default: 0
'''
#this is just a very basic wrapper for the non-adata function
if scipy.sparse.issparse(adata.X):
adata.obsm['X_lsi'] = tfidf(adata.X.todense(), n_components=n_comps, binarize=binarize, random_state=random_state)
else:
adata.obsm['X_lsi'] = tfidf(adata.X, n_components=n_comps, binarize=binarize, random_state=random_state)
[docs]def filter_data(
adata: AnnData,
min_features: int = 0,
min_cells: int = 0,
log=None
):
"""
Filter cells and genes
Parameters
----------
adata
An AnnData matrice of shape n_obs × n_vars. Rows correspond to cells and columns to genes.
min_features
Filtered out cells that are detected in less than n genes. Default: 0.
min_cells
Filtered out genes that are detected in less than n cells. Default: 0.
"""
if log: log.info('Filtering cells')
sc.pp.filter_cells(adata, min_genes=min_features)
if log: log.info('Filtering features')
sc.pp.filter_genes(adata, min_cells=min_cells)
[docs]def batch_scale(adata, use_rep='X', chunk_size=CHUNK_SIZE):
"""
Batch-specific scale data
Parameters
----------
adata
AnnData
use_rep
use '.X' or '.obsm'
chunk_size
chunk large data into small chunks
"""
for b in adata.obs['source'].unique():
idx = np.where(adata.obs['source']==b)[0]
if use_rep == 'X':
scaler = MaxAbsScaler(copy=False).fit(adata.X[idx])
for i in range(len(idx)//chunk_size+1):
adata.X[idx[i*chunk_size:(i+1)*chunk_size]] = scaler.transform(adata.X[idx[i*chunk_size:(i+1)*chunk_size]])
else:
scaler = MaxAbsScaler(copy=False).fit(adata.obsm[use_rep][idx])
for i in range(len(idx)//chunk_size+1):
adata.obsm[use_rep][idx[i*chunk_size:(i+1)*chunk_size]] = scaler.transform(adata.obsm[use_rep][idx[i*chunk_size:(i+1)*chunk_size]])
def get_prior(celltype1, celltype2, alpha=2):
"""
Create a prior correspondence matrix according to cell labels
Parameters
----------
celltype1
cell labels of dataset X
celltype2
cell labels of dataset Y
alpha
the confidence of label, ranges from (1, inf). Higher alpha means better confidence. Default: 2.0
Return
------
torch.tensor
a prior correspondence matrix between cells
"""
Couple = alpha*torch.ones(len(celltype1), len(celltype2))
for i in set(celltype1):
index1 = np.where(celltype1==i)
if i in set(celltype2):
index2 = np.where(celltype2==i)
for j in index1[0]:
Couple[j, index2[0]]=1/alpha
return Couple
[docs]def label_reweight(celltype):
"""
Reweight labels to make all cell types share the same total weight
Parameters
----------
celltype
cell labels
Return
------
torch.tensor
a vector of weights of cells
"""
n = len(celltype)
unique, count = np.unique(celltype, return_counts=True)
p = torch.zeros(n,1)
for i in range(n):
idx = np.where(unique==celltype[i])[0]
tmp = 1/(len(unique)*count[idx])
p[i] = torch.from_numpy(tmp)
weights = p * len(celltype)
return weights
# @profile
[docs]def Run(
adatas=None,
adata_cm=None,
mode='h',
lambda_s=0.5,
lambda_recon=1.0,
lambda_kl=0.5,
lambda_ot=1.0,
iteration=30000,
ref_id=None,
save_OT=False,
use_rep=['X', 'X'],
out='latent',
label_weight=None,
reg=0.1,
reg_m=1.0,
batch_size=256,
lr=2e-4,
enc=None,
gpu=0,
prior=None,
loss_type='BCE',
outdir='output/',
input_id=0,
pred_id=1,
seed=124,
num_workers=4,
patience=30,
batch_key='domain_id',
source_name='source',
model_info=False,
verbose=False,
):
"""
Run data integration
Parameters
----------
adatas
List of AnnData matrices, e.g. [adata1, adata2].
adata_cm
AnnData matrices containing common genes.
mode
Choose from ['h', 'v', 'd']
If 'h', integrate data with common genes (Horizontal integration)
If 'v', integrate data profiled from the same cells (Vertical integration)
If 'd', inetrgate data without common genes (Diagonal integration)
Default: 'h'.
lambda_s
Balanced parameter for common and specific genes. Default: 0.5
lambda_recon:
Balanced parameter for reconstruct term. Default: 1.0
lambda_kl:
Balanced parameter for KL divergence. Default: 0.5
lambda_ot:
Balanced parameter for OT. Default: 1.0
iteration
Max iterations for training. Training one batch_size samples is one iteration. Default: 30000
ref_id
Id of reference dataset. Default: None
save_OT
If True, output a global OT plan. Need more memory. Default: False
use_rep
Use '.X' or '.obsm'. For mode='d' only.
If use_rep=['X','X'], use 'adatas[0].X' and 'adatas[1].X' for integration.
If use_rep=['X','X_lsi'], use 'adatas[0].X' and 'adatas[1].obsm['X_lsi']' for integration.
If use_rep=['X_pca', 'X_lsi'], use 'adatas[0].obsm['X_pca']' and 'adatas[1].obsm['X_lsi']' for integration.
Default: ['X','X']
out
Output of uniPort. Choose from ['latent', 'project', 'predict'].
If out=='latent', train the network and output cell embeddings.
If out=='project', project data into the latent space and output cell embeddings.
If out=='predict', project data into the latent space and output cell embeddings through a specified decoder.
Default: 'latent'.
label_weight
Prior-guided weighted vectors. Default: None
reg:
Entropy regularization parameter in OT. Default: 0.1
reg_m:
Unbalanced OT parameter. Larger values means more balanced OT. Default: 1.0
batch_size
Number of samples per batch to load. Default: 256
lr
Learning rate. Default: 2e-4
enc
Structure of encoder
gpu
Index of GPU to use if GPU is available. Default: 0
prior
Prior correspondence matrix. Default: None
loss_type
type of loss. 'BCE', 'MSE' or 'L1'. Default: 'BCE'
outdir
Output directory. Default: 'output/'
input_id
Only used when mode=='d' and out=='predict' to choose a encoder to project data. Default: 0
pred_id
Only used when out=='predict' to choose a decoder to predict data. Default: 1
seed
Random seed for torch and numpy. Default: 124
patience
early stopping patience. Default: 10
batch_key
Name of batch in AnnData. Default: domain_id
source_name
Name of source in AnnData. Default: source
rep_celltype
Names of cell-type annotation in AnnData. Default: 'cell_type'
umap
If True, perform UMAP for visualization. Default: False
model_info
If True, show structures of encoder and decoders.
verbose
Verbosity, True or False. Default: False
assess
If True, calculate the entropy_batch_mixing score and silhouette score to evaluate integration results. Default: False
show
If True, show the UMAP visualization of latent space. Default: False
Returns
-------
adata.h5ad
The AnnData matrice after integration. The representation of the data is stored at adata.obsm['latent'], adata.obsm['project'] or adata.obsm['predict'].
checkpoint
model.pt contains the variables of the model and config.pt contains the parameters of the model.
log.txt
Records model parameters.
umap.pdf
UMAP plot for visualization if umap=True.
"""
if mode == 'h' and adata_cm is None:
raise AssertionError('adata_cm is needed when mode == "h"!')
if mode not in ['h', 'd', 'v']:
raise AssertionError('mode must be "h", "v" or "d" ')
if adatas is None and adata_cm is None:
raise AssertionError('at least one of adatas and adata_cm should be given!')
np.random.seed(seed) # seed
torch.manual_seed(seed)
if torch.cuda.is_available(): # cuda device
device='cuda'
torch.cuda.set_device(gpu)
else:
device='cpu'
print('Device:', device)
outdir = outdir+'/'
os.makedirs(outdir+'/checkpoint', exist_ok=True)
log = create_logger('', fh=outdir+'log.txt')
use_specific=True
# split adata_cm to adatas
if adatas is None:
use_specific = False
_, idx = np.unique(adata_cm.obs[source_name], return_index=True)
batches = adata_cm.obs[source_name][np.sort(idx)]
flagged = []
for batch in batches:
flagged.append(adata_cm[adata_cm.obs[source_name]==batch].copy())
adatas = flagged
n_domain = len(adatas)
# give reference datasets
if ref_id is None:
ref_id = n_domain-1
tran = {}
num_cell = []
num_gene = []
for i, adata in enumerate(adatas):
if use_rep[i]=='X':
num_cell.append(adata.X.shape[0])
num_gene.append(adata.X.shape[1])
else:
num_cell.append(adata.obsm[use_rep[i]].shape[0])
num_gene.append(adata.obsm[use_rep[i]].shape[1])
# training
if out == 'latent':
for i, adata in enumerate(adatas):
print('Dataset {}:'.format(i), adata.obs[source_name][0])
print(adata)
print('Reference dataset is dataset {}'.format(ref_id))
print('\n')
if adata_cm is not None:
print('Data with common HVG')
print(adata_cm)
print('\n')
if save_OT:
for i in range(n_domain):
if i != ref_id:
ns = num_cell[i]
nt = num_cell[ref_id]
tran_tmp = np.ones((ns, nt)) / (ns * nt)
tran[i] = tran_tmp.astype(np.float32)
print('Size of transport plan between datasets {} and {}:'.format(i, ref_id), np.shape(tran[i]))
trainloader, testloader = load_data(
adatas=adatas,
mode=mode,
use_rep=use_rep,
num_cell=num_cell,
max_gene=max(num_gene),
adata_cm=adata_cm,
use_specific=use_specific,
domain_name=batch_key,
batch_size=batch_size,
num_workers=num_workers
)
early_stopping = EarlyStopping(patience=patience, checkpoint_file=outdir+'/checkpoint/model.pt')
# encoder structure
if enc is None:
enc = [['fc', 1024, 1, 'relu'], ['fc', 16, '', '']]
# decoder structure
dec = {}
if mode == 'd':
for i in range(n_domain):
dec[i] = [['fc', num_gene[i], 1, 'sigmoid']]
elif mode == 'h':
num_gene.append(adata_cm.X.shape[1])
dec[0] = [['fc', num_gene[n_domain], n_domain, 'sigmoid']] # common decoder
if use_specific:
for i in range(1, n_domain+1):
dec[i] = [['fc', num_gene[i-1], 1, 'sigmoid']] # dataset-specific decoder
else:
for i in range(n_domain):
dec[i] = [['fc', num_gene[i], 1, 'sigmoid']] # dataset-specific decoder
# init model
model = VAE(enc, dec, ref_id=ref_id, n_domain=n_domain, mode=mode)
if model_info:
log.info('model\n'+model.__repr__())
model.fit(
trainloader,
tran,
num_cell,
num_gene,
mode=mode,
label_weight=label_weight,
Prior=prior,
save_OT=save_OT,
use_specific=use_specific,
lambda_s=lambda_s,
lambda_recon=lambda_recon,
lambda_kl=lambda_kl,
lambda_ot=lambda_ot,
reg=reg,
reg_m=reg_m,
lr=lr,
max_iteration=iteration,
device=device,
early_stopping=early_stopping,
verbose=verbose,
loss_type=loss_type,
)
torch.save({'enc':enc, 'dec':dec, 'n_domain':n_domain, 'ref_id':ref_id, 'num_gene':num_gene}, outdir+'/checkpoint/config.pt')
# project or predict
else:
state = torch.load(outdir+'/checkpoint/config.pt')
enc, dec, n_domain, ref_id, num_gene = state['enc'], state['dec'], state['n_domain'], state['ref_id'], state['num_gene']
model = VAE(enc, dec, ref_id=ref_id, n_domain=n_domain, mode=mode)
model.load_model(outdir+'/checkpoint/model.pt')
model.to(device)
_, testloader = load_data(
adatas=adatas,
max_gene=max(num_gene),
num_cell=num_cell,
adata_cm=adata_cm,
domain_name=batch_key,
batch_size=batch_size,
mode=mode
)
if mode == 'v':
adatas[0].obsm[out] = model.encodeBatch(testloader, num_gene, pred_id=pred_id, device=device, mode=mode, out=out)
return adatas[0]
elif mode == 'd':
if out == 'latent' or out == 'project':
for i in range(n_domain):
adatas[i].obsm[out] = model.encodeBatch(testloader, num_gene, batch_id=i, device=device, mode=mode, out=out)
for i in range(n_domain-1):
adata_concat = adatas[i].concatenate(adatas[i+1])
elif out == 'predict':
adatas[0].obsm[out] = model.encodeBatch(testloader, num_gene, batch_id=input_id, pred_id=pred_id, device=device, mode=mode, out=out)
elif mode == 'h':
if out == 'latent' or out == 'project':
adata_cm.obsm[out] = model.encodeBatch(testloader, num_gene, device=device, mode=mode, out=out) # save latent rep
elif out == 'predict':
adata_cm.obsm[out] = model.encodeBatch(testloader, num_gene, pred_id=pred_id, device=device, mode=mode, out=out)
if mode == 'h':
if save_OT:
return adata_cm, tran
return adata_cm
else:
if save_OT:
return adata_concat, tran
return adata_concat