Skip to content

Commit 3498e18

Browse files
committed
Not a workable version
1 parent 1a8e0bd commit 3498e18

File tree

2 files changed

+423
-43
lines changed

2 files changed

+423
-43
lines changed

qlib/rl/trainer/callbacks.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Callbacks to insert customized recipes during the training.
5+
Mimicks the hooks of Keras / PyTorch-Lightning, but tailored for the context of RL.
6+
"""
7+
8+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
from .trainer import Trainer
12+
13+
14+
class Callback:
15+
"""Base class of all callbacks."""
16+
17+
def setup(self, trainer: "Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
18+
"""Called when fit, validate, test, predict, or tune begins."""
19+
20+
def teardown(self, trainer: "Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
21+
"""Called when fit, validate, test, predict, or tune ends."""
22+
23+
def on_init_start(self, trainer: "Trainer") -> None:
24+
r"""
25+
.. deprecated:: v1.6
26+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
27+
Called when the trainer initialization begins, model has not yet been set.
28+
"""
29+
30+
def on_init_end(self, trainer: "Trainer") -> None:
31+
r"""
32+
.. deprecated:: v1.6
33+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
34+
Called when the trainer initialization ends, model has not yet been set.
35+
"""
36+
37+
def on_fit_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
38+
"""Called when fit begins."""
39+
40+
def on_fit_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
41+
"""Called when fit ends."""
42+
43+
def on_sanity_check_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
44+
"""Called when the validation sanity check starts."""
45+
46+
def on_sanity_check_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
47+
"""Called when the validation sanity check ends."""
48+
49+
def on_train_batch_start(
50+
self,
51+
trainer: "Trainer",
52+
pl_module: "pl.LightningModule",
53+
batch: Any,
54+
batch_idx: int,
55+
unused: int = 0,
56+
) -> None:
57+
"""Called when the train batch begins."""
58+
59+
def on_train_batch_end(
60+
self,
61+
trainer: "Trainer",
62+
pl_module: "pl.LightningModule",
63+
outputs: STEP_OUTPUT,
64+
batch: Any,
65+
batch_idx: int,
66+
unused: int = 0,
67+
) -> None:
68+
"""Called when the train batch ends."""
69+
70+
def on_train_epoch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
71+
"""Called when the train epoch begins."""
72+
73+
def on_train_epoch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
74+
"""Called when the train epoch ends.
75+
To access all batch outputs at the end of the epoch, either:
76+
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
77+
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
78+
"""
79+
80+
def on_validation_epoch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
81+
"""Called when the val epoch begins."""
82+
83+
def on_validation_epoch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
84+
"""Called when the val epoch ends."""
85+
86+
def on_test_epoch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
87+
"""Called when the test epoch begins."""
88+
89+
def on_test_epoch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
90+
"""Called when the test epoch ends."""
91+
92+
def on_predict_epoch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
93+
"""Called when the predict epoch begins."""
94+
95+
def on_predict_epoch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None:
96+
"""Called when the predict epoch ends."""
97+
98+
def on_epoch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
99+
r"""
100+
.. deprecated:: v1.6
101+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
102+
``on_<train/validation/test>_epoch_start`` instead.
103+
Called when either of train/val/test epoch begins.
104+
"""
105+
106+
def on_epoch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
107+
r"""
108+
.. deprecated:: v1.6
109+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
110+
``on_<train/validation/test>_epoch_end`` instead.
111+
Called when either of train/val/test epoch ends.
112+
"""
113+
114+
def on_batch_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
115+
r"""
116+
.. deprecated:: v1.6
117+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
118+
``on_train_batch_start`` instead.
119+
Called when the training batch begins.
120+
"""
121+
122+
def on_batch_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
123+
r"""
124+
.. deprecated:: v1.6
125+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
126+
``on_train_batch_end`` instead.
127+
Called when the training batch ends.
128+
"""
129+
130+
def on_validation_batch_start(
131+
self, trainer: "Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
132+
) -> None:
133+
"""Called when the validation batch begins."""
134+
135+
def on_validation_batch_end(
136+
self,
137+
trainer: "Trainer",
138+
pl_module: "pl.LightningModule",
139+
outputs: Optional[STEP_OUTPUT],
140+
batch: Any,
141+
batch_idx: int,
142+
dataloader_idx: int,
143+
) -> None:
144+
"""Called when the validation batch ends."""
145+
146+
def on_test_batch_start(
147+
self, trainer: "Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
148+
) -> None:
149+
"""Called when the test batch begins."""
150+
151+
def on_test_batch_end(
152+
self,
153+
trainer: "Trainer",
154+
pl_module: "pl.LightningModule",
155+
outputs: Optional[STEP_OUTPUT],
156+
batch: Any,
157+
batch_idx: int,
158+
dataloader_idx: int,
159+
) -> None:
160+
"""Called when the test batch ends."""
161+
162+
def on_predict_batch_start(
163+
self, trainer: "Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
164+
) -> None:
165+
"""Called when the predict batch begins."""
166+
167+
def on_predict_batch_end(
168+
self,
169+
trainer: "Trainer",
170+
pl_module: "pl.LightningModule",
171+
outputs: Any,
172+
batch: Any,
173+
batch_idx: int,
174+
dataloader_idx: int,
175+
) -> None:
176+
"""Called when the predict batch ends."""
177+
178+
def on_train_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
179+
"""Called when the train begins."""
180+
181+
def on_train_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
182+
"""Called when the train ends."""
183+
184+
def on_pretrain_routine_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
185+
r"""
186+
.. deprecated:: v1.6
187+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
188+
Called when the pretrain routine begins.
189+
"""
190+
191+
def on_pretrain_routine_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
192+
r"""
193+
.. deprecated:: v1.6
194+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
195+
Called when the pretrain routine ends.
196+
"""
197+
198+
def on_validation_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
199+
"""Called when the validation loop begins."""
200+
201+
def on_validation_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
202+
"""Called when the validation loop ends."""
203+
204+
def on_test_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
205+
"""Called when the test begins."""
206+
207+
def on_test_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
208+
"""Called when the test ends."""
209+
210+
def on_predict_start(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
211+
"""Called when the predict begins."""
212+
213+
def on_predict_end(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
214+
"""Called when predict ends."""
215+
216+
def on_keyboard_interrupt(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
217+
r"""
218+
.. deprecated:: v1.5
219+
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
220+
Called when any trainer execution is interrupted by KeyboardInterrupt.
221+
"""
222+
223+
def on_exception(self, trainer: "Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
224+
"""Called when any trainer execution is interrupted by an exception."""
225+
226+
def state_dict(self) -> Dict[str, Any]:
227+
"""Called when saving a checkpoint, implement to generate callback's ``state_dict``.
228+
Returns:
229+
A dictionary containing callback state.
230+
"""
231+
return {}
232+
233+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
234+
"""Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
235+
Args:
236+
state_dict: the callback state returned by ``state_dict``.
237+
"""
238+
pass
239+
240+
def on_save_checkpoint(
241+
self, trainer: "Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
242+
) -> Optional[dict]:
243+
r"""
244+
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
245+
Args:
246+
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
247+
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
248+
checkpoint: the checkpoint dictionary that will be saved.
249+
Returns:
250+
None or the callback state. Support for returning callback state will be removed in v1.8.
251+
.. deprecated:: v1.6
252+
Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
253+
Implement ``Callback.state_dict`` instead to return state.
254+
In v1.8 ``Callback.on_save_checkpoint`` can only return None.
255+
"""
256+
257+
def on_load_checkpoint(
258+
self, trainer: "Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
259+
) -> None:
260+
r"""
261+
Called when loading a model checkpoint, use to reload state.
262+
Args:
263+
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
264+
pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
265+
callback_state: the callback state returned by ``on_save_checkpoint``.
266+
Note:
267+
The ``on_load_checkpoint`` won't be called with an undefined state.
268+
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
269+
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
270+
.. deprecated:: v1.6
271+
This callback hook will change its signature and behavior in v1.8.
272+
If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
273+
In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
274+
checkpoint dictionary instead of only the callback state from the checkpoint.
275+
"""
276+
277+
def on_before_backward(self, trainer: "Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
278+
"""Called before ``loss.backward()``."""
279+
280+
def on_after_backward(self, trainer: "Trainer", pl_module: "pl.LightningModule") -> None:
281+
"""Called after ``loss.backward()`` and before optimizers are stepped."""
282+
283+
def on_before_optimizer_step(
284+
self, trainer: "Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
285+
) -> None:
286+
"""Called before ``optimizer.step()``."""
287+
288+
def on_before_zero_grad(self, trainer: "Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
289+
"""Called before ``optimizer.zero_grad()``."""

0 commit comments

Comments
 (0)