|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from typing import Callable, Sequence, cast, Any |
| 7 | + |
| 8 | +from tianshou.policy import BasePolicy |
| 9 | + |
| 10 | +from qlib.rl.simulator import InitialStateType, Simulator |
| 11 | +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter |
| 12 | +from qlib.rl.reward import Reward |
| 13 | +from qlib.rl.utils import FiniteEnvType, LogWriter |
| 14 | + |
| 15 | +from .vessel import TrainingVessel |
| 16 | +from .trainer import Trainer |
| 17 | + |
| 18 | + |
| 19 | +def train( |
| 20 | + simulator_fn: Callable[[InitialStateType], Simulator], |
| 21 | + state_interpreter: StateInterpreter, |
| 22 | + action_interpreter: ActionInterpreter, |
| 23 | + initial_states: Sequence[InitialStateType], |
| 24 | + policy: BasePolicy, |
| 25 | + reward: Reward, |
| 26 | + vessel_kwargs: dict[str, Any], |
| 27 | + trainer_kwargs: dict[str, Any], |
| 28 | +) -> None: |
| 29 | + """Train a policy with the parallelism provided by RL framework. |
| 30 | +
|
| 31 | + Experimental API. Parameters might change shortly. |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + simulator_fn |
| 36 | + Callable receiving initial seed, returning a simulator. |
| 37 | + state_interpreter |
| 38 | + Interprets the state of simulators. |
| 39 | + action_interpreter |
| 40 | + Interprets the policy actions. |
| 41 | + initial_states |
| 42 | + Initial states to iterate over. Every state will be run exactly once. |
| 43 | + policy |
| 44 | + Policy to train against. |
| 45 | + reward |
| 46 | + Reward function. |
| 47 | + vessel_kwargs |
| 48 | + Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``. |
| 49 | + trainer_kwargs |
| 50 | + Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``. |
| 51 | + """ |
| 52 | + |
| 53 | + vessel = TrainingVessel( |
| 54 | + simulator_fn=simulator_fn, |
| 55 | + state_interpreter=state_interpreter, |
| 56 | + action_interpreter=action_interpreter, |
| 57 | + policy=policy, |
| 58 | + train_initial_states=initial_states, |
| 59 | + reward=reward, # ignore none |
| 60 | + **vessel_kwargs, |
| 61 | + ) |
| 62 | + trainer = Trainer(**trainer_kwargs) |
| 63 | + trainer.fit(vessel) |
| 64 | + |
| 65 | + |
| 66 | +def backtest( |
| 67 | + simulator_fn: Callable[[InitialStateType], Simulator], |
| 68 | + state_interpreter: StateInterpreter, |
| 69 | + action_interpreter: ActionInterpreter, |
| 70 | + initial_states: Sequence[InitialStateType], |
| 71 | + policy: BasePolicy, |
| 72 | + logger: LogWriter | list[LogWriter], |
| 73 | + reward: Reward | None = None, |
| 74 | + finite_env_type: FiniteEnvType = "subproc", |
| 75 | + concurrency: int = 2, |
| 76 | +) -> None: |
| 77 | + """Backtest with the parallelism provided by RL framework. |
| 78 | +
|
| 79 | + Experimental API. Parameters might change shortly. |
| 80 | +
|
| 81 | + Parameters |
| 82 | + ---------- |
| 83 | + simulator_fn |
| 84 | + Callable receiving initial seed, returning a simulator. |
| 85 | + state_interpreter |
| 86 | + Interprets the state of simulators. |
| 87 | + action_interpreter |
| 88 | + Interprets the policy actions. |
| 89 | + initial_states |
| 90 | + Initial states to iterate over. Every state will be run exactly once. |
| 91 | + policy |
| 92 | + Policy to test against. |
| 93 | + logger |
| 94 | + Logger to record the backtest results. Logger must be present because |
| 95 | + without logger, all information will be lost. |
| 96 | + reward |
| 97 | + Optional reward function. For backtest, this is for testing the rewards |
| 98 | + and logging them only. |
| 99 | + finite_env_type |
| 100 | + Type of finite env implementation. |
| 101 | + concurrency |
| 102 | + Parallel workers. |
| 103 | + """ |
| 104 | + |
| 105 | + vessel = TrainingVessel( |
| 106 | + simulator_fn=simulator_fn, |
| 107 | + state_interpreter=state_interpreter, |
| 108 | + action_interpreter=action_interpreter, |
| 109 | + policy=policy, |
| 110 | + test_initial_states=initial_states, |
| 111 | + reward=cast(Reward, reward), # ignore none |
| 112 | + ) |
| 113 | + trainer = Trainer( |
| 114 | + finite_env_type=finite_env_type, |
| 115 | + concurrency=concurrency, |
| 116 | + loggers=logger, |
| 117 | + ) |
| 118 | + trainer.test(vessel) |
0 commit comments