22
33import functools
44from collections import OrderedDict
5+ from operator import itemgetter
6+ from typing import Any
57
68from adaptive .learner .base_learner import BaseLearner
79from adaptive .utils import copy_docstring_from
@@ -39,7 +41,7 @@ class DataSaver:
3941 >>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
4042 """
4143
42- def __init__ (self , learner , arg_picker ) :
44+ def __init__ (self , learner : BaseLearner , arg_picker : itemgetter ) -> None :
4345 self .learner = learner
4446 self .extra_data = OrderedDict ()
4547 self .function = learner .function
@@ -49,21 +51,21 @@ def new(self) -> DataSaver:
4951 """Return a new `DataSaver` with the same `arg_picker` and `learner`."""
5052 return DataSaver (self .learner .new (), self .arg_picker )
5153
52- def __getattr__ (self , attr ) :
54+ def __getattr__ (self , attr : str ) -> Any :
5355 return getattr (self .learner , attr )
5456
5557 @copy_docstring_from (BaseLearner .tell )
56- def tell (self , x , result ) :
58+ def tell (self , x : Any , result : Any ) -> None :
5759 y = self .arg_picker (result )
5860 self .extra_data [x ] = result
5961 self .learner .tell (x , y )
6062
6163 @copy_docstring_from (BaseLearner .tell_pending )
62- def tell_pending (self , x ) :
64+ def tell_pending (self , x : Any ) -> None :
6365 self .learner .tell_pending (x )
6466
6567 def to_dataframe (
66- self , extra_data_name : str = "extra_data" , ** kwargs
68+ self , extra_data_name : str = "extra_data" , ** kwargs : Any
6769 ) -> pandas .DataFrame :
6870 """Return the data as a concatenated `pandas.DataFrame` from child learners.
6971
@@ -98,7 +100,7 @@ def load_dataframe(
98100 extra_data_name : str = "extra_data" ,
99101 input_names : tuple [str ] = (),
100102 ** kwargs ,
101- ):
103+ ) -> None :
102104 """Load the data from a `pandas.DataFrame` into the learner.
103105
104106 Parameters
@@ -122,33 +124,36 @@ def load_dataframe(
122124 key = _to_key (x [:- 1 ])
123125 self .extra_data [key ] = x [- 1 ]
124126
125- def _get_data (self ):
127+ def _get_data (self ) -> tuple [ Any , OrderedDict ] :
126128 return self .learner ._get_data (), self .extra_data
127129
128- def _set_data (self , data ):
130+ def _set_data (
131+ self ,
132+ data : tuple [Any , OrderedDict ],
133+ ) -> None :
129134 learner_data , self .extra_data = data
130135 self .learner ._set_data (learner_data )
131136
132- def __getstate__ (self ):
137+ def __getstate__ (self ) -> tuple [ BaseLearner , itemgetter , OrderedDict ] :
133138 return (
134139 self .learner ,
135140 self .arg_picker ,
136141 self .extra_data ,
137142 )
138143
139- def __setstate__ (self , state ) :
144+ def __setstate__ (self , state : tuple [ BaseLearner , itemgetter , OrderedDict ]) -> None :
140145 learner , arg_picker , extra_data = state
141146 self .__init__ (learner , arg_picker )
142147 self .extra_data = extra_data
143148
144149 @copy_docstring_from (BaseLearner .save )
145- def save (self , fname , compress = True ):
150+ def save (self , fname , compress = True ) -> None :
146151 # We copy this method because the 'DataSaver' is not a
147152 # subclass of the 'BaseLearner'.
148153 BaseLearner .save (self , fname , compress )
149154
150155 @copy_docstring_from (BaseLearner .load )
151- def load (self , fname , compress = True ):
156+ def load (self , fname , compress = True ) -> None :
152157 # We copy this method because the 'DataSaver' is not a
153158 # subclass of the 'BaseLearner'.
154159 BaseLearner .load (self , fname , compress )
0 commit comments