@@ -240,12 +240,14 @@ def transform(self, X):
240240 X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
241241 The embedded data points.
242242 """
243+ check_is_fitted (self , ['preprocessor_' , 'components_' ])
243244 X_checked = check_input (X , type_of_inputs = 'classic' , estimator = self ,
244245 preprocessor = self .preprocessor_ ,
245246 accept_sparse = True )
246247 return X_checked .dot (self .components_ .T )
247248
248249 def get_metric (self ):
250+ check_is_fitted (self , 'components_' )
249251 components_T = self .components_ .T .copy ()
250252
251253 def metric_fun (u , v , squared = False ):
@@ -298,6 +300,7 @@ def get_mahalanobis_matrix(self):
298300 M : `numpy.ndarray`, shape=(n_features, n_features)
299301 The copy of the learned Mahalanobis matrix.
300302 """
303+ check_is_fitted (self , 'components_' )
301304 return self .components_ .T .dot (self .components_ )
302305
303306
@@ -333,7 +336,10 @@ def predict(self, pairs):
333336 y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
334337 The predicted learned metric value between samples in every pair.
335338 """
336- check_is_fitted (self , ['threshold_' , 'components_' ])
339+ if "threshold_" not in vars (self ):
340+ msg = ("A threshold for this estimator has not been set,"
341+ "call its set_threshold or calibrate_threshold method." )
342+ raise AttributeError (msg )
337343 return 2 * (- self .decision_function (pairs ) <= self .threshold_ ) - 1
338344
339345 def decision_function (self , pairs ):
@@ -357,6 +363,7 @@ def decision_function(self, pairs):
357363 y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
358364 The predicted decision function value for each pair.
359365 """
366+ check_is_fitted (self , 'preprocessor_' )
360367 pairs = check_input (pairs , type_of_inputs = 'tuples' ,
361368 preprocessor = self .preprocessor_ ,
362369 estimator = self , tuple_size = self ._tuple_size )
@@ -599,7 +606,7 @@ def predict(self, quadruplets):
599606 prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
600607 Predictions of the ordering of pairs, for each quadruplet.
601608 """
602- check_is_fitted (self , 'components_ ' )
609+ check_is_fitted (self , 'preprocessor_ ' )
603610 quadruplets = check_input (quadruplets , type_of_inputs = 'tuples' ,
604611 preprocessor = self .preprocessor_ ,
605612 estimator = self , tuple_size = self ._tuple_size )
@@ -628,6 +635,7 @@ def decision_function(self, quadruplets):
628635 decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
629636 Metric differences.
630637 """
638+ check_is_fitted (self , 'preprocessor_' )
631639 quadruplets = check_input (quadruplets , type_of_inputs = 'tuples' ,
632640 preprocessor = self .preprocessor_ ,
633641 estimator = self , tuple_size = self ._tuple_size )
0 commit comments