Skip to content

Commit b7ac0af

Browse files
committed
Add type-hints to adaptive/learner/data_saver.py
1 parent 1b7e84d commit b7ac0af

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

adaptive/learner/data_saver.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import functools
44
from collections import OrderedDict
5+
from operator import itemgetter
6+
from typing import Any
57

68
from adaptive.learner.base_learner import BaseLearner
79
from adaptive.utils import copy_docstring_from
@@ -39,7 +41,7 @@ class DataSaver:
3941
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4042
"""
4143

42-
def __init__(self, learner, arg_picker):
44+
def __init__(self, learner: BaseLearner, arg_picker: itemgetter) -> None:
4345
self.learner = learner
4446
self.extra_data = OrderedDict()
4547
self.function = learner.function
@@ -49,21 +51,21 @@ def new(self) -> DataSaver:
4951
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5052
return DataSaver(self.learner.new(), self.arg_picker)
5153

52-
def __getattr__(self, attr):
54+
def __getattr__(self, attr: str) -> Any:
5355
return getattr(self.learner, attr)
5456

5557
@copy_docstring_from(BaseLearner.tell)
56-
def tell(self, x, result):
58+
def tell(self, x: Any, result: Any) -> None:
5759
y = self.arg_picker(result)
5860
self.extra_data[x] = result
5961
self.learner.tell(x, y)
6062

6163
@copy_docstring_from(BaseLearner.tell_pending)
62-
def tell_pending(self, x):
64+
def tell_pending(self, x: Any) -> None:
6365
self.learner.tell_pending(x)
6466

6567
def to_dataframe(
66-
self, extra_data_name: str = "extra_data", **kwargs
68+
self, extra_data_name: str = "extra_data", **kwargs: Any
6769
) -> pandas.DataFrame:
6870
"""Return the data as a concatenated `pandas.DataFrame` from child learners.
6971
@@ -98,7 +100,7 @@ def load_dataframe(
98100
extra_data_name: str = "extra_data",
99101
input_names: tuple[str] = (),
100102
**kwargs,
101-
):
103+
) -> None:
102104
"""Load the data from a `pandas.DataFrame` into the learner.
103105
104106
Parameters
@@ -122,33 +124,36 @@ def load_dataframe(
122124
key = _to_key(x[:-1])
123125
self.extra_data[key] = x[-1]
124126

125-
def _get_data(self):
127+
def _get_data(self) -> tuple[Any, OrderedDict]:
126128
return self.learner._get_data(), self.extra_data
127129

128-
def _set_data(self, data):
130+
def _set_data(
131+
self,
132+
data: tuple[Any, OrderedDict],
133+
) -> None:
129134
learner_data, self.extra_data = data
130135
self.learner._set_data(learner_data)
131136

132-
def __getstate__(self):
137+
def __getstate__(self) -> tuple[BaseLearner, itemgetter, OrderedDict]:
133138
return (
134139
self.learner,
135140
self.arg_picker,
136141
self.extra_data,
137142
)
138143

139-
def __setstate__(self, state):
144+
def __setstate__(self, state: tuple[BaseLearner, itemgetter, OrderedDict]) -> None:
140145
learner, arg_picker, extra_data = state
141146
self.__init__(learner, arg_picker)
142147
self.extra_data = extra_data
143148

144149
@copy_docstring_from(BaseLearner.save)
145-
def save(self, fname, compress=True):
150+
def save(self, fname, compress=True) -> None:
146151
# We copy this method because the 'DataSaver' is not a
147152
# subclass of the 'BaseLearner'.
148153
BaseLearner.save(self, fname, compress)
149154

150155
@copy_docstring_from(BaseLearner.load)
151-
def load(self, fname, compress=True):
156+
def load(self, fname, compress=True) -> None:
152157
# We copy this method because the 'DataSaver' is not a
153158
# subclass of the 'BaseLearner'.
154159
BaseLearner.load(self, fname, compress)

0 commit comments

Comments
 (0)