Skip to content

Commit 2f41ab9

Browse files
authored
Merge pull request #41 from lishen/master
Add sklearn regressors, optimize hyperparameter search space and enhance cross-validation
2 parents 79c2564 + eba8173 commit 2f41ab9

File tree

9 files changed

+1852
-802
lines changed

9 files changed

+1852
-802
lines changed

hpsklearn/components.py

Lines changed: 973 additions & 564 deletions
Large diffs are not rendered by default.

hpsklearn/demo_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def plot_minvalid_vs_time(estimator, ax, ylim=None):
2323

2424

2525
class PlotHelper(object):
26+
2627
def __init__(self, estimator, mintodate_ylim):
2728
self.estimator = estimator
2829
self.fig, self.axs = plt.subplots(1, 2)
@@ -41,4 +42,3 @@ def post_iter(self):
4142

4243
def post_loop(self):
4344
display.clear_output()
44-

hpsklearn/estimator.py

Lines changed: 440 additions & 161 deletions
Large diffs are not rendered by default.

hpsklearn/lagselectors.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Lag selectors that subset time series predictors
2+
3+
This module defines lag selectors with specified lag sizes for endogenous and
4+
exogenous predictors, using the same style as the sklearn transformers. They
5+
can be used in hpsklearn as preprocessors. The module is well suited for time
6+
series data.
7+
8+
When use a lag size of a positive integer, it is assumed that lag=1, 2, ...
9+
predictors are located at the 1st, 2nd, ... columns. When use a negative
10+
integer, the predictors are located at the N-th, (N - 1)th, ... columns.
11+
12+
"""
13+
from sklearn.base import BaseEstimator, TransformerMixin
14+
import numpy as np
15+
16+
17+
class LagSelector(BaseEstimator, TransformerMixin):
18+
"""Subset time series features by choosing the most recent lags
19+
20+
Parameters
21+
----------
22+
lag_size : int, None by default
23+
If None, use all features. If positive integer, use features by
24+
subsetting the X as [:, :lag_size]. If negative integer, use features
25+
by subsetting the X as [:, lag_size:]. If 0, discard the features
26+
from this dataset.
27+
28+
Attributes
29+
----------
30+
max_lag_size_ : int
31+
The largest allowed lag size inferred from input.
32+
"""
33+
34+
def __init__(self, lag_size=None):
35+
self.lag_size = lag_size
36+
37+
def _reset(self):
38+
"""Reset internal data-dependent state of the selector, if necessary.
39+
40+
__init__ parameters are not touched.
41+
"""
42+
if hasattr(self, 'max_lag_size_'):
43+
del self.max_lag_size_
44+
45+
def fit(self, X, y=None):
46+
"""Infer the maximum lag size.
47+
48+
Parameters
49+
----------
50+
X : {array-like, sparse matrix}, shape [n_samples, n_features]
51+
The input time series data with lagged predictors as features.
52+
53+
y: Passthrough for ``Pipeline`` compatibility.
54+
"""
55+
56+
# Reset internal state before fitting
57+
self._reset()
58+
self.max_lag_size_ = X.shape[1]
59+
60+
def transform(self, X, y=None):
61+
"""Perform standardization by centering and scaling
62+
63+
Parameters
64+
----------
65+
X : array-like, shape [n_samples, n_features]
66+
The input time series data with lagged predictors as features.
67+
"""
68+
proofed_lag_size = min(self.max_lag_size_, abs(self.lag_size))
69+
if self.lag_size >= 0:
70+
return X[:, :proofed_lag_size]
71+
else:
72+
return X[:, -proofed_lag_size:]
73+
74+
75+
76+
77+
78+
79+
80+
81+

hpsklearn/tests/test_demo.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,88 @@
1+
from __future__ import print_function
2+
# import numpy as np
3+
from sklearn import datasets
4+
from sklearn.cross_validation import train_test_split
5+
from hyperopt import tpe
6+
import hpsklearn
7+
import sys
18

29
def test_demo_iris():
3-
import numpy as np
4-
import skdata.iris.view
5-
import hyperopt.tpe
6-
import hpsklearn
710

8-
data_view = skdata.iris.view.KfoldClassification(4)
11+
iris = datasets.load_iris()
12+
X_train, X_test, y_train, y_test = train_test_split(
13+
iris.data, iris.target, test_size=.25, random_state=1)
914

1015
estimator = hpsklearn.HyperoptEstimator(
1116
preprocessing=hpsklearn.components.any_preprocessing('pp'),
1217
classifier=hpsklearn.components.any_classifier('clf'),
13-
algo=hyperopt.tpe,
14-
trial_timeout=15.0, # seconds
15-
max_evals=100,
16-
)
18+
algo=tpe.suggest,
19+
trial_timeout=15.0, # seconds
20+
max_evals=10,
21+
seed=1
22+
)
1723

1824
# /BEGIN `Demo version of estimator.fit()`
25+
print('', file=sys.stderr)
26+
print('====Demo classification on Iris dataset====', file=sys.stderr)
1927

20-
iterator = estimator.fit_iter(
21-
data_view.split[0].train.X,
22-
data_view.split[0].train.y)
28+
iterator = estimator.fit_iter(X_train, y_train)
2329
next(iterator)
2430

31+
n_trial = 0
2532
while len(estimator.trials.trials) < estimator.max_evals:
26-
iterator.send(1) # -- try one more model
27-
hpsklearn.demo_support.scatter_error_vs_time(estimator)
28-
hpsklearn.demo_support.bar_classifier_choice(estimator)
33+
iterator.send(1) # -- try one more model
34+
n_trial += 1
35+
print('Trial', n_trial, 'loss:', estimator.trials.losses()[-1],
36+
file=sys.stderr)
37+
# hpsklearn.demo_support.scatter_error_vs_time(estimator)
38+
# hpsklearn.demo_support.bar_classifier_choice(estimator)
2939

30-
estimator.retrain_best_model_on_full_data(
31-
data_view.split[0].train.X,
32-
data_view.split[0].train.y)
40+
estimator.retrain_best_model_on_full_data(X_train, y_train)
3341

3442
# /END Demo version of `estimator.fit()`
3543

36-
test_predictions = estimator.predict(data_view.split[0].test.X)
37-
print(np.mean(test_predictions == data_view.split[0].test.y))
44+
print('Test accuracy:', estimator.score(X_test, y_test), file=sys.stderr)
45+
print('====End of demo====', file=sys.stderr)
46+
47+
48+
def test_demo_boston():
49+
50+
boston = datasets.load_boston()
51+
X_train, X_test, y_train, y_test = train_test_split(
52+
boston.data, boston.target, test_size=.25, random_state=1)
53+
54+
estimator = hpsklearn.HyperoptEstimator(
55+
preprocessing=hpsklearn.components.any_preprocessing('pp'),
56+
regressor=hpsklearn.components.any_regressor('reg'),
57+
algo=tpe.suggest,
58+
trial_timeout=15.0, # seconds
59+
max_evals=10,
60+
seed=1
61+
)
62+
63+
# /BEGIN `Demo version of estimator.fit()`
64+
print('', file=sys.stderr)
65+
print('====Demo regression on Boston dataset====', file=sys.stderr)
66+
67+
68+
iterator = estimator.fit_iter(X_train, y_train)
69+
next(iterator)
70+
71+
n_trial = 0
72+
while len(estimator.trials.trials) < estimator.max_evals:
73+
iterator.send(1) # -- try one more model
74+
n_trial += 1
75+
print('Trial', n_trial, 'loss:', estimator.trials.losses()[-1],
76+
file=sys.stderr)
77+
# hpsklearn.demo_support.scatter_error_vs_time(estimator)
78+
# hpsklearn.demo_support.bar_classifier_choice(estimator)
79+
80+
estimator.retrain_best_model_on_full_data(X_train, y_train)
81+
82+
# /END Demo version of `estimator.fit()`
83+
84+
print('Test R2:', estimator.score(X_test, y_test), file=sys.stderr)
85+
print('====End of demo====', file=sys.stderr)
86+
87+
88+
# -- flake8 eof

hpsklearn/tests/test_estimator.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,43 +10,50 @@
1010

1111

1212
class TestIter(unittest.TestCase):
13+
1314
def setUp(self):
1415
np.random.seed(123)
1516
self.X = np.random.randn(1000, 2)
1617
self.Y = (self.X[:, 0] > 0).astype('int')
1718

1819
def test_fit_iter_basic(self):
19-
model = hyperopt_estimator(verbose=1, trial_timeout=5.0)
20+
model = hyperopt_estimator(
21+
classifier=components.any_classifier('classifier'),
22+
verbose=1, trial_timeout=5.0)
2023
for ii, trials in enumerate(model.fit_iter(self.X, self.Y)):
2124
assert trials is model.trials
2225
assert len(trials.trials) == ii
2326
if ii == 10:
2427
break
2528

2629
def test_fit(self):
27-
model = hyperopt_estimator(verbose=1, max_evals=5, trial_timeout=5.0)
30+
model = hyperopt_estimator(
31+
classifier=components.any_classifier('classifier'),
32+
verbose=1, max_evals=5, trial_timeout=5.0)
2833
model.fit(self.X, self.Y)
2934
assert len(model.trials.trials) == 5
3035

3136
def test_fit_biginc(self):
32-
model = hyperopt_estimator(verbose=1, max_evals=5, trial_timeout=5.0,
33-
fit_increment=20)
37+
model = hyperopt_estimator(
38+
classifier=components.any_classifier('classifier'),
39+
verbose=1, max_evals=5, trial_timeout=5.0, fit_increment=20)
3440
model.fit(self.X, self.Y)
3541
# -- make sure we only get 5 even with big fit_increment
3642
assert len(model.trials.trials) == 5
3743

3844

39-
class TestSpace(unittest.TestCase):
40-
def setUp(self):
41-
np.random.seed(123)
42-
self.X = np.random.randn(1000, 2)
43-
self.Y = (self.X[:, 0] > 0).astype('int')
45+
# class TestSpace(unittest.TestCase):
4446

45-
def test_smoke(self):
46-
# -- verify the space argument is accepted and runs
47-
space = components.generic_space()
48-
model = hyperopt_estimator(
49-
verbose=1, max_evals=10, trial_timeout=5, space=space)
50-
model.fit(self.X, self.Y)
47+
# def setUp(self):
48+
# np.random.seed(123)
49+
# self.X = np.random.randn(1000, 2)
50+
# self.Y = (self.X[:, 0] > 0).astype('int')
51+
52+
# def test_smoke(self):
53+
# # -- verify the space argument is accepted and runs
54+
# space = components.generic_space()
55+
# model = hyperopt_estimator(
56+
# verbose=1, max_evals=10, trial_timeout=5, space=space)
57+
# model.fit(self.X, self.Y)
5158

5259
# -- flake8 eof

0 commit comments

Comments
 (0)