66from metric_learn .sklearn_shims import set_random_state
77from sklearn import clone
88import numpy as np
9+ from numpy .testing import assert_array_equal
910
1011
1112@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
@@ -26,6 +27,49 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset,
2627 assert len (not_valid ) == 0
2728
2829
30+ @pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
31+ ids = ids_triplets_learners )
32+ def test_no_zero_prediction (estimator , build_dataset ):
33+ """
34+ Test that all predicted values are not zero, even when the
35+ distance d(x,y) and d(x,z) is the same for a triplet of the
36+ form (x, y, z). i.e border cases.
37+ """
38+ triplets , _ , _ , X = build_dataset (with_preprocessor = False )
39+ # Force 3 dimentions only, to use cross product and get easy orthogonal vec.
40+ triplets = np .array ([[t [0 ][:3 ], t [1 ][:3 ], t [2 ][:3 ]] for t in triplets ])
41+ X = X [:, :3 ]
42+ # Dummy fit
43+ estimator = clone (estimator )
44+ set_random_state (estimator )
45+ estimator .fit (triplets )
46+ # We force the transformation to be identity, to force euclidean distance
47+ estimator .components_ = np .eye (X .shape [1 ])
48+
49+ # Get two orthogonal vectors in respect to X[1]
50+ k = X [1 ] / np .linalg .norm (X [1 ]) # Normalize first vector
51+ x = X [2 ] - X [2 ].dot (k ) * k # Get random orthogonal vector
52+ x /= np .linalg .norm (x ) # Normalize
53+ y = np .cross (k , x ) # Get orthogonal vector to x
54+ # Assert these orthogonal vectors are different
55+ with pytest .raises (AssertionError ):
56+ assert_array_equal (X [1 ], x )
57+ with pytest .raises (AssertionError ):
58+ assert_array_equal (X [1 ], y )
59+ # Assert the distance is the same for both
60+ assert estimator .get_metric ()(X [1 ], x ) == estimator .get_metric ()(X [1 ], y )
61+
62+ # Form the three scenarios where predict() gives 0 with numpy.sign
63+ triplets_test = np .array ( # Critical examples
64+ [[X [0 ], X [2 ], X [2 ]],
65+ [X [1 ], X [1 ], X [1 ]],
66+ [X [1 ], x , y ]])
67+ # Predict
68+ predictions = estimator .predict (triplets_test )
69+ # Check there are no zero values
70+ assert np .sum (predictions == 0 ) == 0
71+
72+
2973@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
3074@pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
3175 ids = ids_triplets_learners )
0 commit comments