diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a52a4548dc12..aa092c9b71a9 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -398,11 +398,13 @@ def is_device_rocm(): def get_rocm_version(): rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") version_path = Path(rocm_path) / ".info" / "version" - if not version_path.exists(): - raise FileNotFoundError(f"Expected ROCm version file at {version_path}") - version_str = version_path.read_text().strip() - major, minor, *_ = version_str.split(".") - return int(major), int(minor) + try: + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + except FileNotFoundError: + warnings.warn("ROCm was not installed") + raise unittest.SkipTest("ROCm was not installed") def is_device_cuda(): return 'cuda' in xla_bridge.get_backend().platform_version