Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
- Add batch FUGW loss to `ot.batch` (PR #775)

#### Closed issues

Expand Down
2 changes: 1 addition & 1 deletion ot/batch/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def loss_linear_batch(M, T, nx=None):
return nx.sum(M * T, axis=(1, 2))


def loss_linear_samples_batch(X, Y, T, metric="l2"):
def loss_linear_samples_batch(X, Y, T, metric="sqeuclidean"):
r"""Computes the linear optimal transport loss given samples and transport plan. This is the equivalent of
calling `dist_batch` and then `loss_linear_batch`.

Expand Down
281 changes: 279 additions & 2 deletions ot/batch/_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from ..utils import OTResult
from ot.backend import get_backend
from ot.batch._linear import loss_linear_batch
from ot.batch._linear import loss_linear_batch, loss_linear_samples_batch
from ot.batch._utils import bmv, bop, bregman_log_projection_batch
from ot.utils import list_to_array


def tensor_batch(
Expand Down Expand Up @@ -152,6 +153,90 @@ def h2(C2):
return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric)


def div_to_product_batch(
T, a, b, T1=None, T2=None, divergence="kl", mass=True, nx=None
):
r"""Fast computation of the Bregman divergence between a batch of arbitrary measures and a product measures.
Only support for Kullback-Leibler and half-squared L2 divergences.

- For half-squared L2 divergence:

.. math::
\frac{1}{2} || \pi - a \otimes b ||^2
= \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big]

- For Kullback-Leibler divergence:

.. math::
KL(\pi | a \otimes b)
= \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle
- \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b)

where :

- :math:`\pi` is the (`dim_a`, `dim_b`) transport plan
- :math:`\pi_1` and :math:`\pi_2` are the marginal distributions
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- :math:`m` denotes the mass of the measure

Parameters
----------
pi : array-like (B, n, m)
Transport plan for each problem in the batch
a : array-like (B,n)
Unnormalized histogram of dimension `n` for each problem in the batch
b : array-like (B,m)
Unnormalized histogram of dimension `m` for each problem in the batch
T1 : array-like (B, n), optional (default = None)
Marginal distribution with respect to the first dimension of the transport plan for each problem in the batch
Only used in case of Kullback-Leibler divergence.
T2 : array-like (B, m), optional (default = None)
Marginal distribution with respect to the second dimension of the transport plan for each problem in the batch
Only used in case of Kullback-Leibler divergence.
divergence : string, default = "kl"
Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence)
mass : bool, optional. Default is False.
Only used in case of Kullback-Leibler divergence.
If False, calculate the relative entropy.
If True, calculate the Kullback-Leibler divergence.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
Bregman divergence between an arbitrary measure and a product measure for each problem in the batch.
"""

arr = [T, a, b, T1, T2]

if nx is None:
nx = get_backend(*arr, T1, T2)

if divergence == "kl":
if T1 is None:
T1 = nx.sum(T, 2)
if T2 is None:
T2 = nx.sum(T, 1)

if divergence == "kl":
res = (
nx.sum((T * nx.log(T + 1.0 * (T == 0))), (1, 2))
- nx.sum(T1 * nx.log(a), 1)
- nx.sum(T2 * nx.log(b), 1)
)
if mass:
res = res - nx.sum(T1, 1) + nx.sum(a, 1) * nx.sum(b, 1)

elif divergence == "l2":
res = (
nx.sum(T**2, (1, 2))
+ nx.sum(a**2, 1) * nx.sum(b**2, 1)
- 2 * nx.sum((a * (T @ b[:, :, None]).squeeze(-1)), 1)
) / 2

return res


def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None):
r"""
Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version.
Expand Down Expand Up @@ -205,7 +290,7 @@ def loss_quadratic_samples_batch(
C2,
T,
loss="sqeuclidean",
symmetric=None,
symmetric=True,
nx=None,
logits=None,
recompute_const=False,
Expand Down Expand Up @@ -266,6 +351,198 @@ def loss_quadratic_samples_batch(
)


def loss_fugw_batch(
a,
b,
L,
M,
T,
alpha=0.5,
reg_marginals=1,
symmetric=True,
divergence="kl",
recompute_const=True,
nx=None,
):
r"""
Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version.

Parameters
----------
L : dict
Cost tensor as returned by `tensor_batch`.
M : array-like, shape (B, n, m)
Cost matrix between features across domains.
T : array-like, shape (B, n, m)
Transport plan.
alpha : float, array-like or list (B,) optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha
a scalar it is used for all problems in the batch.
reg_marginals : float array-like or list(B,) optional
Marginal relaxation terms. If rho is
a scalar it is used for all problems in the batch.
symmetric : bool, optional
Whether to use symmetric version. Default is True.
divergence : string, default = "kl"
Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence)
recompute_const : bool, optional
Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints.
nx : module, optional
Backend to use. Default is None.
"""
if nx is None:
nx = get_backend(T)

B = T.shape[0]

if isinstance(alpha, list):
alpha = list_to_array(alpha, nx=nx)

if isinstance(reg_marginals, list):
reg_marginals = list_to_array(reg_marginals, nx=nx)

if hasattr(alpha, "ndim") and alpha.ndim > 0:
if alpha.ndim != 1 or alpha.shape[0] != B:
raise ValueError(
f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}"
)

if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0:
if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B:
raise ValueError(
f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}"
)

quadratic = loss_quadratic_batch(
L, T, recompute_const=recompute_const, symmetric=symmetric, nx=nx
)

linear = loss_linear_batch(M, T, nx=nx)

unbalanced = div_to_product_batch(
T,
a,
b,
divergence=divergence,
mass=True,
nx=nx,
)

return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced


def loss_fugw_samples_batch(
a,
b,
C1,
C2,
X,
Y,
T,
alpha=0.5,
reg_marginals=1,
symmetric=True,
divergence="kl",
recompute_const=True,
metric_linear="sqeuclidean",
metric_quadratic="sqeuclidean",
logits=None,
nx=None,
):
r"""
Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (quadratic term), a cost matrix between features across domains (linear term) and a transport plan. Batched version.

Parameters
----------
a : array-like, shape (B, n)
Source distributions.
b : array-like, shape (B, m)
Target distributions.
C1 : array-like, shape (B, n, n) or (B, n, n, d)
Source cost matrices for the quadratic term.
C2 : array-like, shape (B, m, m) or (B, n, n, d)
Target cost matrices for the quadratic term.
X : array-like, shape (B, n, d)
Samples from source distribution for the linear term
Y : array-like, shape (B, m, d)
Samples from target distribution for the linear term
T : array-like, shape (B, n, m)
Transport plan.
alpha : float or array-like or list(B,) optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha
a scalar it is used for all problems in the batch.
reg_marginals : float or array-like or list(B,) optional
Marginal relaxation terms. If rho is
a scalar it is used for all problems in the batch.
symmetric : bool, optional
Whether to use symmetric version. Default is True.
divergence : string, default = "kl"
Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence)
recompute_const : bool, optional
Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints.
metric_linear : str, optional
Metric for the linear term, 'sqeuclidean', 'euclidean', 'minkowski' or 'kl'
metric_quadratic : str, optional
Metric to use for the quadratic term. Supported values: 'sqeuclidean', 'kl'.
Default is 'sqeuclidean'.
logits : bool, optional
For KL divergence, whether inputs are logits (unnormalized log probabilities).
If True, inputs are treated as logits. Default is None.
nx : module, optional
Backend to use. Default is None.
"""
if nx is None:
nx = get_backend(T)

B = T.shape[0]

if isinstance(alpha, list):
alpha = list_to_array(alpha, nx=nx)

if isinstance(reg_marginals, list):
reg_marginals = list_to_array(reg_marginals, nx=nx)

if hasattr(alpha, "ndim") and alpha.ndim > 0:
if alpha.ndim != 1 or alpha.shape[0] != B:
raise ValueError(
f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}"
)

if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0:
if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B:
raise ValueError(
f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}"
)

quadratic = loss_quadratic_samples_batch(
a,
b,
C1,
C2,
T,
loss=metric_quadratic,
symmetric=symmetric,
nx=nx,
logits=logits,
recompute_const=recompute_const,
)

linear = loss_linear_samples_batch(X, Y, T, metric=metric_linear)

unbalanced = div_to_product_batch(
T,
a,
b,
divergence=divergence,
mass=True,
nx=nx,
)

return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced


def solve_gromov_batch(
C1,
C2,
Expand Down
2 changes: 1 addition & 1 deletion test/batch/test_solve_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for module bregman on OT with bregman projections"""
"""Tests for module batch"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
Expand Down
Loading
Loading