uniport.model.loss.distance_gmm

uniport.model.loss.distance_gmm(mu_src: Tensor, mu_dst: Tensor, var_src: Tensor, var_dst: Tensor)[source]

Calculate a Wasserstein distance matrix between the gmm distributions with diagonal variances

Parameters:
  • mu_src – [R, D] matrix, the means of R Gaussian distributions

  • mu_dst – [C, D] matrix, the means of C Gaussian distributions

  • logvar_src – [R, D] matrix, the log(variance) of R Gaussian distributions

  • logvar_dst – [C, D] matrix, the log(variance) of C Gaussian distributions

Returns:

distance matrix

Return type:

[R, C] matrix