diff --git a/RELEASES.md b/RELEASES.md index d187eec78..4e41edd76 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Add batch FUGW loss to `ot.batch` (PR #775) #### Closed issues diff --git a/ot/batch/_linear.py b/ot/batch/_linear.py index a63fcb404..1a9ec1955 100644 --- a/ot/batch/_linear.py +++ b/ot/batch/_linear.py @@ -147,7 +147,7 @@ def loss_linear_batch(M, T, nx=None): return nx.sum(M * T, axis=(1, 2)) -def loss_linear_samples_batch(X, Y, T, metric="l2"): +def loss_linear_samples_batch(X, Y, T, metric="sqeuclidean"): r"""Computes the linear optimal transport loss given samples and transport plan. This is the equivalent of calling `dist_batch` and then `loss_linear_batch`. diff --git a/ot/batch/_quadratic.py b/ot/batch/_quadratic.py index 0da4b8962..982ef7cfd 100644 --- a/ot/batch/_quadratic.py +++ b/ot/batch/_quadratic.py @@ -10,8 +10,9 @@ from ..utils import OTResult from ot.backend import get_backend -from ot.batch._linear import loss_linear_batch +from ot.batch._linear import loss_linear_batch, loss_linear_samples_batch from ot.batch._utils import bmv, bop, bregman_log_projection_batch +from ot.utils import list_to_array def tensor_batch( @@ -152,6 +153,90 @@ def h2(C2): return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric) +def div_to_product_batch( + T, a, b, T1=None, T2=None, divergence="kl", mass=True, nx=None +): + r"""Fast computation of the Bregman divergence between a batch of arbitrary measures and a product measures. + Only support for Kullback-Leibler and half-squared L2 divergences. + + - For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \pi - a \otimes b ||^2 + = \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big] + + - For Kullback-Leibler divergence: + + .. math:: + KL(\pi | a \otimes b) + = \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle + - \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b) + + where : + + - :math:`\pi` is the (`dim_a`, `dim_b`) transport plan + - :math:`\pi_1` and :math:`\pi_2` are the marginal distributions + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + pi : array-like (B, n, m) + Transport plan for each problem in the batch + a : array-like (B,n) + Unnormalized histogram of dimension `n` for each problem in the batch + b : array-like (B,m) + Unnormalized histogram of dimension `m` for each problem in the batch + T1 : array-like (B, n), optional (default = None) + Marginal distribution with respect to the first dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + T2 : array-like (B, m), optional (default = None) + Marginal distribution with respect to the second dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + mass : bool, optional. Default is False. + Only used in case of Kullback-Leibler divergence. + If False, calculate the relative entropy. + If True, calculate the Kullback-Leibler divergence. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + Bregman divergence between an arbitrary measure and a product measure for each problem in the batch. + """ + + arr = [T, a, b, T1, T2] + + if nx is None: + nx = get_backend(*arr, T1, T2) + + if divergence == "kl": + if T1 is None: + T1 = nx.sum(T, 2) + if T2 is None: + T2 = nx.sum(T, 1) + + if divergence == "kl": + res = ( + nx.sum((T * nx.log(T + 1.0 * (T == 0))), (1, 2)) + - nx.sum(T1 * nx.log(a), 1) + - nx.sum(T2 * nx.log(b), 1) + ) + if mass: + res = res - nx.sum(T1, 1) + nx.sum(a, 1) * nx.sum(b, 1) + + elif divergence == "l2": + res = ( + nx.sum(T**2, (1, 2)) + + nx.sum(a**2, 1) * nx.sum(b**2, 1) + - 2 * nx.sum((a * (T @ b[:, :, None]).squeeze(-1)), 1) + ) / 2 + + return res + + def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None): r""" Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version. @@ -205,7 +290,7 @@ def loss_quadratic_samples_batch( C2, T, loss="sqeuclidean", - symmetric=None, + symmetric=True, nx=None, logits=None, recompute_const=False, @@ -266,6 +351,198 @@ def loss_quadratic_samples_batch( ) +def loss_fugw_batch( + a, + b, + L, + M, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + nx=None, +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + L : dict + Cost tensor as returned by `tensor_batch`. + M : array-like, shape (B, n, m) + Cost matrix between features across domains. + T : array-like, shape (B, n, m) + Transport plan. + alpha : float, array-like or list (B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float array-like or list(B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. + nx : module, optional + Backend to use. Default is None. + """ + if nx is None: + nx = get_backend(T) + + B = T.shape[0] + + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_batch( + L, T, recompute_const=recompute_const, symmetric=symmetric, nx=nx + ) + + linear = loss_linear_batch(M, T, nx=nx) + + unbalanced = div_to_product_batch( + T, + a, + b, + divergence=divergence, + mass=True, + nx=nx, + ) + + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced + + +def loss_fugw_samples_batch( + a, + b, + C1, + C2, + X, + Y, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + metric_linear="sqeuclidean", + metric_quadratic="sqeuclidean", + logits=None, + nx=None, +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (quadratic term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + a : array-like, shape (B, n) + Source distributions. + b : array-like, shape (B, m) + Target distributions. + C1 : array-like, shape (B, n, n) or (B, n, n, d) + Source cost matrices for the quadratic term. + C2 : array-like, shape (B, m, m) or (B, n, n, d) + Target cost matrices for the quadratic term. + X : array-like, shape (B, n, d) + Samples from source distribution for the linear term + Y : array-like, shape (B, m, d) + Samples from target distribution for the linear term + T : array-like, shape (B, n, m) + Transport plan. + alpha : float or array-like or list(B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float or array-like or list(B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. + metric_linear : str, optional + Metric for the linear term, 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' + metric_quadratic : str, optional + Metric to use for the quadratic term. Supported values: 'sqeuclidean', 'kl'. + Default is 'sqeuclidean'. + logits : bool, optional + For KL divergence, whether inputs are logits (unnormalized log probabilities). + If True, inputs are treated as logits. Default is None. + nx : module, optional + Backend to use. Default is None. + """ + if nx is None: + nx = get_backend(T) + + B = T.shape[0] + + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_samples_batch( + a, + b, + C1, + C2, + T, + loss=metric_quadratic, + symmetric=symmetric, + nx=nx, + logits=logits, + recompute_const=recompute_const, + ) + + linear = loss_linear_samples_batch(X, Y, T, metric=metric_linear) + + unbalanced = div_to_product_batch( + T, + a, + b, + divergence=divergence, + mass=True, + nx=nx, + ) + + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced + + def solve_gromov_batch( C1, C2, diff --git a/test/batch/test_solve_batch.py b/test/batch/test_solve_batch.py index 45a7e69fe..fc9a74492 100644 --- a/test/batch/test_solve_batch.py +++ b/test/batch/test_solve_batch.py @@ -1,4 +1,4 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary # Kilian Fatras diff --git a/test/batch/test_solve_gromov_batch.py b/test/batch/test_solve_gromov_batch.py index e0029689b..1518b41d8 100644 --- a/test/batch/test_solve_gromov_batch.py +++ b/test/batch/test_solve_gromov_batch.py @@ -1,19 +1,24 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary -# Kilian Fatras -# Quang Huy Tran -# Eduardo Fernandes Montesuma +# Sonia Mazelet + # # License: MIT License import numpy as np -from ot.batch import solve_gromov_batch, loss_quadratic_samples_batch +from ot.batch import ( + solve_gromov_batch, + loss_quadratic_batch, + loss_linear_batch, + loss_quadratic_samples_batch, +) from ot import solve_gromov from ot.batch._linear import dist_batch import pytest from itertools import product from ot.backend import torch +from ot.batch._quadratic import tensor_batch, loss_fugw_batch, loss_fugw_samples_batch def test_solve_gromov_batch(): @@ -133,3 +138,65 @@ def test_backend(nx): C = np.random.randn(batchsize, n, n, d) C = nx.from_numpy(C) solve_gromov_batch(C1=C, C2=C, a=None, b=None, loss="sqeuclidean", logits=False) + + +def test_fugw_loss(): + """Check that loss_fugw_batch and loss_fugw_samples_batch run without error.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + X = rng.rand(batchsize, n, d) + Y = rng.rand(batchsize, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + alpha = rng.rand() + reg_marginals = rng.rand() + + loss_fugw = loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + loss_fugw_sample = loss_fugw_samples_batch( + a, a, C1, C2, X, Y, T, alpha=alpha, reg_marginals=reg_marginals + ) + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + + alpha = rng.rand(batchsize) + reg_marginals = rng.rand(batchsize) + loss_fugw = loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + loss_fugw_sample = loss_fugw_samples_batch( + a, a, C1, C2, X, Y, T, alpha=alpha, reg_marginals=reg_marginals + ) + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + + +def test_valid_fugw_loss_endpoints(): + """Check that loss_fugw_batch gives the same results as solve_gromov_batch and solve_linear_batch for alpha=0 and alpha=1.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=0.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_linear = loss_linear_batch(M, T) + np.testing.assert_allclose(loss_fugw, loss_linear, atol=1e-5) + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=1.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_gromov = loss_quadratic_batch(L, T, recompute_const=True) + np.testing.assert_allclose(loss_fugw, loss_gromov, atol=1e-5)