diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index c74d9a8..e3f4db6 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -77,7 +77,7 @@ def __call__(self, *args, **kwargs): try: - # jax>=0.8.3 + # jax>=0.8.0 # The import may fail if the JAX version is not new enough. from jaxlib import _pathways as jaxlib_pathways # pylint: disable=g-import-not-at-top @@ -86,10 +86,10 @@ def __call__(self, *args, **kwargs): del jaxlib_pathways except ImportError: - # jax<0.8.3 + # jax<0.8.0 transfer_to_shardings = _FakeJaxFunction( "jax.jaxlib._pathways._transfer_to_shardings", - "0.8.3", + "0.8.0", )