66import warnings
77import numpy as np
88from six .moves import xrange
9+ from sklearn .exceptions import ChangedBehaviorWarning
910from sklearn .metrics import pairwise_distances
1011from sklearn .utils .validation import check_array
1112from sklearn .base import TransformerMixin
@@ -298,7 +299,6 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
298299 A positive definite (PD) matrix of shape
299300 (n_features, n_features), that will be used as such to set the
300301 prior.
301-
302302 A0 : Not used
303303 .. deprecated:: 0.5.0
304304 `A0` was deprecated in version 0.5.0 and will
@@ -310,7 +310,9 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
310310 tuples will be formed like this: X[indices].
311311 random_state : int or numpy.RandomState or None, optional (default=None)
312312 A pseudo random number generator object or a seed for it if int. If
313- ``prior='random'``, ``random_state`` is used to set the prior.
313+ ``prior='random'``, ``random_state`` is used to set the prior. In any
314+ case, `random_state` is also used to randomly sample constraints from
315+ labels.
314316
315317
316318 Attributes
@@ -350,7 +352,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
350352 self .num_constraints = num_constraints
351353 self .bounds = bounds
352354
353- def fit (self , X , y , random_state = np . random , bounds = None ):
355+ def fit (self , X , y , random_state = 'deprecated' , bounds = None ):
354356 """Create constraints from labels and learn the ITML model.
355357
356358
@@ -362,8 +364,11 @@ def fit(self, X, y, random_state=np.random, bounds=None):
362364 y : (n) array-like
363365 Data labels.
364366
365- random_state : numpy.random.RandomState, optional
366- If provided, controls random number generation.
367+ random_state : Not used
368+ .. deprecated:: 0.5.0
369+ `random_state` in the `fit` function was deprecated in version 0.5.0
370+ and will be removed in 0.6.0. Set `random_state` at initialization
371+ instead (when instantiating a new `ITML_Supervised` object).
367372
368373 bounds : array-like of two numbers
369374 Bounds on similarity, aside slack variables, s.t.
@@ -384,6 +389,18 @@ def fit(self, X, y, random_state=np.random, bounds=None):
384389 ' It has been deprecated in version 0.5.0 and will be'
385390 ' removed in 0.6.0. Use the "bounds" parameter of this '
386391 'fit method instead.' , DeprecationWarning )
392+ if random_state != 'deprecated' :
393+ warnings .warn ('"random_state" parameter in the `fit` function is '
394+ 'deprecated. Set `random_state` at initialization '
395+ 'instead (when instantiating a new `ITML_Supervised` '
396+ 'object).' , DeprecationWarning )
397+ else :
398+ warnings .warn ('As of v0.5.0, `ITML_Supervised` now uses the '
399+ '`random_state` given at initialization to sample '
400+ 'constraints, not the default `np.random` from the `fit` '
401+ 'method, since this argument is now deprecated. '
402+ 'This warning will disappear in v0.6.0.' ,
403+ ChangedBehaviorWarning )
387404 X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
388405 num_constraints = self .num_constraints
389406 if num_constraints is None :
@@ -392,6 +409,6 @@ def fit(self, X, y, random_state=np.random, bounds=None):
392409
393410 c = Constraints (y )
394411 pos_neg = c .positive_negative_pairs (num_constraints ,
395- random_state = random_state )
412+ random_state = self . random_state )
396413 pairs , y = wrap_pairs (X , pos_neg )
397414 return _BaseITML ._fit (self , pairs , y , bounds = bounds )
0 commit comments