Skip to content

Commit 0669475

Browse files
authored
Allow an arbitrary Zarr store to be used for intermediate storage (#801)
* Allow an arbitrary Zarr store to be used for intermediate storage * Don't run test_arbitrary_zarr_store on tensorstore * Fix typo * Rename backend_storage_name as get_storage_name * Put `intermediate_store` param next to `work_dir` * Fix test
1 parent 41e506f commit 0669475

File tree

5 files changed

+56
-18
lines changed

5 files changed

+56
-18
lines changed

cubed/core/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, numpy_array_to_backend_array
1818
from cubed.backend_array_api import namespace as nxp
1919
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
20-
from cubed.core.plan import Plan, context_dir_path
20+
from cubed.core.plan import Plan, intermediate_store
2121
from cubed.primitive.blockwise import blockwise as primitive_blockwise
2222
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
2323
from cubed.primitive.memory import get_buffer_copies
@@ -363,7 +363,7 @@ def blockwise(
363363
spec = check_array_specs(arrays)
364364
buffer_copies = get_buffer_copies(spec)
365365
if target_store is None:
366-
target_store = context_dir_path(spec=spec)
366+
target_store = intermediate_store(spec=spec)
367367
op = primitive_blockwise(
368368
func,
369369
out_ind,
@@ -517,14 +517,14 @@ def _general_blockwise(
517517
if isinstance(target_stores, list) and len(target_stores) > 1: # multiple outputs
518518
name = [gensym() for _ in range(len(target_stores))]
519519
target_stores = [
520-
ts if ts is not None else context_dir_path(spec=spec)
520+
ts if ts is not None else intermediate_store(spec=spec)
521521
for ts in target_stores
522522
]
523523
target_names = name
524524
else: # single output
525525
name = gensym()
526526
if target_stores is None:
527-
target_stores = [context_dir_path(spec=spec)]
527+
target_stores = [intermediate_store(spec=spec)]
528528
target_names = [name]
529529

530530
op = primitive_general_blockwise(
@@ -951,7 +951,7 @@ def rechunk(x, chunks, *, target_store=None, min_mem=None, use_new_impl=True):
951951
name = gensym()
952952
spec = x.spec
953953
if target_store is None:
954-
target_store = context_dir_path(spec=spec)
954+
target_store = intermediate_store(spec=spec)
955955
name_int = f"{name}-int"
956956
ops = primitive_rechunk(
957957
x._zarray,

cubed/core/plan.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,14 @@ def arrays_to_plan(*arrays):
558558
return plans[0].arrays_to_plan(*arrays)
559559

560560

561-
def context_dir_path(spec=None):
561+
def intermediate_store(spec=None):
562+
"""Return a file path or a store object that is used for storing
563+
intermediate data.
564+
565+
By default returns a temporary file path, which may be local or remote.
566+
"""
567+
if spec.intermediate_store is not None:
568+
return spec.intermediate_store
562569
work_dir = spec.work_dir if spec is not None else None
563570
if work_dir is None:
564571
work_dir = tempfile.gettempdir()
@@ -567,16 +574,6 @@ def context_dir_path(spec=None):
567574
return context_dir
568575

569576

570-
def new_temp_path(name, suffix=".zarr", spec=None):
571-
"""Return a string path for a temporary file path, which may be local or remote.
572-
573-
Note that this function does not create the file or any directories (and they
574-
may never be created, if for example the file doesn't need to be materialized).
575-
"""
576-
context_dir = context_dir_path(spec)
577-
return join_path(context_dir, f"{name}{suffix}")
578-
579-
580577
def create_zarr_array(lazy_zarr_array, *, config=None):
581578
"""Stage function for create."""
582579
lazy_zarr_array.create(mode="a")

cubed/spec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from cubed.runtime.create import create_executor
88
from cubed.runtime.types import Executor
9+
from cubed.types import T_Store
910
from cubed.utils import convert_to_bytes
1011

1112

@@ -15,6 +16,8 @@ class Spec:
1516
def __init__(
1617
self,
1718
work_dir: Union[str, None] = None,
19+
*,
20+
intermediate_store: Union[T_Store, None] = None,
1821
allowed_mem: Union[int, str, None] = None,
1922
reserved_mem: Union[int, str, None] = 0,
2023
executor: Union[Executor, None] = None,
@@ -30,6 +33,8 @@ def __init__(
3033
----------
3134
work_dir : str or None
3235
The directory path (specified as an fsspec or obstore URL) used for storing intermediate data.
36+
intermediate_store : store, optional
37+
The Zarr store for intermediate data. Takes precedence over ``work_dir``.
3338
allowed_mem : int or str, optional
3439
The total memory available to a worker for running a task, in bytes.
3540
@@ -65,6 +70,7 @@ def __init__(
6570
self._executor_options = executor_options
6671
self._storage_options = storage_options
6772
self._zarr_compressor = zarr_compressor
73+
self._intermediate_store = intermediate_store
6874

6975
@property
7076
def work_dir(self) -> Optional[str]:
@@ -118,16 +124,22 @@ def zarr_compressor(self) -> Union[dict, str, None]:
118124
"""The compressor used by Zarr for intermediate data."""
119125
return self._zarr_compressor
120126

127+
@property
128+
def intermediate_store(self) -> Union[dict, str, None]:
129+
"""The Zarr store for intermediate data. Takes precedence over ``work_dir``."""
130+
return self._intermediate_store
131+
121132
def __repr__(self) -> str:
122133
return (
123-
f"cubed.Spec(work_dir={self._work_dir}, allowed_mem={self._allowed_mem}, "
134+
f"cubed.Spec(work_dir={self._work_dir}, intermediate_store={self._intermediate_store}, allowed_mem={self._allowed_mem}, "
124135
f"reserved_mem={self._reserved_mem}, executor={self._executor}, storage_options={self._storage_options}, zarr_compressor={self._zarr_compressor})"
125136
)
126137

127138
def __eq__(self, other):
128139
if isinstance(other, Spec):
129140
return (
130141
self.work_dir == other.work_dir
142+
and self.intermediate_store == other.intermediate_store
131143
and self.allowed_mem == other.allowed_mem
132144
and self.reserved_mem == other.reserved_mem
133145
and self.executor == other.executor

cubed/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def modal_executor(request):
5454

5555

5656
def test_object_bool(tmp_path, executor):
57-
spec = cubed.Spec(tmp_path, 100000, executor=executor)
57+
spec = cubed.Spec(tmp_path, allowed_mem=100000, executor=executor)
5858
a = xp.asarray(
5959
[[False, False, False], [False, False, False], [False, False, False]],
6060
chunks=(2, 2),

cubed/tests/test_store.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import pytest
3+
import zarr
4+
from numpy.testing import assert_array_equal
5+
6+
import cubed
7+
import cubed.array_api as xp
8+
from cubed.storage.store import get_storage_name
9+
10+
ZARR_PYTHON_V2 = zarr.__version__[0] == "2"
11+
12+
13+
@pytest.mark.skipif(
14+
ZARR_PYTHON_V2 or get_storage_name() == "tensorstore",
15+
reason="setting an arbitrary Zarr store is not supported for Zarr Python v2, or tensorstore",
16+
)
17+
def test_arbitrary_zarr_store():
18+
store = zarr.storage.MemoryStore()
19+
spec = cubed.Spec(intermediate_store=store, allowed_mem="100kB")
20+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
21+
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
22+
c = xp.add(a, b)
23+
assert_array_equal(c, np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]))
24+
25+
# check store was used
26+
z = zarr.open_group(store)
27+
array_keys = list(z.array_keys())
28+
assert len(array_keys) == 1
29+
assert array_keys[0].startswith("array-")

0 commit comments

Comments
 (0)