88import numpy as np
99from numpy .testing import assert_array_equal
1010
11+
1112@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
1213@pytest .mark .parametrize ('estimator, build_dataset' , triplets_learners ,
1314 ids = ids_triplets_learners )
@@ -75,8 +76,8 @@ def test_no_zero_prediction(estimator, build_dataset):
7576 """
7677 # Dummy fit
7778 triplets , _ , _ , X = build_dataset (with_preprocessor = False )
78- # Force 3 dimentions only, to use cross product and get easy orthogonal vectors .
79- triplets = np .array ([ [t [0 ][:3 ], t [1 ][:3 ], t [2 ][:3 ]] for t in triplets ])
79+ # Force 3 dimentions only, to use cross product and get easy orthogonal vec .
80+ triplets = np .array ([[t [0 ][:3 ], t [1 ][:3 ], t [2 ][:3 ]] for t in triplets ])
8081 X = np .array ([x [:3 ] for x in X ])
8182 # Dummy fit
8283 estimator = clone (estimator )
@@ -102,10 +103,9 @@ def test_no_zero_prediction(estimator, build_dataset):
102103 triplets_test = np .array ( # Critical examples
103104 [[X [0 ], X [2 ], X [2 ]],
104105 [X [1 ], X [1 ], X [1 ]],
105- [X [1 ], x , y ]
106- ])
106+ [X [1 ], x , y ]])
107107 # Predict
108108 predictions = estimator .predict (triplets_test )
109109 # Count non -1 or 1 values
110110 not_valid = [e for e in predictions if e not in [- 1 , 1 ]]
111- assert len (not_valid ) == 0
111+ assert len (not_valid ) == 0
0 commit comments