@@ -68,9 +68,13 @@ def _fit(self, pairs, y, bounds=None):
6868 X = np .vstack ({tuple (row ) for row in pairs .reshape (- 1 , pairs .shape [2 ])})
6969 self .bounds_ = np .percentile (pairwise_distances (X ), (5 , 95 ))
7070 else :
71- assert len (bounds ) == 2
71+ bounds = check_array (bounds , allow_nd = False , ensure_min_samples = 0 ,
72+ ensure_2d = False )
73+ bounds = bounds .ravel ()
74+ if bounds .size != 2 :
75+ raise ValueError ("`bounds` should be an array-like of two elements." )
7276 self .bounds_ = bounds
73- self .bounds_ [self .bounds_ == 0 ] = 1e-9
77+ self .bounds_ [self .bounds_ == 0 ] = 1e-9
7478 # init metric
7579 if self .A0 is None :
7680 A = np .identity (pairs .shape [2 ])
@@ -133,7 +137,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
133137
134138 Attributes
135139 ----------
136- bounds_ : array-like , shape=(2,)
140+ bounds_ : `numpy.ndarray` , shape=(2,)
137141 Bounds on similarity, aside slack variables, s.t.
138142 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
139143 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -170,7 +174,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
170174 preprocessor.
171175 y: array-like, of shape (n_constraints,)
172176 Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
173- bounds : `list` of two numbers
177+ bounds : array-like of two numbers
174178 Bounds on similarity, aside slack variables, s.t.
175179 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
176180 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -191,7 +195,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
191195 calibration_params = (calibration_params if calibration_params is not
192196 None else dict ())
193197 self ._validate_calibration_params (** calibration_params )
194- self ._fit (pairs , y )
198+ self ._fit (pairs , y , bounds = bounds )
195199 self .calibrate_threshold (pairs , y , ** calibration_params )
196200 return self
197201
@@ -201,7 +205,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
201205
202206 Attributes
203207 ----------
204- bounds_ : array-like , shape=(2,)
208+ bounds_ : `numpy.ndarray` , shape=(2,)
205209 Bounds on similarity, aside slack variables, s.t.
206210 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
207211 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -274,7 +278,7 @@ def fit(self, X, y, random_state=np.random, bounds=None):
274278 random_state : numpy.random.RandomState, optional
275279 If provided, controls random number generation.
276280
277- bounds : `list` of two numbers
281+ bounds : array-like of two numbers
278282 Bounds on similarity, aside slack variables, s.t.
279283 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
280284 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
0 commit comments