| 
 | 1 | +# Copyright (c) Microsoft Corporation.  | 
 | 2 | +# Licensed under the MIT License.  | 
 | 3 | + | 
 | 4 | +from __future__ import annotations  | 
 | 5 | + | 
 | 6 | +from typing import Callable, Sequence, cast, Any  | 
 | 7 | + | 
 | 8 | +from tianshou.policy import BasePolicy  | 
 | 9 | + | 
 | 10 | +from qlib.rl.simulator import InitialStateType, Simulator  | 
 | 11 | +from qlib.rl.interpreter import StateInterpreter, ActionInterpreter  | 
 | 12 | +from qlib.rl.reward import Reward  | 
 | 13 | +from qlib.rl.utils import FiniteEnvType, LogWriter  | 
 | 14 | + | 
 | 15 | +from .vessel import TrainingVessel  | 
 | 16 | +from .trainer import Trainer  | 
 | 17 | + | 
 | 18 | + | 
 | 19 | +def train(  | 
 | 20 | +    simulator_fn: Callable[[InitialStateType], Simulator],  | 
 | 21 | +    state_interpreter: StateInterpreter,  | 
 | 22 | +    action_interpreter: ActionInterpreter,  | 
 | 23 | +    initial_states: Sequence[InitialStateType],  | 
 | 24 | +    policy: BasePolicy,  | 
 | 25 | +    reward: Reward,  | 
 | 26 | +    vessel_kwargs: dict[str, Any],  | 
 | 27 | +    trainer_kwargs: dict[str, Any],  | 
 | 28 | +) -> None:  | 
 | 29 | +    """Train a policy with the parallelism provided by RL framework.  | 
 | 30 | +
  | 
 | 31 | +    Experimental API. Parameters might change shortly.  | 
 | 32 | +
  | 
 | 33 | +    Parameters  | 
 | 34 | +    ----------  | 
 | 35 | +    simulator_fn  | 
 | 36 | +        Callable receiving initial seed, returning a simulator.  | 
 | 37 | +    state_interpreter  | 
 | 38 | +        Interprets the state of simulators.  | 
 | 39 | +    action_interpreter  | 
 | 40 | +        Interprets the policy actions.  | 
 | 41 | +    initial_states  | 
 | 42 | +        Initial states to iterate over. Every state will be run exactly once.  | 
 | 43 | +    policy  | 
 | 44 | +        Policy to train against.  | 
 | 45 | +    reward  | 
 | 46 | +        Reward function.  | 
 | 47 | +    vessel_kwargs  | 
 | 48 | +        Keyword arguments passed to :class:`TrainingVessel`, like ``episode_per_iter``.  | 
 | 49 | +    trainer_kwargs  | 
 | 50 | +        Keyword arguments passed to :class:`Trainer`, like ``finite_env_type``, ``concurrency``.  | 
 | 51 | +    """  | 
 | 52 | + | 
 | 53 | +    vessel = TrainingVessel(  | 
 | 54 | +        simulator_fn=simulator_fn,  | 
 | 55 | +        state_interpreter=state_interpreter,  | 
 | 56 | +        action_interpreter=action_interpreter,  | 
 | 57 | +        policy=policy,  | 
 | 58 | +        train_initial_states=initial_states,  | 
 | 59 | +        reward=reward,  # ignore none  | 
 | 60 | +        **vessel_kwargs,  | 
 | 61 | +    )  | 
 | 62 | +    trainer = Trainer(**trainer_kwargs)  | 
 | 63 | +    trainer.fit(vessel)  | 
 | 64 | + | 
 | 65 | + | 
 | 66 | +def backtest(  | 
 | 67 | +    simulator_fn: Callable[[InitialStateType], Simulator],  | 
 | 68 | +    state_interpreter: StateInterpreter,  | 
 | 69 | +    action_interpreter: ActionInterpreter,  | 
 | 70 | +    initial_states: Sequence[InitialStateType],  | 
 | 71 | +    policy: BasePolicy,  | 
 | 72 | +    logger: LogWriter | list[LogWriter],  | 
 | 73 | +    reward: Reward | None = None,  | 
 | 74 | +    finite_env_type: FiniteEnvType = "subproc",  | 
 | 75 | +    concurrency: int = 2,  | 
 | 76 | +) -> None:  | 
 | 77 | +    """Backtest with the parallelism provided by RL framework.  | 
 | 78 | +
  | 
 | 79 | +    Experimental API. Parameters might change shortly.  | 
 | 80 | +
  | 
 | 81 | +    Parameters  | 
 | 82 | +    ----------  | 
 | 83 | +    simulator_fn  | 
 | 84 | +        Callable receiving initial seed, returning a simulator.  | 
 | 85 | +    state_interpreter  | 
 | 86 | +        Interprets the state of simulators.  | 
 | 87 | +    action_interpreter  | 
 | 88 | +        Interprets the policy actions.  | 
 | 89 | +    initial_states  | 
 | 90 | +        Initial states to iterate over. Every state will be run exactly once.  | 
 | 91 | +    policy  | 
 | 92 | +        Policy to test against.  | 
 | 93 | +    logger  | 
 | 94 | +        Logger to record the backtest results. Logger must be present because  | 
 | 95 | +        without logger, all information will be lost.  | 
 | 96 | +    reward  | 
 | 97 | +        Optional reward function. For backtest, this is for testing the rewards  | 
 | 98 | +        and logging them only.  | 
 | 99 | +    finite_env_type  | 
 | 100 | +        Type of finite env implementation.  | 
 | 101 | +    concurrency  | 
 | 102 | +        Parallel workers.  | 
 | 103 | +    """  | 
 | 104 | + | 
 | 105 | +    vessel = TrainingVessel(  | 
 | 106 | +        simulator_fn=simulator_fn,  | 
 | 107 | +        state_interpreter=state_interpreter,  | 
 | 108 | +        action_interpreter=action_interpreter,  | 
 | 109 | +        policy=policy,  | 
 | 110 | +        test_initial_states=initial_states,  | 
 | 111 | +        reward=cast(Reward, reward),  # ignore none  | 
 | 112 | +    )  | 
 | 113 | +    trainer = Trainer(  | 
 | 114 | +        finite_env_type=finite_env_type,  | 
 | 115 | +        concurrency=concurrency,  | 
 | 116 | +        loggers=logger,  | 
 | 117 | +    )  | 
 | 118 | +    trainer.test(vessel)  | 
0 commit comments