66from metric_learn .sklearn_shims import set_random_state
77from sklearn import clone
88import numpy as np
9-
9+ from numpy . testing import assert_array_equal
1010
1111@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
1212@pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
@@ -63,3 +63,49 @@ def test_accuracy_toy_example(estimator, build_dataset):
6363 # we force the transformation to be identity so that we control what it does
6464 estimator .components_ = np .eye (X .shape [1 ])
6565 assert estimator .score (triplets_test ) == 0.25
66+
67+
68+ @pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
69+ ids = ids_triplets_learners )
70+ def test_no_zero_prediction (estimator , build_dataset ):
71+ """
72+ Test that all predicted values are in {-1, 1}, even when the
73+ distance d(x,y) and d(x,z) is the same for a triplet of the
74+ form (x, y, z).
75+ """
76+ # Dummy fit
77+ triplets , _ , _ , X = build_dataset (with_preprocessor = False )
78+ # Force 3 dimentions only, to use cross product and get easy orthogonal vectors.
79+ triplets = np .array ([ [t [0 ][:3 ], t [1 ][:3 ], t [2 ][:3 ]] for t in triplets ])
80+ X = np .array ([x [:3 ] for x in X ])
81+ # Dummy fit
82+ estimator = clone (estimator )
83+ set_random_state (estimator )
84+ estimator .fit (triplets )
85+ # we force the transformation to be identity, to force euclidean distance
86+ estimator .components_ = np .eye (X .shape [1 ])
87+
88+ # Get two orthogonal vectors in respect to X[1]
89+ k = X [1 ]/ np .linalg .norm (X [1 ]) # Normalize first vector
90+ x = X [2 ] - X [2 ].dot (k ) * k # Get random orthogonal vector
91+ x /= np .linalg .norm (x ) # Normalize
92+ y = np .cross (k , x ) # Get orthogonal vector to x
93+ # Assert these orthogonal vectors are different
94+ with pytest .raises (AssertionError ):
95+ assert_array_equal (X [1 ], x )
96+ with pytest .raises (AssertionError ):
97+ assert_array_equal (X [1 ], y )
98+ # Assert the distance is the same for both
99+ assert estimator .get_metric ()(X [1 ], x ) == estimator .get_metric ()(X [1 ], y )
100+
101+ # Form the three scenarios where predict() gives 0 with numpy.sign
102+ triplets_test = np .array ( # Critical examples
103+ [[X [0 ], X [2 ], X [2 ]],
104+ [X [1 ], X [1 ], X [1 ]],
105+ [X [1 ], x , y ]
106+ ])
107+ # Predict
108+ predictions = estimator .predict (triplets_test )
109+ # Count non -1 or 1 values
110+ not_valid = [e for e in predictions if e not in [- 1 , 1 ]]
111+ assert len (not_valid ) == 0
0 commit comments