@@ -69,9 +69,13 @@ def _fit(self, pairs, y, bounds=None):
6969 X = np .vstack ({tuple (row ) for row in pairs .reshape (- 1 , pairs .shape [2 ])})
7070 self .bounds_ = np .percentile (pairwise_distances (X ), (5 , 95 ))
7171 else :
72- assert len (bounds ) == 2
72+ bounds = check_array (bounds , allow_nd = False , ensure_min_samples = 0 ,
73+ ensure_2d = False )
74+ bounds = bounds .ravel ()
75+ if bounds .size != 2 :
76+ raise ValueError ("`bounds` should be an array-like of two elements." )
7377 self .bounds_ = bounds
74- self .bounds_ [self .bounds_ == 0 ] = 1e-9
78+ self .bounds_ [self .bounds_ == 0 ] = 1e-9
7579 # init metric
7680 if self .A0 is None :
7781 A = np .identity (pairs .shape [2 ])
@@ -134,7 +138,7 @@ class ITML(_BaseITML, _PairsClassifierMixin):
134138
135139 Attributes
136140 ----------
137- bounds_ : array-like , shape=(2,)
141+ bounds_ : `numpy.ndarray` , shape=(2,)
138142 Bounds on similarity, aside slack variables, s.t.
139143 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
140144 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -171,7 +175,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
171175 preprocessor.
172176 y: array-like, of shape (n_constraints,)
173177 Labels of constraints. Should be -1 for dissimilar pair, 1 for similar.
174- bounds : `list` of two numbers
178+ bounds : array-like of two numbers
175179 Bounds on similarity, aside slack variables, s.t.
176180 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
177181 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -192,7 +196,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None):
192196 calibration_params = (calibration_params if calibration_params is not
193197 None else dict ())
194198 self ._validate_calibration_params (** calibration_params )
195- self ._fit (pairs , y )
199+ self ._fit (pairs , y , bounds = bounds )
196200 self .calibrate_threshold (pairs , y , ** calibration_params )
197201 return self
198202
@@ -202,7 +206,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
202206
203207 Attributes
204208 ----------
205- bounds_ : array-like , shape=(2,)
209+ bounds_ : `numpy.ndarray` , shape=(2,)
206210 Bounds on similarity, aside slack variables, s.t.
207211 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
208212 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
@@ -275,7 +279,7 @@ def fit(self, X, y, random_state=np.random, bounds=None):
275279 random_state : numpy.random.RandomState, optional
276280 If provided, controls random number generation.
277281
278- bounds : `list` of two numbers
282+ bounds : array-like of two numbers
279283 Bounds on similarity, aside slack variables, s.t.
280284 ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
281285 and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
0 commit comments