diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index e964aaff2..b46dab81b 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -83,6 +83,9 @@ jobs: python -m pip install --upgrade pip setuptools pip install -r requirements_all.txt pip install pytest pytest-cov + - name: test POT import + run: | + python -c "import ot; print(ot.__version__)" - name: Run tests run: | python -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --cov=./ --cov-report=xml diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..d187eec78 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -29,7 +29,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Fix test of the version of jax in `ot.backend` (PR #794) - Reverting the openmp fix on macOS (PR #789) for macOS (PR #797) - Align documentation build dependencies and doc extras (PR #801) - +- Debug Debug linux test core dump (PR #815) ## 0.9.6.post1 diff --git a/ot/gnn/__init__.py b/ot/gnn/__init__.py index 5f3a93fed..9567f1ecc 100644 --- a/ot/gnn/__init__.py +++ b/ot/gnn/__init__.py @@ -17,9 +17,9 @@ # All submodules and packages from ._utils import FGW_distance_to_templates, wasserstein_distance_to_templates - from ._layers import TFGWPooling, TWPooling + __all__ = [ "FGW_distance_to_templates", "wasserstein_distance_to_templates", diff --git a/ot/gnn/_utils.py b/ot/gnn/_utils.py index 16487c210..dc1342bf6 100644 --- a/ot/gnn/_utils.py +++ b/ot/gnn/_utils.py @@ -12,7 +12,15 @@ from ..utils import dist from ..gromov import fused_gromov_wasserstein2 from ..lp import emd2 -from torch_geometric.utils import subgraph +import warnings + +try: + from torch_geometric.utils import subgraph +except ImportError: + warnings.warn( + "torch_geometric is not installed. The ot.gnn module requires torch_geometric to be installed." + ) + pass def TFGW_template_initialization( diff --git a/requirements_all.txt b/requirements_all.txt index a015855f6..890d6ceb4 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -5,7 +5,7 @@ autograd pymanopt @ git+https://github.com/pymanopt/pymanopt.git@master cvxopt scikit-learn -torch +torch<=2.11 jax jaxlib tensorflow diff --git a/setup.py b/setup.py index 0452fcf25..a2324a41d 100644 --- a/setup.py +++ b/setup.py @@ -116,6 +116,7 @@ "networkx", "memory_profiler", ], + "tests": ["pytest", "pytest-cov"], "all": [ "jax", "jaxlib",