1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
3+
4+ """Callbacks to insert customized recipes during the training.
5+ Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
6+ """
7+
8+ from typing import Any , Dict , List , Optional , TYPE_CHECKING
9+
10+ if TYPE_CHECKING :
11+ from .trainer import Trainer
12+
13+
14+ class Callback :
15+ """Base class of all callbacks."""
16+
17+ def setup (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
18+ """Called when fit, validate, test, predict, or tune begins."""
19+
20+ def teardown (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
21+ """Called when fit, validate, test, predict, or tune ends."""
22+
23+ def on_init_start (self , trainer : "Trainer" ) -> None :
24+ r"""
25+ .. deprecated:: v1.6
26+ This callback hook was deprecated in v1.6 and will be removed in v1.8.
27+ Called when the trainer initialization begins, model has not yet been set.
28+ """
29+
30+ def on_init_end (self , trainer : "Trainer" ) -> None :
31+ r"""
32+ .. deprecated:: v1.6
33+ This callback hook was deprecated in v1.6 and will be removed in v1.8.
34+ Called when the trainer initialization ends, model has not yet been set.
35+ """
36+
37+ def on_fit_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
38+ """Called when fit begins."""
39+
40+ def on_fit_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
41+ """Called when fit ends."""
42+
43+ def on_sanity_check_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
44+ """Called when the validation sanity check starts."""
45+
46+ def on_sanity_check_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
47+ """Called when the validation sanity check ends."""
48+
49+ def on_train_batch_start (
50+ self ,
51+ trainer : "Trainer" ,
52+ pl_module : "pl.LightningModule" ,
53+ batch : Any ,
54+ batch_idx : int ,
55+ unused : int = 0 ,
56+ ) -> None :
57+ """Called when the train batch begins."""
58+
59+ def on_train_batch_end (
60+ self ,
61+ trainer : "Trainer" ,
62+ pl_module : "pl.LightningModule" ,
63+ outputs : STEP_OUTPUT ,
64+ batch : Any ,
65+ batch_idx : int ,
66+ unused : int = 0 ,
67+ ) -> None :
68+ """Called when the train batch ends."""
69+
70+ def on_train_epoch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
71+ """Called when the train epoch begins."""
72+
73+ def on_train_epoch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
74+ """Called when the train epoch ends.
75+ To access all batch outputs at the end of the epoch, either:
76+ 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
77+ 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
78+ """
79+
80+ def on_validation_epoch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
81+ """Called when the val epoch begins."""
82+
83+ def on_validation_epoch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
84+ """Called when the val epoch ends."""
85+
86+ def on_test_epoch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
87+ """Called when the test epoch begins."""
88+
89+ def on_test_epoch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
90+ """Called when the test epoch ends."""
91+
92+ def on_predict_epoch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
93+ """Called when the predict epoch begins."""
94+
95+ def on_predict_epoch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , outputs : List [Any ]) -> None :
96+ """Called when the predict epoch ends."""
97+
98+ def on_epoch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
99+ r"""
100+ .. deprecated:: v1.6
101+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
102+ ``on_<train/validation/test>_epoch_start`` instead.
103+ Called when either of train/val/test epoch begins.
104+ """
105+
106+ def on_epoch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
107+ r"""
108+ .. deprecated:: v1.6
109+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
110+ ``on_<train/validation/test>_epoch_end`` instead.
111+ Called when either of train/val/test epoch ends.
112+ """
113+
114+ def on_batch_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
115+ r"""
116+ .. deprecated:: v1.6
117+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
118+ ``on_train_batch_start`` instead.
119+ Called when the training batch begins.
120+ """
121+
122+ def on_batch_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
123+ r"""
124+ .. deprecated:: v1.6
125+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
126+ ``on_train_batch_end`` instead.
127+ Called when the training batch ends.
128+ """
129+
130+ def on_validation_batch_start (
131+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
132+ ) -> None :
133+ """Called when the validation batch begins."""
134+
135+ def on_validation_batch_end (
136+ self ,
137+ trainer : "Trainer" ,
138+ pl_module : "pl.LightningModule" ,
139+ outputs : Optional [STEP_OUTPUT ],
140+ batch : Any ,
141+ batch_idx : int ,
142+ dataloader_idx : int ,
143+ ) -> None :
144+ """Called when the validation batch ends."""
145+
146+ def on_test_batch_start (
147+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
148+ ) -> None :
149+ """Called when the test batch begins."""
150+
151+ def on_test_batch_end (
152+ self ,
153+ trainer : "Trainer" ,
154+ pl_module : "pl.LightningModule" ,
155+ outputs : Optional [STEP_OUTPUT ],
156+ batch : Any ,
157+ batch_idx : int ,
158+ dataloader_idx : int ,
159+ ) -> None :
160+ """Called when the test batch ends."""
161+
162+ def on_predict_batch_start (
163+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
164+ ) -> None :
165+ """Called when the predict batch begins."""
166+
167+ def on_predict_batch_end (
168+ self ,
169+ trainer : "Trainer" ,
170+ pl_module : "pl.LightningModule" ,
171+ outputs : Any ,
172+ batch : Any ,
173+ batch_idx : int ,
174+ dataloader_idx : int ,
175+ ) -> None :
176+ """Called when the predict batch ends."""
177+
178+ def on_train_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
179+ """Called when the train begins."""
180+
181+ def on_train_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
182+ """Called when the train ends."""
183+
184+ def on_pretrain_routine_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
185+ r"""
186+ .. deprecated:: v1.6
187+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
188+ Called when the pretrain routine begins.
189+ """
190+
191+ def on_pretrain_routine_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
192+ r"""
193+ .. deprecated:: v1.6
194+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
195+ Called when the pretrain routine ends.
196+ """
197+
198+ def on_validation_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
199+ """Called when the validation loop begins."""
200+
201+ def on_validation_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
202+ """Called when the validation loop ends."""
203+
204+ def on_test_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
205+ """Called when the test begins."""
206+
207+ def on_test_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
208+ """Called when the test ends."""
209+
210+ def on_predict_start (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
211+ """Called when the predict begins."""
212+
213+ def on_predict_end (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
214+ """Called when predict ends."""
215+
216+ def on_keyboard_interrupt (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
217+ r"""
218+ .. deprecated:: v1.5
219+ This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
220+ Called when any trainer execution is interrupted by KeyboardInterrupt.
221+ """
222+
223+ def on_exception (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , exception : BaseException ) -> None :
224+ """Called when any trainer execution is interrupted by an exception."""
225+
226+ def state_dict (self ) -> Dict [str , Any ]:
227+ """Called when saving a checkpoint, implement to generate callback's ``state_dict``.
228+ Returns:
229+ A dictionary containing callback state.
230+ """
231+ return {}
232+
233+ def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
234+ """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
235+ Args:
236+ state_dict: the callback state returned by ``state_dict``.
237+ """
238+ pass
239+
240+ def on_save_checkpoint (
241+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , checkpoint : Dict [str , Any ]
242+ ) -> Optional [dict ]:
243+ r"""
244+ Called when saving a checkpoint to give you a chance to store anything else you might want to save.
245+ Args:
246+ trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
247+ pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
248+ checkpoint: the checkpoint dictionary that will be saved.
249+ Returns:
250+ None or the callback state. Support for returning callback state will be removed in v1.8.
251+ .. deprecated:: v1.6
252+ Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
253+ Implement ``Callback.state_dict`` instead to return state.
254+ In v1.8 ``Callback.on_save_checkpoint`` can only return None.
255+ """
256+
257+ def on_load_checkpoint (
258+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , callback_state : Dict [str , Any ]
259+ ) -> None :
260+ r"""
261+ Called when loading a model checkpoint, use to reload state.
262+ Args:
263+ trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
264+ pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
265+ callback_state: the callback state returned by ``on_save_checkpoint``.
266+ Note:
267+ The ``on_load_checkpoint`` won't be called with an undefined state.
268+ If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
269+ you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
270+ .. deprecated:: v1.6
271+ This callback hook will change its signature and behavior in v1.8.
272+ If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
273+ In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
274+ checkpoint dictionary instead of only the callback state from the checkpoint.
275+ """
276+
277+ def on_before_backward (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , loss : torch .Tensor ) -> None :
278+ """Called before ``loss.backward()``."""
279+
280+ def on_after_backward (self , trainer : "Trainer" , pl_module : "pl.LightningModule" ) -> None :
281+ """Called after ``loss.backward()`` and before optimizers are stepped."""
282+
283+ def on_before_optimizer_step (
284+ self , trainer : "Trainer" , pl_module : "pl.LightningModule" , optimizer : Optimizer , opt_idx : int
285+ ) -> None :
286+ """Called before ``optimizer.step()``."""
287+
288+ def on_before_zero_grad (self , trainer : "Trainer" , pl_module : "pl.LightningModule" , optimizer : Optimizer ) -> None :
289+ """Called before ``optimizer.zero_grad()``."""
0 commit comments