Skip to content

Commit a02cbd5

Browse files
author
mvargas33
committed
Fix Triplets predict function. Made a test to show the point.
1 parent a92de21 commit a02cbd5

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

metric_learn/base_metric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def predict(self, triplets):
602602
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
603603
Predictions of the ordering of pairs, for each triplet.
604604
"""
605-
return np.sign(self.decision_function(triplets))
605+
return np.array([-1 if (t <= 0) else 1 for t in
606+
self.decision_function(triplets)])
607+
#return np.sign(self.decision_function(triplets))
606608

607609
def decision_function(self, triplets):
608610
"""Predicts differences between sample distances in input triplets.

test/test_triplets_classifiers.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from metric_learn.sklearn_shims import set_random_state
77
from sklearn import clone
88
import numpy as np
9-
9+
from numpy.testing import assert_array_equal
1010

1111
@pytest.mark.parametrize('with_preprocessor', [True, False])
1212
@pytest.mark.parametrize('estimator, build_dataset', triplets_learners,
@@ -63,3 +63,49 @@ def test_accuracy_toy_example(estimator, build_dataset):
6363
# we force the transformation to be identity so that we control what it does
6464
estimator.components_ = np.eye(X.shape[1])
6565
assert estimator.score(triplets_test) == 0.25
66+
67+
68+
@pytest.mark.parametrize('estimator, build_dataset', triplets_learners,
69+
ids=ids_triplets_learners)
70+
def test_no_zero_prediction(estimator, build_dataset):
71+
"""
72+
Test that all predicted values are in {-1, 1}, even when the
73+
distance d(x,y) and d(x,z) is the same for a triplet of the
74+
form (x, y, z).
75+
"""
76+
# Dummy fit
77+
triplets, _, _, X = build_dataset(with_preprocessor=False)
78+
# Force 3 dimentions only, to use cross product and get easy orthogonal vectors.
79+
triplets = np.array([ [t[0][:3], t[1][:3], t[2][:3]] for t in triplets])
80+
X = np.array([x[:3] for x in X])
81+
# Dummy fit
82+
estimator = clone(estimator)
83+
set_random_state(estimator)
84+
estimator.fit(triplets)
85+
# we force the transformation to be identity, to force euclidean distance
86+
estimator.components_ = np.eye(X.shape[1])
87+
88+
# Get two orthogonal vectors in respect to X[1]
89+
k = X[1]/np.linalg.norm(X[1]) # Normalize first vector
90+
x = X[2] - X[2].dot(k) * k # Get random orthogonal vector
91+
x /= np.linalg.norm(x) # Normalize
92+
y = np.cross(k, x) # Get orthogonal vector to x
93+
# Assert these orthogonal vectors are different
94+
with pytest.raises(AssertionError):
95+
assert_array_equal(X[1], x)
96+
with pytest.raises(AssertionError):
97+
assert_array_equal(X[1], y)
98+
# Assert the distance is the same for both
99+
assert estimator.get_metric()(X[1], x) == estimator.get_metric()(X[1], y)
100+
101+
# Form the three scenarios where predict() gives 0 with numpy.sign
102+
triplets_test = np.array( # Critical examples
103+
[[X[0], X[2], X[2]],
104+
[X[1], X[1], X[1]],
105+
[X[1], x, y]
106+
])
107+
# Predict
108+
predictions = estimator.predict(triplets_test)
109+
# Count non -1 or 1 values
110+
not_valid = [e for e in predictions if e not in [-1, 1]]
111+
assert len(not_valid) == 0

0 commit comments

Comments
 (0)