Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ jobs:
python -m pip install --upgrade pip setuptools
pip install -r requirements_all.txt
pip install pytest pytest-cov
- name: test POT import
run: |
python -c "import ot; print(ot.__version__)"
- name: Run tests
run: |
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov=./ --cov-report=xml
Expand Down
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Fix test of the version of jax in `ot.backend` (PR #794)
- Reverting the openmp fix on macOS (PR #789) for macOS (PR #797)
- Align documentation build dependencies and doc extras (PR #801)

- Debug Debug linux test core dump (PR #815)

## 0.9.6.post1

Expand Down
2 changes: 1 addition & 1 deletion ot/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
# All submodules and packages

from ._utils import FGW_distance_to_templates, wasserstein_distance_to_templates

from ._layers import TFGWPooling, TWPooling


__all__ = [
"FGW_distance_to_templates",
"wasserstein_distance_to_templates",
Expand Down
10 changes: 9 additions & 1 deletion ot/gnn/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
from ..utils import dist
from ..gromov import fused_gromov_wasserstein2
from ..lp import emd2
from torch_geometric.utils import subgraph
import warnings

try:
from torch_geometric.utils import subgraph
except ImportError:
warnings.warn(
"torch_geometric is not installed. The ot.gnn module requires torch_geometric to be installed."
)
pass


def TFGW_template_initialization(
Expand Down
2 changes: 1 addition & 1 deletion requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ autograd
pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master
cvxopt
scikit-learn
torch
torch<=2.11
jax
jaxlib
tensorflow
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"networkx",
"memory_profiler",
],
"tests": ["pytest", "pytest-cov"],
"all": [
"jax",
"jaxlib",
Expand Down
Loading