Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def compute(
optimize_graph=optimize_graph,
optimize_function=optimize_function,
resume=resume,
array_names=tuple(a.name for a in arrays),
spec=spec,
**kwargs,
)
Expand Down Expand Up @@ -343,7 +342,6 @@ def visualize(
optimize_graph=optimize_graph,
optimize_function=optimize_function,
show_hidden=show_hidden,
array_names=tuple(a.name for a in arrays),
)


Expand Down
35 changes: 17 additions & 18 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS = 10


def simple_optimize_dag(dag, array_names=None):
def simple_optimize_dag(dag, array_names):
"""Apply map blocks fusion."""

# note there is no need to prune the dag, since the way it is built
Expand All @@ -41,7 +41,7 @@ def can_fuse(n):

# if input is one of the arrays being computed then don't fuse
op2_input = next(dag.predecessors(op2))
if array_names is not None and op2_input in array_names:
if op2_input in array_names:
return False

# if input is used by another node then don't fuse
Expand Down Expand Up @@ -159,7 +159,7 @@ def can_fuse_predecessors(
dag,
name,
*,
array_names=None,
array_names,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
Expand All @@ -182,18 +182,17 @@ def can_fuse_predecessors(
return False

# if a predecessor op produces one of the arrays being computed, then don't fuse
if array_names is not None:
predecessor_array_names = set(
array_name for _, array_name, _ in predecessor_ops_and_arrays(dag, name)
predecessor_array_names = set(
array_name for _, array_name, _ in predecessor_ops_and_arrays(dag, name)
)
array_names_intersect = set(array_names) & predecessor_array_names
if len(array_names_intersect) > 0:
logger.debug(
"can't fuse %s since predecessor ops produce one or more arrays being computed %s",
name,
array_names_intersect,
)
array_names_intersect = set(array_names) & predecessor_array_names
if len(array_names_intersect) > 0:
logger.debug(
"can't fuse %s since predecessor ops produce one or more arrays being computed %s",
name,
array_names_intersect,
)
return False
return False

# if any predecessor ops have multiple outputs then don't fuse
# TODO: implement "child fusion" (where a multiple output op fuses its children)
Expand Down Expand Up @@ -247,7 +246,7 @@ def fuse_predecessors(
dag,
name,
*,
array_names=None,
array_names,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
Expand Down Expand Up @@ -302,7 +301,7 @@ def fuse_predecessors(
def multiple_inputs_optimize_dag(
dag,
*,
array_names=None,
array_names,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
Expand All @@ -324,7 +323,7 @@ def multiple_inputs_optimize_dag(
return dag


def fuse_all_optimize_dag(dag, array_names=None):
def fuse_all_optimize_dag(dag, array_names):
"""Force all operations to be fused."""
dag = dag.copy()
always_fuse = [op for op in dag.nodes() if op.startswith("op-")]
Expand All @@ -333,7 +332,7 @@ def fuse_all_optimize_dag(dag, array_names=None):
)


def fuse_only_optimize_dag(dag, *, array_names=None, only_fuse=None):
def fuse_only_optimize_dag(dag, *, array_names, only_fuse=None):
"""Force only specified operations to be fused, all others will be left even if they are suitable for fusion."""
dag = dag.copy()
always_fuse = only_fuse
Expand Down
38 changes: 15 additions & 23 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional

import networkx as nx

Expand Down Expand Up @@ -76,8 +76,9 @@ class Plan:
as the output of one pipeline, then read back as the input to later pipelines.
"""

def __init__(self, dag):
def __init__(self, dag, array_names):
self.dag = dag
self.array_names = array_names

# args from primitive_op onwards are omitted for creation functions when no computation is needed
@classmethod
Expand Down Expand Up @@ -173,21 +174,20 @@ def _new(
if hasattr(x, "name"):
dag.add_edge(x.name, op_name_unique)

return Plan(dag)
return Plan(dag, (name,))

@classmethod
def arrays_to_plan(cls, *arrays):
return Plan(arrays_to_dag(*arrays))
return Plan(arrays_to_dag(*arrays), tuple(a.name for a in arrays))

def optimize(
self,
optimize_function: Optional[Callable[..., nx.MultiDiGraph]] = None,
array_names: Optional[Tuple[str]] = None,
):
if optimize_function is None:
optimize_function = multiple_inputs_optimize_dag
dag = optimize_function(self.dag, array_names=array_names)
return Plan(dag)
dag = optimize_function(self.dag, array_names=self.array_names)
return Plan(dag, self.array_names)

def _create_lazy_zarr_arrays(self, dag):
# find all lazy zarr arrays in dag
Expand Down Expand Up @@ -272,19 +272,14 @@ def _finalize(
optimize_graph: bool = True,
optimize_function=None,
compile_function: Optional[Decorator] = None,
array_names=None,
) -> "FinalizedPlan":
dag = (
self.optimize(optimize_function, array_names).dag
if optimize_graph
else self.dag
)
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
return FinalizedPlan(nx.freeze(dag))
return FinalizedPlan(nx.freeze(dag), array_names=self.array_names)

def execute(
self,
Expand All @@ -294,12 +289,11 @@ def execute(
optimize_function=None,
compile_function=None,
resume=None,
array_names=None,
spec=None,
**kwargs,
):
finalized_plan = self._finalize(
optimize_graph, optimize_function, compile_function, array_names=array_names
optimize_graph, optimize_function, compile_function
)
dag = finalized_plan.dag

Expand Down Expand Up @@ -334,11 +328,8 @@ def visualize(
optimize_graph=True,
optimize_function=None,
show_hidden=False,
array_names=None,
):
finalized_plan = self._finalize(
optimize_graph, optimize_function, array_names=array_names
)
finalized_plan = self._finalize(optimize_graph, optimize_function)
dag = finalized_plan.dag
dag = dag.copy() # make a copy since we mutate the DAG below

Expand Down Expand Up @@ -507,10 +498,11 @@ class FinalizedPlan:
4. freezing the final DAG so it can't be changed
"""

def __init__(self, dag):
def __init__(self, dag, array_names):
self.dag = dag
self.array_names = array_names

def max_projected_mem(self):
def max_projected_mem(self) -> int:
"""Return the maximum projected memory across all tasks to execute this plan."""
projected_mem_values = [
node["primitive_op"].projected_mem for _, node in visit_nodes(self.dag)
Expand All @@ -525,7 +517,7 @@ def num_primitive_ops(self) -> int:
"""Return the number of primitive operations in this plan."""
return len(list(visit_nodes(self.dag)))

def num_tasks(self):
def num_tasks(self) -> int:
"""Return the number of tasks needed to execute this plan."""
tasks = 0
for _, node in visit_nodes(self.dag):
Expand Down
2 changes: 1 addition & 1 deletion cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_custom_optimize_function(spec):
< num_tasks_with_no_optimization
)

def custom_optimize_function(dag, array_names=None):
def custom_optimize_function(dag, array_names):
# leave DAG unchanged
return dag

Expand Down