Skip to content

Commit 25ecb11

Browse files
authored
Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint (cherry picked from commit 1a8e0bd) * Not a workable version (cherry picked from commit 3498e18) * vessel * ckpt * . * vessel * . * . * checkpoint callback * . * cleanup * logger * . * test * . * add test * . * . * . * . * New reward * Add train API * fix mypy * fix lint * More comment * 3.7 compat * fix test * fix test * . * Resolve comments * fix typehint
1 parent 2ca0d88 commit 25ecb11

File tree

17 files changed

+1410
-145
lines changed

17 files changed

+1410
-145
lines changed

qlib/rl/entries/__init__.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

qlib/rl/entries/test.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

qlib/rl/entries/train.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

qlib/rl/order_execution/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .interpreter import *
1010
from .network import *
1111
from .policy import *
12+
from .reward import *
1213
from .simulator_simple import *

qlib/rl/order_execution/reward.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
from typing import cast
7+
8+
import numpy as np
9+
from qlib.rl.reward import Reward
10+
11+
from .simulator_simple import SAOEState, SAOEMetrics
12+
13+
__all__ = ["PAPenaltyReward"]
14+
15+
16+
class PAPenaltyReward(Reward[SAOEState]):
17+
"""Encourage higher PAs, but penalize stacking all the amounts within a very short time.
18+
Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`.
19+
20+
Parameters
21+
----------
22+
penalty
23+
The penalty for large volume in a short time.
24+
"""
25+
26+
def __init__(self, penalty: float = 100.0):
27+
self.penalty = penalty
28+
29+
def reward(self, simulator_state: SAOEState) -> float:
30+
whole_order = simulator_state.order.amount
31+
assert whole_order > 0
32+
last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict())
33+
pa = last_step["pa"] * last_step["amount"] / whole_order
34+
35+
# Inspect the "break-down" of the latest step: trading amount at every tick
36+
last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :]
37+
penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum()
38+
39+
reward = pa + penalty
40+
41+
# Throw error in case of NaN
42+
assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}"
43+
44+
self.log("reward/pa", pa)
45+
self.log("reward/penalty", penalty)
46+
return reward

qlib/rl/order_execution/simulator_simple.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,14 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
131131
"""
132132

133133
history_exec: pd.DataFrame
134-
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns."""
134+
"""All execution history at every possible time ticks. See :class:`SAOEMetrics` for available columns.
135+
Index is ``datetime``.
136+
"""
135137

136138
history_steps: pd.DataFrame
137139
"""Positions at each step. The position before first step is also recorded.
138-
See :class:`SAOEMetrics` for available columns."""
140+
See :class:`SAOEMetrics` for available columns.
141+
Index is ``datetime``, which is the **starting** time of each step."""
139142

140143
metrics: SAOEMetrics | None
141144
"""Metrics. Only available when done."""

qlib/rl/trainer/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Train, test, inference utilities."""
5+
6+
from .api import backtest, train
7+
from .callbacks import EarlyStopping, Checkpoint
8+
from .trainer import Trainer
9+
from .vessel import TrainingVessel, TrainingVesselBase

qlib/rl/trainer/api.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
from typing import Callable, Sequence, cast, Any
7+
8+
from tianshou.policy import BasePolicy
9+
10+
from qlib.rl.simulator import InitialStateType, Simulator
11+
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
12+
from qlib.rl.reward import Reward
13+
from qlib.rl.utils import FiniteEnvType, LogWriter
14+
15+
from .vessel import TrainingVessel
16+
from .trainer import Trainer
17+
18+
19+
def train(
20+
simulator_fn: Callable[[InitialStateType], Simulator],
21+
state_interpreter: StateInterpreter,
22+
action_interpreter: ActionInterpreter,
23+
initial_states: Sequence[InitialStateType],
24+
policy: BasePolicy,
25+
reward: Reward,
26+
vessel_kwargs: dict[str, Any],
27+
trainer_kwargs: dict[str, Any],
28+
) -> None:
29+
"""Train a policy with the parallelism provided by RL framework.
30+
31+
Experimental API. Parameters might change shortly.
32+
33+
Parameters
34+
----------
35+
simulator_fn
36+
Callable receiving initial seed, returning a simulator.
37+
state_interpreter
38+
Interprets the state of simulators.
39+
action_interpreter
40+
Interprets the policy actions.
41+
initial_states
42+
Initial states to iterate over. Every state will be run exactly once.
43+
policy
44+
Policy to train against.
45+
reward
46+
Reward function.
47+
vessel_kwargs
48+
Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.
49+
trainer_kwargs
50+
Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.
51+
"""
52+
53+
vessel = TrainingVessel(
54+
simulator_fn=simulator_fn,
55+
state_interpreter=state_interpreter,
56+
action_interpreter=action_interpreter,
57+
policy=policy,
58+
train_initial_states=initial_states,
59+
reward=reward, # ignore none
60+
**vessel_kwargs,
61+
)
62+
trainer = Trainer(**trainer_kwargs)
63+
trainer.fit(vessel)
64+
65+
66+
def backtest(
67+
simulator_fn: Callable[[InitialStateType], Simulator],
68+
state_interpreter: StateInterpreter,
69+
action_interpreter: ActionInterpreter,
70+
initial_states: Sequence[InitialStateType],
71+
policy: BasePolicy,
72+
logger: LogWriter | list[LogWriter],
73+
reward: Reward | None = None,
74+
finite_env_type: FiniteEnvType = "subproc",
75+
concurrency: int = 2,
76+
) -> None:
77+
"""Backtest with the parallelism provided by RL framework.
78+
79+
Experimental API. Parameters might change shortly.
80+
81+
Parameters
82+
----------
83+
simulator_fn
84+
Callable receiving initial seed, returning a simulator.
85+
state_interpreter
86+
Interprets the state of simulators.
87+
action_interpreter
88+
Interprets the policy actions.
89+
initial_states
90+
Initial states to iterate over. Every state will be run exactly once.
91+
policy
92+
Policy to test against.
93+
logger
94+
Logger to record the backtest results. Logger must be present because
95+
without logger, all information will be lost.
96+
reward
97+
Optional reward function. For backtest, this is for testing the rewards
98+
and logging them only.
99+
finite_env_type
100+
Type of finite env implementation.
101+
concurrency
102+
Parallel workers.
103+
"""
104+
105+
vessel = TrainingVessel(
106+
simulator_fn=simulator_fn,
107+
state_interpreter=state_interpreter,
108+
action_interpreter=action_interpreter,
109+
policy=policy,
110+
test_initial_states=initial_states,
111+
reward=cast(Reward, reward), # ignore none
112+
)
113+
trainer = Trainer(
114+
finite_env_type=finite_env_type,
115+
concurrency=concurrency,
116+
loggers=logger,
117+
)
118+
trainer.test(vessel)

0 commit comments

Comments
 (0)