uniport.model.loss.unbalanced_ot
- uniport.model.loss.unbalanced_ot(tran, mu1, var1, mu2, var2, reg=0.1, reg_m=1.0, Couple=None, device='cpu', idx_q=None, idx_r=None, query_weight=None, ref_weight=None)[source]
Calculate a unbalanced optimal transport matrix between mini batches.
- Parameters:
tran – transport matrix between the two batches sampling from the global OT matrix.
mu1 – mean vector of batch 1 from the encoder
var1 – standard deviation vector of batch 1 from the encoder
mu2 – mean vector of batch 2 from the encoder
var2 – standard deviation vector of batch 2 from the encoder
reg – Entropy regularization parameter in OT. Default: 0.1
reg_m – Unbalanced OT parameter. Larger values means more balanced OT. Default: 1.0
Couple – prior information about weights between cell correspondence. Default: None
device – training device
idx_q – domain_id of query batch
idx_r – domain_id of reference batch
query_weight – reweighted vectors of query batch
ref_weight – reweighted vectors of reference batch
- Returns:
float – minibatch unbalanced optimal transport loss
matrix – minibatch unbalanced optimal transport matrix