diff --git a/adaptive/learner/balancing_learner.py b/adaptive/learner/balancing_learner.py index 0215b3af6..593331792 100644 --- a/adaptive/learner/balancing_learner.py +++ b/adaptive/learner/balancing_learner.py @@ -1,11 +1,13 @@ from __future__ import annotations import itertools +import numbers from collections import defaultdict from collections.abc import Iterable from contextlib import suppress from functools import partial from operator import itemgetter +from typing import Any, Callable, Dict, Sequence, Tuple, Union import numpy as np @@ -13,20 +15,33 @@ from adaptive.notebook_integration import ensure_holoviews from adaptive.utils import cache_latest, named_product, restore +try: + from typing import Literal, TypeAlias +except ImportError: + from typing_extensions import Literal, TypeAlias + try: import pandas with_pandas = True - except ModuleNotFoundError: with_pandas = False -def dispatch(child_functions, arg): +def dispatch(child_functions: list[Callable], arg: Any) -> Any: index, x = arg return child_functions[index](x) +STRATEGY_TYPE: TypeAlias = Literal["loss_improvements", "loss", "npoints", "cycle"] + +CDIMS_TYPE: TypeAlias = Union[ + Sequence[Dict[str, Any]], + Tuple[Sequence[str], Sequence[Tuple[Any, ...]]], + None, +] + + class BalancingLearner(BaseLearner): r"""Choose the optimal points from a set of learners. @@ -78,13 +93,19 @@ class BalancingLearner(BaseLearner): behave in an undefined way. Change the `strategy` in that case. """ - def __init__(self, learners, *, cdims=None, strategy="loss_improvements"): + def __init__( + self, + learners: list[BaseLearner], + *, + cdims: CDIMS_TYPE = None, + strategy: STRATEGY_TYPE = "loss_improvements", + ) -> None: self.learners = learners # Naively we would make 'function' a method, but this causes problems # when using executors from 'concurrent.futures' because we have to # pickle the whole learner. - self.function = partial(dispatch, [l.function for l in self.learners]) + self.function = partial(dispatch, [l.function for l in self.learners]) # type: ignore self._ask_cache = {} self._loss = {} @@ -96,7 +117,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"): "A BalacingLearner can handle only one type" " of learners." ) - self.strategy = strategy + self.strategy: STRATEGY_TYPE = strategy def new(self) -> BalancingLearner: """Create a new `BalancingLearner` with the same parameters.""" @@ -107,21 +128,21 @@ def new(self) -> BalancingLearner: ) @property - def data(self): + def data(self) -> dict[tuple[int, Any], Any]: data = {} for i, l in enumerate(self.learners): data.update({(i, p): v for p, v in l.data.items()}) return data @property - def pending_points(self): + def pending_points(self) -> set[tuple[int, Any]]: pending_points = set() for i, l in enumerate(self.learners): pending_points.update({(i, p) for p in l.pending_points}) return pending_points @property - def npoints(self): + def npoints(self) -> int: return sum(l.npoints for l in self.learners) @property @@ -134,7 +155,7 @@ def nsamples(self): ) @property - def strategy(self): + def strategy(self) -> STRATEGY_TYPE: """Can be either 'loss_improvements' (default), 'loss', 'npoints', or 'cycle'. The points that the `BalancingLearner` choses can be either based on: the best 'loss_improvements', the smallest total 'loss' of @@ -145,7 +166,7 @@ def strategy(self): return self._strategy @strategy.setter - def strategy(self, strategy): + def strategy(self, strategy: STRATEGY_TYPE) -> None: self._strategy = strategy if strategy == "loss_improvements": self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements @@ -162,7 +183,9 @@ def strategy(self, strategy): ' strategy="npoints", or strategy="cycle" is implemented.' ) - def _ask_and_tell_based_on_loss_improvements(self, n): + def _ask_and_tell_based_on_loss_improvements( + self, n: int + ) -> tuple[list[tuple[int, Any]], list[float]]: selected = [] # tuples ((learner_index, point), loss_improvement) total_points = [l.npoints + len(l.pending_points) for l in self.learners] for _ in range(n): @@ -185,7 +208,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n): points, loss_improvements = map(list, zip(*selected)) return points, loss_improvements - def _ask_and_tell_based_on_loss(self, n): + def _ask_and_tell_based_on_loss( + self, n: int + ) -> tuple[list[tuple[int, Any]], list[float]]: selected = [] # tuples ((learner_index, point), loss_improvement) total_points = [l.npoints + len(l.pending_points) for l in self.learners] for _ in range(n): @@ -206,7 +231,9 @@ def _ask_and_tell_based_on_loss(self, n): points, loss_improvements = map(list, zip(*selected)) return points, loss_improvements - def _ask_and_tell_based_on_npoints(self, n): + def _ask_and_tell_based_on_npoints( + self, n: numbers.Integral + ) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]: selected = [] # tuples ((learner_index, point), loss_improvement) total_points = [l.npoints + len(l.pending_points) for l in self.learners] for _ in range(n): @@ -222,7 +249,9 @@ def _ask_and_tell_based_on_npoints(self, n): points, loss_improvements = map(list, zip(*selected)) return points, loss_improvements - def _ask_and_tell_based_on_cycle(self, n): + def _ask_and_tell_based_on_cycle( + self, n: int + ) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]: points, loss_improvements = [], [] for _ in range(n): index = next(self._cycle) @@ -233,7 +262,9 @@ def _ask_and_tell_based_on_cycle(self, n): return points, loss_improvements - def ask(self, n, tell_pending=True): + def ask( + self, n: int, tell_pending: bool = True + ) -> tuple[list[tuple[numbers.Integral, Any]], list[float]]: """Chose points for learners.""" if n == 0: return [], [] @@ -244,20 +275,20 @@ def ask(self, n, tell_pending=True): else: return self._ask_and_tell(n) - def tell(self, x, y): + def tell(self, x: tuple[numbers.Integral, Any], y: Any) -> None: index, x = x self._ask_cache.pop(index, None) self._loss.pop(index, None) self._pending_loss.pop(index, None) self.learners[index].tell(x, y) - def tell_pending(self, x): + def tell_pending(self, x: tuple[numbers.Integral, Any]) -> None: index, x = x self._ask_cache.pop(index, None) self._loss.pop(index, None) self.learners[index].tell_pending(x) - def _losses(self, real=True): + def _losses(self, real: bool = True) -> list[float]: losses = [] loss_dict = self._loss if real else self._pending_loss @@ -269,11 +300,16 @@ def _losses(self, real=True): return losses @cache_latest - def loss(self, real=True): + def loss(self, real: bool = True) -> float: losses = self._losses(real) return max(losses) - def plot(self, cdims=None, plotter=None, dynamic=True): + def plot( + self, + cdims: CDIMS_TYPE = None, + plotter: Callable[[BaseLearner], Any] | None = None, + dynamic: bool = True, + ): """Returns a DynamicMap with sliders. Parameters @@ -346,13 +382,19 @@ def plot_function(*args): vals = {d.name: d.values for d in dm.dimensions() if d.values} return hv.HoloMap(dm.select(**vals)) - def remove_unfinished(self): + def remove_unfinished(self) -> None: """Remove uncomputed data from the learners.""" for learner in self.learners: learner.remove_unfinished() @classmethod - def from_product(cls, f, learner_type, learner_kwargs, combos): + def from_product( + cls, + f, + learner_type: BaseLearner, + learner_kwargs: dict[str, Any], + combos: dict[str, Sequence[Any]], + ) -> BalancingLearner: """Create a `BalancingLearner` with learners of all combinations of named variables’ values. The `cdims` will be set correctly, so calling `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels. @@ -448,7 +490,11 @@ def load_dataframe( for i, gr in df.groupby(index_name): self.learners[i].load_dataframe(gr, **kwargs) - def save(self, fname, compress=True): + def save( + self, + fname: Callable[[BaseLearner], str] | Sequence[str], + compress: bool = True, + ) -> None: """Save the data of the child learners into pickle files in a directory. @@ -486,7 +532,11 @@ def save(self, fname, compress=True): for l in self.learners: l.save(fname(l), compress=compress) - def load(self, fname, compress=True): + def load( + self, + fname: Callable[[BaseLearner], str] | Sequence[str], + compress: bool = True, + ) -> None: """Load the data of the child learners from pickle files in a directory. @@ -510,20 +560,20 @@ def load(self, fname, compress=True): for l in self.learners: l.load(fname(l), compress=compress) - def _get_data(self): + def _get_data(self) -> list[Any]: return [l._get_data() for l in self.learners] - def _set_data(self, data): + def _set_data(self, data: list[Any]): for l, _data in zip(self.learners, data): l._set_data(_data) - def __getstate__(self): + def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]: return ( self.learners, self._cdims_default, self.strategy, ) - def __setstate__(self, state): + def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]): learners, cdims, strategy = state self.__init__(learners, cdims=cdims, strategy=strategy)