Source code for uniport.model.vae

'''
Author: Kai Cao
Modified from SCALEX
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from itertools import cycle
import sys
import time
from .layer import *
from .loss import *


[docs]class VAE(nn.Module): """ Variational Autoencoder framework """
[docs] def __init__(self, enc, dec, ref_id, n_domain, mode): """ Parameters ---------- enc Encoder structure config dec Decoder structure config ref_id ID of reference dataset n_domain The number of different domains mode Choose from ['h', 'v', 'd'] """ super().__init__() x_dim = {} for key in dec.keys(): x_dim[key] = dec[key][-1][1] self.z_dim = enc[-1][1] self.encoder = Encoder(x_dim, enc, mode) self.decoder = Decoder(self.z_dim, dec) self.n_domain = n_domain self.ref_id = ref_id
def load_model(self, path): """ Load trained model parameters dictionary. Parameters ---------- path file path that stores the model parameters """ pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict) def encodeBatch( self, dataloader, num_gene, mode='h', out='latent', batch_id=0, pred_id=1, device='cuda', eval=False, ): """ Inference Parameters ---------- dataloader An iterable over the given dataset for inference. num_gene List of number of genes in different datasets. mode Choose from ['h', 'v', 'd'] If 'h', integrate data with common genes If 'v', integrate data profiled from the same cells If 'd', inetrgate data without common genes Default: 'h'. out Output of uniPort. Choose from ['latent', 'project', 'predict']. If out='latent', train the network and output cell embeddings. If out='project', project data into the latent space and output cell embeddings. If out='predict', project data into the latent space and output cell embeddings through a specified decoder. Default: 'latent'. batch_id Choose which encoder to project data when mode='d'. Default: 0. pred_id Choose which decoder to reconstruct data when out='predict'. device 'cuda' or 'cpu' for . Default: 'cuda'. eval If True, set the model to evaluation mode. If False, set the model to train mode. Default: False. Returns ------- output Cell embeddings (if out='latent' or 'project') or Predicted data (if out='predict'). """ self.to(device) if eval: self.eval() else: self.train() # indices = np.zeros(dataloader.dataset.shape[0]) output = [] if out == 'latent' or out == 'project': # output = np.zeros((dataloader.dataset.shape[0], self.z_dim)) if mode == 'v': for x, idx in dataloader: x = x.float().to(device) z = self.encoder(x[:, 0:num_gene[0]], 0)[1] output.append(z.detach().cpu()) output = torch.cat(output).numpy() elif mode == 'd': for x, y,idx in dataloader: x, y = x[:, 0:num_gene[batch_id]].float().to(device), y.long().to(device) loc = torch.where(y==batch_id)[0] x = x[loc] z = self.encoder(x, batch_id)[1] # z, mu, var output.append(z.detach().cpu()) output = torch.cat(output).numpy() elif mode == 'h': output = np.zeros((dataloader.dataset.shape[0], self.z_dim)) for x,y,idx in dataloader: x_c = x[:, 0:num_gene[self.n_domain]].float().to(device) z = self.encoder(x_c, 0)[1] output[idx] = z.detach().cpu().numpy() # output.append(z.detach().cpu()) # output = torch.cat(output).numpy() elif out == 'predict': if mode == 'v': for x, idx in dataloader: x = x.float().to(device) z = self.encoder(x[:, 0:num_gene[0]], 0)[1] recon = self.decoder(z, pred_id) output.append(recon.detach().cpu()) output = torch.cat(output).numpy() elif mode == 'd': for x, y,idx in dataloader: x, y = x[:, 0:num_gene[batch_id]].float().to(device), y.long().to(device) loc = torch.where(y==batch_id)[0] x = x[loc] z = self.encoder(x, batch_id)[1] # z, mu, var recon = self.decoder(z, pred_id) output.append(recon.detach().cpu()) output = torch.cat(output).numpy() elif mode == 'h': output = np.zeros((dataloader.dataset.shape[0], num_gene[pred_id])) for x,y,idx in dataloader: x_c = x[:, 0:num_gene[self.n_domain]].float().to(device) z = self.encoder(x_c, 0)[1] recon = self.decoder(z, pred_id+1) output[idx] = recon.detach().cpu().numpy() return output def fit( self, dataloader, tran, num_cell, num_gene, mode='h', loss_type='BCE', label_weight=None, Prior=None, save_OT=False, use_specific=True, lambda_s=0.5, lambda_kl=0.5, lambda_recon=1.0, lambda_ot=1.0, reg=0.1, reg_m=1.0, lr=2e-4, max_iteration=30000, early_stopping=None, device='cuda:0', verbose=False, ): """ train VAE Parameters ---------- dataloader An iterable over the given dataset for training. tran A global OT plan. tran={} if save_OT=False in function.py. num_cell List of number of cells in different datasets. num_gene List of number of genes in different datasets. mode Choose from ['h', 'v', 'd'] If 'h', integrate data with common genes If 'v', integrate data profiled from the same cells If 'd', inetrgate data without common genes Default: 'h' loss_type type of loss. Choose between ['BCE', 'MSE'm 'L1']. Default: 'BCE' label_weight Prior-guided weighted vectors. Default: None Prior Prior correspondence matrix. save_OT If True, output a global OT plan. Default: False use_specific If True, specific genes in each dataset will be considered. Default: True lambda_s Balanced parameter for specific genes. Default: 0.5 lambda_kl: Balanced parameter for KL divergence. Default: 0.5 lambda_recon: Balanced parameter for reconstruction. Default: 1.0 lambda_ot: Balanced parameter for OT. Default: 1.0 reg: Entropy regularization parameter in OT. Default: 0.1 reg_m: Unbalanced OT parameter. Larger values means more balanced OT. Default: 1.0 lr Learning rate. Default: 2e-4 max_iteration Max iterations for training. Training one batch_size samples is one iteration. Default: 60000 early_stopping EarlyStopping class (definite in utils.py) for stoping the training if loss doesn't improve after a given patience. Default: None device 'cuda' or 'cpu' for training. Default: 'cuda' verbose Verbosity, True or False. Default: False """ self.to(device) optim = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4) n_epoch = int(np.ceil(max_iteration/len(dataloader))) if loss_type == 'BCE': loss_func = nn.BCELoss() elif loss_type == 'MSE': loss_func = nn.MSELoss() elif loss_type == 'L1': loss_func = nn.L1Loss() with tqdm(range(n_epoch), total=n_epoch, desc='Epochs') as tq: for epoch in tq: tk0 = tqdm(enumerate(dataloader), total=len(dataloader), leave=False, desc='Iterations', disable=(not verbose)) epoch_loss = defaultdict(float) if mode == 'v': for i, (x, idx) in tk0: x = x.float().to(device) idx = idx.to(device) x_list = [] num_sum = [] num_sum.append(num_gene[0]) x_list.append(x[:, 0:num_sum[0]]) for j in range(1, self.n_domain): num_sum.append(num_sum[-1] + num_gene[j]) x_list.append(x[:, num_sum[-2]:num_sum[-1]]) recon_loss = torch.tensor(0.0).to(device) kl_loss = torch.tensor(0.0).to(device) z, mu, var = self.encoder(x_list[0], 0) kl_loss += kl_div(mu, var) recon = self.decoder(z, 0) recon_loss = loss_func(recon, x_list[0]) * 2000 for j in range(1, self.n_domain): recon = self.decoder(z, j) recon_loss += lambda_s * loss_func(recon, x_list[j]) * 2000 ## TO DO loss = {'recon_loss':lambda_recon*recon_loss, 'kl_loss':lambda_kl*kl_loss} optim.zero_grad() sum(loss.values()).backward() optim.step() for k,v in loss.items(): epoch_loss[k] += loss[k].item() info = ','.join(['{}={:.3f}'.format(k, v) for k,v in loss.items()]) tk0.set_postfix_str(info) epoch_loss = {k:v/(i+1) for k, v in epoch_loss.items()} epoch_info = ','.join(['{}={:.3f}'.format(k, v) for k,v in epoch_loss.items()]) tq.set_postfix_str(epoch_info) elif mode == 'd': for i, (x,y,idx) in tk0: x, y = x.float().to(device), y.long().to(device) idx = idx.to(device) if len(torch.unique(y)) < self.n_domain: continue mu_dict = {} var_dict = {} loc_ref = torch.where(y==self.ref_id)[0] idx_ref = idx[loc_ref] - sum(num_cell[0:self.ref_id]) loc_query = {} idx_query = {} tran_batch = {} Prior_batch = None query_id = list(range(self.n_domain)) query_id.remove(self.ref_id) for j in query_id: loc_query[j] = torch.where(y==j)[0] idx_query[j] = idx[loc_query[j]] - sum(num_cell[0:j]) if save_OT: tran_batch[j] = torch.from_numpy(tran[j]).to(device)[idx_query[j]][:,idx_ref] else: tran_batch[j] = None if Prior is not None: Prior_batch = Prior[j][idx_query[j]][:,idx_ref].to(device) recon_loss = torch.tensor(0.0).to(device) kl_loss = torch.tensor(0.0).to(device) ot_loss = torch.tensor(0.0).to(device) loc = loc_query loc[self.ref_id] = loc_ref for j in range(self.n_domain): z_j, mu_j, var_j = self.encoder(x[loc[j]][:, 0:num_gene[j]], j) mu_dict[j] = mu_j var_dict[j] = var_j recon_j = self.decoder(z_j, j) recon_loss += loss_func(recon_j, x[loc[j]][:, 0:num_gene[j]]) * x[loc[j]].size(-1) ## TO DO kl_loss += kl_div(mu_j, var_j) for j in query_id: ot_loss_tmp, tran_batch[j] = unbalanced_ot(tran_batch[j], mu_dict[j], var_dict[j], \ mu_dict[self.ref_id].detach(), var_dict[self.ref_id].detach(), Couple=Prior_batch, device=device) if save_OT: t0 = np.repeat(idx_query[j].cpu().numpy(), len(idx_ref)).reshape(len(idx_query[j]),len(idx_ref)) t1 = np.tile(idx_ref.cpu().numpy(), (len(idx_query[j]), 1)) tran[j][t0,t1] = tran_batch[j].cpu().numpy() ot_loss += ot_loss_tmp loss = {'recon_loss':lambda_recon*recon_loss, 'kl_loss':lambda_kl*kl_loss, 'ot_loss':lambda_ot*ot_loss} optim.zero_grad() sum(loss.values()).backward() optim.step() for k,v in loss.items(): epoch_loss[k] += loss[k].item() info = ','.join(['{}={:.3f}'.format(k, v) for k,v in loss.items()]) tk0.set_postfix_str(info) epoch_loss = {k:v/(i+1) for k, v in epoch_loss.items()} epoch_info = ','.join(['{}={:.3f}'.format(k, v) for k,v in epoch_loss.items()]) tq.set_postfix_str(epoch_info) elif mode == 'h': for i, (x, y, idx) in tk0: x_c, y = x[:, 0:num_gene[self.n_domain]].float().to(device), y.long().to(device) idx = idx.to(device) loc_ref = torch.where(y==self.ref_id)[0] idx_ref = idx[loc_ref] - sum(num_cell[0:self.ref_id]) loc_query = {} idx_query = {} tran_batch = {} Prior_batch = None query_id = list(range(self.n_domain)) query_id.remove(self.ref_id) if len(loc_ref) > 0: for j in query_id: loc_query[j] = torch.where(y==j)[0] idx_query[j] = idx[loc_query[j]] - sum(num_cell[0:j]) if save_OT: if len(idx_query[j]) != 0: if (len(idx_query[j])) == 1: tran_batch[j] = torch.from_numpy(tran[j]).to(device)[idx_query[j]][idx_ref] else: tran_batch[j] = torch.from_numpy(tran[j]).to(device)[idx_query[j]][:,idx_ref] else: tran_batch[j] = None if Prior is not None: Prior_batch = Prior[j][idx_query[j]][:,idx_ref].to(device) ot_loss = torch.tensor(0.0).to(device) recon_loss = torch.tensor(0.0).to(device) kl_loss = torch.tensor(0.0).to(device) loc = loc_query loc[self.ref_id] = loc_ref idx = idx_query idx[self.ref_id] = idx_ref z, mu, var = self.encoder(x_c, 0) recon_x_c = self.decoder(z, 0, y) if label_weight is None: recon_loss = loss_func(recon_x_c, x_c) * 2000 else: for j, weight in enumerate(label_weight): if len(loc[j])>0: if weight is None: recon_loss += 1/self.n_domain * loss_func(recon_x_c[loc[j]], x_c[loc[j]]) * 2000 else: weight = weight.to(device) recon_loss += 1/self.n_domain * F.binary_cross_entropy(recon_x_c[loc[j]], x_c[loc[j]], weight=weight[idx[j]]) * 2000 kl_loss = kl_div(mu, var) if use_specific: x_s = x[:, num_gene[self.n_domain]:].float().to(device) for j in range(self.n_domain): if len(loc[j])>0: recon_x_s = self.decoder(z[loc[j]], j+1) recon_loss += lambda_s * loss_func(recon_x_s, x_s[loc[j]][:, 0:num_gene[j]]) * 2000 if len(torch.unique(y))>1 and len(loc[self.ref_id])!=0: mu_dict = {} var_dict = {} for j in range(self.n_domain): if len(loc[j])>0: mu_dict[j] = mu[loc[j]] var_dict[j] = var[loc[j]] for j in query_id: if len(loc[j])>0: if label_weight is None: ot_loss_tmp, tran_batch[j] = unbalanced_ot( tran_batch[j], mu_dict[j], var_dict[j], mu_dict[self.ref_id].detach(), var_dict[self.ref_id].detach(), reg=reg, reg_m=reg_m, idx_q=idx_query[j], idx_r=idx_ref, Couple=Prior_batch, device=device, ) else: ot_loss_tmp, tran_batch[j] = unbalanced_ot( tran_batch[j], mu_dict[j], var_dict[j], mu_dict[self.ref_id].detach(), var_dict[self.ref_id].detach(), reg=reg, reg_m=reg_m, idx_q=idx_query[j], idx_r=idx_ref, Couple=Prior_batch, device=device, query_weight=label_weight[j], ref_weight=label_weight[self.ref_id], ) if save_OT: t0 = np.repeat(idx_query[j].cpu().numpy(), len(idx_ref)).reshape(len(idx_query[j]),len(idx_ref)) t1 = np.tile(idx_ref.cpu().numpy(), (len(idx_query[j]), 1)) tran[j][t0,t1] = tran_batch[j].cpu().numpy() ot_loss += ot_loss_tmp loss = {'recloss':lambda_recon*recon_loss, 'klloss':lambda_kl*kl_loss, 'otloss':lambda_ot*ot_loss} optim.zero_grad() sum(loss.values()).backward() optim.step() for k,v in loss.items(): epoch_loss[k] += loss[k].item() info = ','.join(['{}={:.2f}'.format(k, v) for k,v in loss.items()]) tk0.set_postfix_str(info) epoch_loss = {k:v/(i+1) for k, v in epoch_loss.items()} epoch_info = ','.join(['{}={:.2f}'.format(k, v) for k,v in epoch_loss.items()]) tq.set_postfix_str(epoch_info) early_stopping(sum(epoch_loss.values()), self) if early_stopping.early_stop: print('EarlyStopping: run {} epoch'.format(epoch+1)) break