Skip to content

Commit 44fd427

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] Remove random_seed in fit and use the one in init (#224)
* Remove random_seed in fit and use the one in init * update tests with the new API * Update test_RCA in sklearn_compat * Update test_SDML in sklearn_compat * Remove testing of pca_comps since it's deprecated * Fix sklearn test * Fix random_seed for test_iris in TestRCA * Relaunch CI * Augment tolerance rather than fix random_seed * Add ChangedBehaviorWarning if the random_state is left default * Update the merge * Address #224 (review)
1 parent 46a948a commit 44fd427

File tree

11 files changed

+335
-126
lines changed

11 files changed

+335
-126
lines changed

metric_learn/constraints.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from six.moves import xrange
88
from scipy.sparse import coo_matrix
9+
from sklearn.utils import check_random_state
910

1011
__all__ = ['Constraints']
1112

@@ -23,7 +24,8 @@ def __init__(self, partial_labels):
2324
self.known_label_idx, = np.where(partial_labels >= 0)
2425
self.known_labels = partial_labels[self.known_label_idx]
2526

26-
def adjacency_matrix(self, num_constraints, random_state=np.random):
27+
def adjacency_matrix(self, num_constraints, random_state=None):
28+
random_state = check_random_state(random_state)
2729
a, b, c, d = self.positive_negative_pairs(num_constraints,
2830
random_state=random_state)
2931
row = np.concatenate((a, c))
@@ -35,7 +37,8 @@ def adjacency_matrix(self, num_constraints, random_state=np.random):
3537
return adj + adj.T
3638

3739
def positive_negative_pairs(self, num_constraints, same_length=False,
38-
random_state=np.random):
40+
random_state=None):
41+
random_state = check_random_state(random_state)
3942
a, b = self._pairs(num_constraints, same_label=True,
4043
random_state=random_state)
4144
c, d = self._pairs(num_constraints, same_label=False,
@@ -68,13 +71,14 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10,
6871
ab = np.array(list(ab)[:num_constraints], dtype=int)
6972
return self.known_label_idx[ab.T]
7073

71-
def chunks(self, num_chunks=100, chunk_size=2, random_state=np.random):
74+
def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
7275
"""
7376
the random state object to be passed must be a numpy random seed
7477
"""
78+
random_state = check_random_state(random_state)
7579
chunks = -np.ones_like(self.known_label_idx, dtype=int)
7680
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
77-
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
81+
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))]
7882
idx = 0
7983
while idx < num_chunks and all_inds:
8084
if len(all_inds) == 1:

metric_learn/itml.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
import numpy as np
88
from six.moves import xrange
9+
from sklearn.exceptions import ChangedBehaviorWarning
910
from sklearn.metrics import pairwise_distances
1011
from sklearn.utils.validation import check_array
1112
from sklearn.base import TransformerMixin
@@ -298,7 +299,6 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
298299
A positive definite (PD) matrix of shape
299300
(n_features, n_features), that will be used as such to set the
300301
prior.
301-
302302
A0 : Not used
303303
.. deprecated:: 0.5.0
304304
`A0` was deprecated in version 0.5.0 and will
@@ -310,7 +310,9 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
310310
tuples will be formed like this: X[indices].
311311
random_state : int or numpy.RandomState or None, optional (default=None)
312312
A pseudo random number generator object or a seed for it if int. If
313-
``prior='random'``, ``random_state`` is used to set the prior.
313+
``prior='random'``, ``random_state`` is used to set the prior. In any
314+
case, `random_state` is also used to randomly sample constraints from
315+
labels.
314316
315317
316318
Attributes
@@ -350,7 +352,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
350352
self.num_constraints = num_constraints
351353
self.bounds = bounds
352354

353-
def fit(self, X, y, random_state=np.random, bounds=None):
355+
def fit(self, X, y, random_state='deprecated', bounds=None):
354356
"""Create constraints from labels and learn the ITML model.
355357
356358
@@ -362,8 +364,11 @@ def fit(self, X, y, random_state=np.random, bounds=None):
362364
y : (n) array-like
363365
Data labels.
364366
365-
random_state : numpy.random.RandomState, optional
366-
If provided, controls random number generation.
367+
random_state : Not used
368+
.. deprecated:: 0.5.0
369+
`random_state` in the `fit` function was deprecated in version 0.5.0
370+
and will be removed in 0.6.0. Set `random_state` at initialization
371+
instead (when instantiating a new `ITML_Supervised` object).
367372
368373
bounds : array-like of two numbers
369374
Bounds on similarity, aside slack variables, s.t.
@@ -384,6 +389,18 @@ def fit(self, X, y, random_state=np.random, bounds=None):
384389
' It has been deprecated in version 0.5.0 and will be'
385390
' removed in 0.6.0. Use the "bounds" parameter of this '
386391
'fit method instead.', DeprecationWarning)
392+
if random_state != 'deprecated':
393+
warnings.warn('"random_state" parameter in the `fit` function is '
394+
'deprecated. Set `random_state` at initialization '
395+
'instead (when instantiating a new `ITML_Supervised` '
396+
'object).', DeprecationWarning)
397+
else:
398+
warnings.warn('As of v0.5.0, `ITML_Supervised` now uses the '
399+
'`random_state` given at initialization to sample '
400+
'constraints, not the default `np.random` from the `fit` '
401+
'method, since this argument is now deprecated. '
402+
'This warning will disappear in v0.6.0.',
403+
ChangedBehaviorWarning)
387404
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
388405
num_constraints = self.num_constraints
389406
if num_constraints is None:
@@ -392,6 +409,6 @@ def fit(self, X, y, random_state=np.random, bounds=None):
392409

393410
c = Constraints(y)
394411
pos_neg = c.positive_negative_pairs(num_constraints,
395-
random_state=random_state)
412+
random_state=self.random_state)
396413
pairs, y = wrap_pairs(X, pos_neg)
397414
return _BaseITML._fit(self, pairs, y, bounds=bounds)

metric_learn/lsml.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
286286
random_state : int or numpy.RandomState or None, optional (default=None)
287287
A pseudo random number generator object or a seed for it if int. If
288288
``init='random'``, ``random_state`` is used to set the random
289-
prior.
289+
prior. In any case, `random_state` is also used to randomly sample
290+
constraints from labels.
290291
291292
Attributes
292293
----------
@@ -308,7 +309,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None,
308309
self.num_constraints = num_constraints
309310
self.weights = weights
310311

311-
def fit(self, X, y, random_state=np.random):
312+
def fit(self, X, y, random_state='deprecated'):
312313
"""Create constraints from labels and learn the LSML model.
313314
314315
Parameters
@@ -319,13 +320,28 @@ def fit(self, X, y, random_state=np.random):
319320
y : (n) array-like
320321
Data labels.
321322
322-
random_state : numpy.random.RandomState, optional
323-
If provided, controls random number generation.
323+
random_state : Not used
324+
.. deprecated:: 0.5.0
325+
`random_state` in the `fit` function was deprecated in version 0.5.0
326+
and will be removed in 0.6.0. Set `random_state` at initialization
327+
instead (when instantiating a new `LSML_Supervised` object).
324328
"""
325329
if self.num_labeled != 'deprecated':
326330
warnings.warn('"num_labeled" parameter is not used.'
327331
' It has been deprecated in version 0.5.0 and will be'
328332
' removed in 0.6.0', DeprecationWarning)
333+
if random_state != 'deprecated':
334+
warnings.warn('"random_state" parameter in the `fit` function is '
335+
'deprecated. Set `random_state` at initialization '
336+
'instead (when instantiating a new `LSML_Supervised` '
337+
'object).', DeprecationWarning)
338+
else:
339+
warnings.warn('As of v0.5.0, `LSML_Supervised` now uses the '
340+
'`random_state` given at initialization to sample '
341+
'constraints, not the default `np.random` from the `fit` '
342+
'method, since this argument is now deprecated. '
343+
'This warning will disappear in v0.6.0.',
344+
ChangedBehaviorWarning)
329345
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
330346
num_constraints = self.num_constraints
331347
if num_constraints is None:
@@ -334,6 +350,6 @@ def fit(self, X, y, random_state=np.random):
334350

335351
c = Constraints(y)
336352
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
337-
random_state=random_state)
353+
random_state=self.random_state)
338354
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],
339355
weights=self.weights)

metric_learn/mmc.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
538538
random_state : int or numpy.RandomState or None, optional (default=None)
539539
A pseudo random number generator object or a seed for it if int. If
540540
``init='random'``, ``random_state`` is used to initialize the random
541-
Mahalanobis matrix.
541+
Mahalanobis matrix. In any case, `random_state` is also used to
542+
randomly sample constraints from labels.
542543
543544
`MMC_Supervised` creates pairs of similar sample by taking same class
544545
samples, and pairs of dissimilar samples by taking different class
@@ -566,7 +567,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
566567
self.num_labeled = num_labeled
567568
self.num_constraints = num_constraints
568569

569-
def fit(self, X, y, random_state=np.random):
570+
def fit(self, X, y, random_state='deprecated'):
570571
"""Create constraints from labels and learn the MMC model.
571572
572573
Parameters
@@ -575,13 +576,28 @@ def fit(self, X, y, random_state=np.random):
575576
Input data, where each row corresponds to a single instance.
576577
y : (n) array-like
577578
Data labels.
578-
random_state : numpy.random.RandomState, optional
579-
If provided, controls random number generation.
579+
random_state : Not used
580+
.. deprecated:: 0.5.0
581+
`random_state` in the `fit` function was deprecated in version 0.5.0
582+
and will be removed in 0.6.0. Set `random_state` at initialization
583+
instead (when instantiating a new `MMC_Supervised` object).
580584
"""
581585
if self.num_labeled != 'deprecated':
582586
warnings.warn('"num_labeled" parameter is not used.'
583587
' It has been deprecated in version 0.5.0 and will be'
584588
' removed in 0.6.0', DeprecationWarning)
589+
if random_state != 'deprecated':
590+
warnings.warn('"random_state" parameter in the `fit` function is '
591+
'deprecated. Set `random_state` at initialization '
592+
'instead (when instantiating a new `MMC_Supervised` '
593+
'object).', DeprecationWarning)
594+
else:
595+
warnings.warn('As of v0.5.0, `MMC_Supervised` now uses the '
596+
'`random_state` given at initialization to sample '
597+
'constraints, not the default `np.random` from the `fit` '
598+
'method, since this argument is now deprecated. '
599+
'This warning will disappear in v0.6.0.',
600+
ChangedBehaviorWarning)
585601
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
586602
num_constraints = self.num_constraints
587603
if num_constraints is None:
@@ -590,6 +606,6 @@ def fit(self, X, y, random_state=np.random):
590606

591607
c = Constraints(y)
592608
pos_neg = c.positive_negative_pairs(num_constraints,
593-
random_state=random_state)
609+
random_state=self.random_state)
594610
pairs, y = wrap_pairs(X, pos_neg)
595611
return _BaseMMC._fit(self, pairs, y)

metric_learn/rca.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,17 @@ class RCA_Supervised(RCA):
184184
be removed in 0.6.0. Use `n_components` instead.
185185
186186
num_chunks: int, optional
187+
187188
chunk_size: int, optional
189+
188190
preprocessor : array-like, shape=(n_samples, n_features) or callable
189191
The preprocessor to call to get tuples from indices. If array-like,
190192
tuples will be formed like this: X[indices].
191193
194+
random_state : int or numpy.RandomState or None, optional (default=None)
195+
A pseudo random number generator object or a seed for it if int.
196+
It is used to randomly sample constraints from labels.
197+
192198
Attributes
193199
----------
194200
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
@@ -197,13 +203,15 @@ class RCA_Supervised(RCA):
197203

198204
def __init__(self, num_dims='deprecated', n_components=None,
199205
pca_comps='deprecated', num_chunks=100, chunk_size=2,
200-
preprocessor=None):
206+
preprocessor=None, random_state=None):
207+
"""Initialize the supervised version of `RCA`."""
201208
RCA.__init__(self, num_dims=num_dims, n_components=n_components,
202209
pca_comps=pca_comps, preprocessor=preprocessor)
203210
self.num_chunks = num_chunks
204211
self.chunk_size = chunk_size
212+
self.random_state = random_state
205213

206-
def fit(self, X, y, random_state=np.random):
214+
def fit(self, X, y, random_state='deprecated'):
207215
"""Create constraints from labels and learn the RCA model.
208216
Needs num_constraints specified in constructor.
209217
@@ -212,10 +220,26 @@ def fit(self, X, y, random_state=np.random):
212220
X : (n x d) data matrix
213221
each row corresponds to a single instance
214222
y : (n) data labels
215-
random_state : a random.seed object to fix the random_state if needed.
223+
random_state : Not used
224+
.. deprecated:: 0.5.0
225+
`random_state` in the `fit` function was deprecated in version 0.5.0
226+
and will be removed in 0.6.0. Set `random_state` at initialization
227+
instead (when instantiating a new `RCA_Supervised` object).
216228
"""
229+
if random_state != 'deprecated':
230+
warnings.warn('"random_state" parameter in the `fit` function is '
231+
'deprecated. Set `random_state` at initialization '
232+
'instead (when instantiating a new `RCA_Supervised` '
233+
'object).', DeprecationWarning)
234+
else:
235+
warnings.warn('As of v0.5.0, `RCA_Supervised` now uses the '
236+
'`random_state` given at initialization to sample '
237+
'constraints, not the default `np.random` from the `fit` '
238+
'method, since this argument is now deprecated. '
239+
'This warning will disappear in v0.6.0.',
240+
ChangedBehaviorWarning)
217241
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
218242
chunks = Constraints(y).chunks(num_chunks=self.num_chunks,
219243
chunk_size=self.chunk_size,
220-
random_state=random_state)
244+
random_state=self.random_state)
221245
return RCA.fit(self, X, chunks)

metric_learn/sdml.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
310310
random_state : int or numpy.RandomState or None, optional (default=None)
311311
A pseudo random number generator object or a seed for it if int. If
312312
``init='random'``, ``random_state`` is used to set the random
313-
prior.
313+
prior. In any case, `random_state` is also used to randomly sample
314+
constraints from labels.
314315
315316
Attributes
316317
----------
@@ -336,7 +337,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, prior=None,
336337
self.num_labeled = num_labeled
337338
self.num_constraints = num_constraints
338339

339-
def fit(self, X, y, random_state=np.random):
340+
def fit(self, X, y, random_state='deprecated'):
340341
"""Create constraints from labels and learn the SDML model.
341342
342343
Parameters
@@ -345,9 +346,11 @@ def fit(self, X, y, random_state=np.random):
345346
data matrix, where each row corresponds to a single instance
346347
y : array-like, shape (n,)
347348
data labels, one for each instance
348-
random_state : {numpy.random.RandomState, int}, optional
349-
Random number generator or random seed. If not given, the singleton
350-
numpy.random will be used.
349+
random_state : Not used
350+
.. deprecated:: 0.5.0
351+
`random_state` in the `fit` function was deprecated in version 0.5.0
352+
and will be removed in 0.6.0. Set `random_state` at initialization
353+
instead (when instantiating a new `SDML_Supervised` object).
351354
352355
Returns
353356
-------
@@ -358,6 +361,18 @@ def fit(self, X, y, random_state=np.random):
358361
warnings.warn('"num_labeled" parameter is not used.'
359362
' It has been deprecated in version 0.5.0 and will be'
360363
' removed in 0.6.0', DeprecationWarning)
364+
if random_state != 'deprecated':
365+
warnings.warn('"random_state" parameter in the `fit` function is '
366+
'deprecated. Set `random_state` at initialization '
367+
'instead (when instantiating a new `SDML_Supervised` '
368+
'object).', DeprecationWarning)
369+
else:
370+
warnings.warn('As of v0.5.0, `SDML_Supervised` now uses the '
371+
'`random_state` given at initialization to sample '
372+
'constraints, not the default `np.random` from the `fit` '
373+
'method, since this argument is now deprecated. '
374+
'This warning will disappear in v0.6.0.',
375+
ChangedBehaviorWarning)
361376
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
362377
num_constraints = self.num_constraints
363378
if num_constraints is None:
@@ -366,6 +381,6 @@ def fit(self, X, y, random_state=np.random):
366381

367382
c = Constraints(y)
368383
pos_neg = c.positive_negative_pairs(num_constraints,
369-
random_state=random_state)
384+
random_state=self.random_state)
370385
pairs, y = wrap_pairs(X, pos_neg)
371386
return _BaseSDML._fit(self, pairs, y)

0 commit comments

Comments
 (0)