Skip to content

Commit 4994faf

Browse files
authored
Merge pull request microsoft#280 from yongzhengqi/main
Implement Enhanced Indexing as a Portfolio Optimizer
2 parents 18592b8 + 5f642b2 commit 4994faf

File tree

14 files changed

+624
-224
lines changed

14 files changed

+624
-224
lines changed

qlib/contrib/strategy/strategy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88

99
from ..backtest.order import Order
10-
from ...utils import get_pre_trading_date
1110
from .order_generator import OrderGenWInteract
1211

1312

@@ -390,11 +389,11 @@ def filter_stock(l):
390389
current_stock_list = current_temp.get_stock_list()
391390
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
392391

393-
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
394-
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
392+
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
393+
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
395394
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
396395
for code in buy:
397-
# check is stock supended
396+
# check is stock suspended
398397
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
399398
continue
400399
# buy order

qlib/model/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def fit(self, dataset: Dataset):
4343
4444
# get weights
4545
try:
46-
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
46+
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"],
47+
data_key=DataHandlerLP.DK_L)
4748
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
4849
except KeyError as e:
4950
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)

qlib/model/riskmodel/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from .base import RiskModel
5+
from .poet import POETCovEstimator
6+
from .shrink import ShrinkCovEstimator
7+
from .structured import StructuredCovEstimator

qlib/model/riskmodel/base.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import inspect
5+
import numpy as np
6+
import pandas as pd
7+
from typing import Union
8+
9+
from qlib.model.base import BaseModel
10+
11+
12+
class RiskModel(BaseModel):
13+
"""Risk Model
14+
15+
A risk model is used to estimate the covariance matrix of stock returns.
16+
"""
17+
18+
MASK_NAN = "mask"
19+
FILL_NAN = "fill"
20+
IGNORE_NAN = "ignore"
21+
22+
def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
23+
"""
24+
Args:
25+
nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
26+
assume_centered (bool): whether the data is assumed to be centered.
27+
scale_return (bool): whether scale returns as percentage.
28+
"""
29+
# nan
30+
assert nan_option in [
31+
self.MASK_NAN,
32+
self.FILL_NAN,
33+
self.IGNORE_NAN,
34+
], f"`nan_option={nan_option}` is not supported"
35+
self.nan_option = nan_option
36+
37+
self.assume_centered = assume_centered
38+
self.scale_return = scale_return
39+
40+
def predict(
41+
self,
42+
X: Union[pd.Series, pd.DataFrame, np.ndarray],
43+
return_corr: bool = False,
44+
is_price: bool = True,
45+
return_decomposed_components=False,
46+
) -> Union[pd.DataFrame, np.ndarray, tuple]:
47+
"""
48+
Args:
49+
X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
50+
with variables as columns and observations as rows.
51+
return_corr (bool): whether return the correlation matrix.
52+
is_price (bool): whether `X` contains price (if not assume stock returns).
53+
return_decomposed_components (bool): whether return decomposed components of the covariance matrix.
54+
55+
Returns:
56+
pd.DataFrame or np.ndarray: estimated covariance (or correlation).
57+
"""
58+
assert (
59+
not return_corr or not return_decomposed_components
60+
), "Can only return either correlation matrix or decomposed components."
61+
62+
# transform input into 2D array
63+
if not isinstance(X, (pd.Series, pd.DataFrame)):
64+
columns = None
65+
else:
66+
if isinstance(X.index, pd.MultiIndex):
67+
if isinstance(X, pd.DataFrame):
68+
X = X.iloc[:, 0].unstack(level="instrument") # always use the first column
69+
else:
70+
X = X.unstack(level="instrument")
71+
else:
72+
# X is 2D DataFrame
73+
pass
74+
columns = X.columns # will be used to restore dataframe
75+
X = X.values
76+
77+
# calculate pct_change
78+
if is_price:
79+
X = X[1:] / X[:-1] - 1 # NOTE: resulting `n - 1` rows
80+
81+
# scale return
82+
if self.scale_return:
83+
X *= 100
84+
85+
# handle nan and centered
86+
X = self._preprocess(X)
87+
88+
# return decomposed components if needed
89+
if return_decomposed_components:
90+
assert (
91+
"return_decomposed_components" in inspect.getfullargspec(self._predict).args
92+
), "This risk model does not support return decomposed components of the covariance matrix "
93+
94+
F, cov_b, var_u = self._predict(X, return_decomposed_components=True)
95+
return F, cov_b, var_u
96+
97+
# estimate covariance
98+
S = self._predict(X)
99+
100+
# return correlation if needed
101+
if return_corr:
102+
vola = np.sqrt(np.diag(S))
103+
corr = S / np.outer(vola, vola)
104+
if columns is None:
105+
return corr
106+
return pd.DataFrame(corr, index=columns, columns=columns)
107+
108+
# return covariance
109+
if columns is None:
110+
return S
111+
return pd.DataFrame(S, index=columns, columns=columns)
112+
113+
def _predict(self, X: np.ndarray) -> np.ndarray:
114+
"""covariance estimation implementation
115+
116+
This method should be overridden by child classes.
117+
118+
By default, this method implements the empirical covariance estimation.
119+
120+
Args:
121+
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
122+
123+
Returns:
124+
np.ndarray: covariance matrix.
125+
"""
126+
xTx = np.asarray(X.T.dot(X))
127+
N = len(X)
128+
if isinstance(X, np.ma.MaskedArray):
129+
M = 1 - X.mask
130+
N = M.T.dot(M) # each pair has distinct number of samples
131+
return xTx / N
132+
133+
def _preprocess(self, X: np.ndarray) -> Union[np.ndarray, np.ma.MaskedArray]:
134+
"""handle nan and centerize data
135+
136+
Note:
137+
if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
138+
"""
139+
# handle nan
140+
if self.nan_option == self.FILL_NAN:
141+
X = np.nan_to_num(X)
142+
elif self.nan_option == self.MASK_NAN:
143+
X = np.ma.masked_invalid(X)
144+
# centralize
145+
if not self.assume_centered:
146+
X = X - np.nanmean(X, axis=0)
147+
return X

qlib/model/riskmodel/poet.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import numpy as np
2+
3+
from qlib.model.riskmodel import RiskModel
4+
5+
6+
class POETCovEstimator(RiskModel):
7+
"""Principal Orthogonal Complement Thresholding Estimator (POET)
8+
9+
Reference:
10+
[1] Fan, J., Liao, Y., & Mincheva, M. (2013). Large covariance estimation by thresholding principal orthogonal complements.
11+
Journal of the Royal Statistical Society. Series B: Statistical Methodology, 75(4), 603–680. https://doi.org/10.1111/rssb.12016
12+
[2] http://econweb.rutgers.edu/yl1114/papers/poet/POET.m
13+
"""
14+
15+
THRESH_SOFT = "soft"
16+
THRESH_HARD = "hard"
17+
THRESH_SCAD = "scad"
18+
19+
def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
20+
"""
21+
Args:
22+
num_factors (int): number of factors (if set to zero, no factor model will be used).
23+
thresh (float): the positive constant for thresholding.
24+
thresh_method (str): thresholding method, which can be
25+
- 'soft': soft thresholding.
26+
- 'hard': hard thresholding.
27+
- 'scad': scad thresholding.
28+
kwargs: see `RiskModel` for more information.
29+
"""
30+
super().__init__(**kwargs)
31+
32+
assert num_factors >= 0, "`num_factors` requires a positive integer"
33+
self.num_factors = num_factors
34+
35+
assert thresh >= 0, "`thresh` requires a positive float number"
36+
self.thresh = thresh
37+
38+
assert thresh_method in [
39+
self.THRESH_HARD,
40+
self.THRESH_SOFT,
41+
self.THRESH_SCAD,
42+
], "`thresh_method` should be `soft`/`hard`/`scad`"
43+
self.thresh_method = thresh_method
44+
45+
def _predict(self, X: np.ndarray) -> np.ndarray:
46+
47+
Y = X.T # NOTE: to match POET's implementation
48+
p, n = Y.shape
49+
50+
if self.num_factors > 0:
51+
Dd, V = np.linalg.eig(Y.T.dot(Y))
52+
V = V[:, np.argsort(Dd)]
53+
F = V[:, -self.num_factors :][:, ::-1] * np.sqrt(n)
54+
LamPCA = Y.dot(F) / n
55+
uhat = np.asarray(Y - LamPCA.dot(F.T))
56+
Lowrank = np.asarray(LamPCA.dot(LamPCA.T))
57+
rate = 1 / np.sqrt(p) + np.sqrt(np.log(p) / n)
58+
else:
59+
uhat = np.asarray(Y)
60+
rate = np.sqrt(np.log(p) / n)
61+
Lowrank = 0
62+
63+
lamb = rate * self.thresh
64+
SuPCA = uhat.dot(uhat.T) / n
65+
SuDiag = np.diag(np.diag(SuPCA))
66+
R = np.linalg.inv(SuDiag ** 0.5).dot(SuPCA).dot(np.linalg.inv(SuDiag ** 0.5))
67+
68+
if self.thresh_method == self.THRESH_HARD:
69+
M = R * (np.abs(R) > lamb)
70+
elif self.thresh_method == self.THRESH_SOFT:
71+
res = np.abs(R) - lamb
72+
res = (res + np.abs(res)) / 2
73+
M = np.sign(R) * res
74+
else:
75+
M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb)
76+
M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7
77+
M3 = (np.abs(R) >= 3.7 * lamb) * R
78+
M = M1 + M2 + M3
79+
80+
Rthresh = M - np.diag(np.diag(M)) + np.eye(p)
81+
SigmaU = (SuDiag ** 0.5).dot(Rthresh).dot(SuDiag ** 0.5)
82+
SigmaY = SigmaU + Lowrank
83+
84+
return SigmaY

0 commit comments

Comments
 (0)