-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Qlib RL framework (stage 2) - trainer #1125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
542d295
checkpoint
ultmaster d5e15ac
Not a workable version
ultmaster 4acb0c2
vessel
ultmaster 2d1d8cb
ckpt
ultmaster 1f85487
.
ultmaster ea40fdf
vessel
ultmaster 6816c7e
.
ultmaster 319766f
.
ultmaster f2c02e0
checkpoint callback
ultmaster 4db8567
.
ultmaster 163bc2a
cleanup
ultmaster b76b810
logger
ultmaster 647dc76
.
ultmaster fc7eb9a
test
ultmaster 38ecc21
.
ultmaster 67a53fb
add test
ultmaster 6b391de
.
ultmaster c73ec3a
.
ultmaster 5980e45
.
ultmaster 85710ef
.
ultmaster 26883c8
New reward
ultmaster 30afc6c
Add train API
ultmaster 68716e4
fix mypy
ultmaster cbf2577
fix lint
ultmaster 8e479c2
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster 4aa421f
More comment
ultmaster 54d2342
3.7 compat
ultmaster f432525
fix test
ultmaster 5344846
fix test
ultmaster 3123f1a
.
ultmaster 1bb307d
Resolve comments
ultmaster e8c6f4f
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster 5fe8bff
fix typehint
ultmaster 4b5fcb0
Merge branch 'main' of https://github.com/microsoft/qlib into rl-trai…
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import cast | ||
|
|
||
| import numpy as np | ||
| from qlib.rl.reward import Reward | ||
|
|
||
| from .simulator_simple import SAOEState, SAOEMetrics | ||
|
|
||
| __all__ = ["PAPenaltyReward"] | ||
|
|
||
|
|
||
| class PAPenaltyReward(Reward[SAOEState]): | ||
| """Encourage higher PAs, but penalize stacking all the amounts within a very short time. | ||
| Formally, for each time step, the reward is :math:`(PA_t * vol_t / target - vol_t^2 * penalty)`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| penalty | ||
| The penalty for large volume in a short time. | ||
| """ | ||
|
|
||
| def __init__(self, penalty: float = 100.0): | ||
| self.penalty = penalty | ||
|
|
||
| def reward(self, simulator_state: SAOEState) -> float: | ||
| whole_order = simulator_state.order.amount | ||
| assert whole_order > 0 | ||
| last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict()) | ||
| pa = last_step["pa"] * last_step["amount"] / whole_order | ||
|
|
||
| # Inspect the "break-down" of the latest step: trading amount at every tick | ||
| last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :] | ||
| penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum() | ||
|
|
||
| reward = pa + penalty | ||
|
|
||
| # Throw error in case of NaN | ||
| assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}" | ||
|
|
||
| self.log("reward/pa", pa) | ||
| self.log("reward/penalty", penalty) | ||
| return reward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| """Train, test, inference utilities.""" | ||
|
|
||
| from .api import backtest, train | ||
| from .callbacks import EarlyStopping, Checkpoint | ||
| from .trainer import Trainer | ||
| from .vessel import TrainingVessel, TrainingVesselBase |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Callable, Sequence, cast, Any | ||
|
|
||
| from tianshou.policy import BasePolicy | ||
|
|
||
| from qlib.rl.simulator import InitialStateType, Simulator | ||
| from qlib.rl.interpreter import StateInterpreter, ActionInterpreter | ||
| from qlib.rl.reward import Reward | ||
| from qlib.rl.utils import FiniteEnvType, LogWriter | ||
|
|
||
| from .vessel import TrainingVessel | ||
| from .trainer import Trainer | ||
|
|
||
|
|
||
| def train( | ||
| simulator_fn: Callable[[InitialStateType], Simulator], | ||
| state_interpreter: StateInterpreter, | ||
| action_interpreter: ActionInterpreter, | ||
| initial_states: Sequence[InitialStateType], | ||
| policy: BasePolicy, | ||
| reward: Reward, | ||
| vessel_kwargs: dict[str, Any], | ||
| trainer_kwargs: dict[str, Any], | ||
| ) -> None: | ||
| """Train a policy with the parallelism provided by RL framework. | ||
|
|
||
| Experimental API. Parameters might change shortly. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| simulator_fn | ||
| Callable receiving initial seed, returning a simulator. | ||
| state_interpreter | ||
| Interprets the state of simulators. | ||
| action_interpreter | ||
| Interprets the policy actions. | ||
| initial_states | ||
| Initial states to iterate over. Every state will be run exactly once. | ||
| policy | ||
| Policy to train against. | ||
| reward | ||
| Reward function. | ||
| vessel_kwargs | ||
| Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``. | ||
| trainer_kwargs | ||
| Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``. | ||
| """ | ||
|
|
||
| vessel = TrainingVessel( | ||
| simulator_fn=simulator_fn, | ||
| state_interpreter=state_interpreter, | ||
| action_interpreter=action_interpreter, | ||
| policy=policy, | ||
| train_initial_states=initial_states, | ||
| reward=reward, # ignore none | ||
| **vessel_kwargs, | ||
| ) | ||
| trainer = Trainer(**trainer_kwargs) | ||
| trainer.fit(vessel) | ||
|
|
||
|
|
||
| def backtest( | ||
| simulator_fn: Callable[[InitialStateType], Simulator], | ||
| state_interpreter: StateInterpreter, | ||
| action_interpreter: ActionInterpreter, | ||
| initial_states: Sequence[InitialStateType], | ||
| policy: BasePolicy, | ||
| logger: LogWriter | list[LogWriter], | ||
| reward: Reward | None = None, | ||
| finite_env_type: FiniteEnvType = "subproc", | ||
| concurrency: int = 2, | ||
| ) -> None: | ||
| """Backtest with the parallelism provided by RL framework. | ||
|
|
||
| Experimental API. Parameters might change shortly. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| simulator_fn | ||
| Callable receiving initial seed, returning a simulator. | ||
| state_interpreter | ||
| Interprets the state of simulators. | ||
| action_interpreter | ||
| Interprets the policy actions. | ||
| initial_states | ||
| Initial states to iterate over. Every state will be run exactly once. | ||
| policy | ||
| Policy to test against. | ||
| logger | ||
| Logger to record the backtest results. Logger must be present because | ||
| without logger, all information will be lost. | ||
| reward | ||
| Optional reward function. For backtest, this is for testing the rewards | ||
| and logging them only. | ||
| finite_env_type | ||
| Type of finite env implementation. | ||
| concurrency | ||
| Parallel workers. | ||
| """ | ||
|
|
||
| vessel = TrainingVessel( | ||
| simulator_fn=simulator_fn, | ||
| state_interpreter=state_interpreter, | ||
| action_interpreter=action_interpreter, | ||
| policy=policy, | ||
| test_initial_states=initial_states, | ||
| reward=cast(Reward, reward), # ignore none | ||
| ) | ||
| trainer = Trainer( | ||
| finite_env_type=finite_env_type, | ||
| concurrency=concurrency, | ||
| loggers=logger, | ||
| ) | ||
| trainer.test(vessel) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.