@@ -332,7 +332,7 @@ def predict(self, pairs):
332332 The predicted learned metric value between samples in every pair.
333333 """
334334 check_is_fitted (self , ['threshold_' , 'transformer_' ])
335- return - 2 * (self .decision_function (pairs ) > self .threshold_ ) + 1
335+ return 2 * (self .decision_function (pairs ) > self .threshold_ ) - 1
336336
337337 def decision_function (self , pairs ):
338338 """Returns the decision function used to classify the pairs.
@@ -387,13 +387,13 @@ def score(self, pairs, y):
387387 return roc_auc_score (y , self .decision_function (pairs ))
388388
389389 def set_default_threshold (self , pairs , y ):
390- """Returns a threshold that is the mean between the similar metrics
391- mean, and the dissimilar metrics mean"""
392- similar_threshold = np .mean (self .decision_function (
390+ """Returns a threshold that is the opposite of the mean between the similar
391+ metrics mean and the dissimilar metrics mean"""
392+ similar_threshold = np .mean (self .score_pairs (
393393 pairs [(y == 1 ).ravel ()]))
394- dissimilar_threshold = np .mean (self .decision_function (
394+ dissimilar_threshold = np .mean (self .score_pairs (
395395 pairs [(y == - 1 ).ravel ()]))
396- self .threshold_ = np .mean ([similar_threshold , dissimilar_threshold ])
396+ self .threshold_ = - np .mean ([similar_threshold , dissimilar_threshold ])
397397
398398
399399class _QuadrupletsClassifierMixin (BaseMetricLearner ):
0 commit comments