uniport.model.loss.unbalanced_ot

uniport.model.loss.unbalanced_ot(tran, mu1, var1, mu2, var2, reg=0.1, reg_m=1.0, Couple=None, device='cpu', idx_q=None, idx_r=None, query_weight=None, ref_weight=None)[source]

Calculate a unbalanced optimal transport matrix between mini batches.

Parameters:
  • tran – transport matrix between the two batches sampling from the global OT matrix.

  • mu1 – mean vector of batch 1 from the encoder

  • var1 – standard deviation vector of batch 1 from the encoder

  • mu2 – mean vector of batch 2 from the encoder

  • var2 – standard deviation vector of batch 2 from the encoder

  • reg – Entropy regularization parameter in OT. Default: 0.1

  • reg_m – Unbalanced OT parameter. Larger values means more balanced OT. Default: 1.0

  • Couple – prior information about weights between cell correspondence. Default: None

  • device – training device

  • idx_q – domain_id of query batch

  • idx_r – domain_id of reference batch

  • query_weight – reweighted vectors of query batch

  • ref_weight – reweighted vectors of reference batch

Returns:

  • float – minibatch unbalanced optimal transport loss

  • matrix – minibatch unbalanced optimal transport matrix