@@ -105,6 +105,70 @@ def stable_init(self, n_components=None, pca_comps=None,
105105
106106# ---------------------- Test scikit-learn compatibility ----------------------
107107
108+ def generate_array_like (input_data , labels = None ):
109+ """Helper function to generate array-like variants of numpy datasets,
110+ for testing purposes."""
111+ list_data = input_data .tolist ()
112+ input_data_changed = [input_data , list_data , tuple (list_data )]
113+ if input_data .ndim >= 2 :
114+ input_data_changed .append (tuple (tuple (x ) for x in list_data ))
115+ if input_data .ndim >= 3 :
116+ input_data_changed .append (tuple (tuple (tuple (x ) for x in y ) for y in
117+ list_data ))
118+ if input_data .ndim == 2 :
119+ pd = pytest .importorskip ('pandas' )
120+ input_data_changed .append (pd .DataFrame (input_data ))
121+ if labels is not None :
122+ labels_changed = [labels , list (labels ), tuple (labels )]
123+ else :
124+ labels_changed = [labels ]
125+ return input_data_changed , labels_changed
126+
127+
128+ @pytest .mark .integration
129+ @pytest .mark .parametrize ('with_preprocessor' , [True , False ])
130+ @pytest .mark .parametrize ('estimator, build_dataset' , metric_learners ,
131+ ids = ids_metric_learners )
132+ def test_array_like_inputs (estimator , build_dataset , with_preprocessor ):
133+ """Test that metric-learners can have as input (of all functions that are
134+ applied on data) any array-like object."""
135+ input_data , labels , preprocessor , X = build_dataset (with_preprocessor )
136+
137+ # we subsample the data for the test to be more efficient
138+ input_data , _ , labels , _ = train_test_split (input_data , labels ,
139+ train_size = 20 )
140+ X = X [:10 ]
141+
142+ estimator = clone (estimator )
143+ estimator .set_params (preprocessor = preprocessor )
144+ set_random_state (estimator )
145+ input_variants , label_variants = generate_array_like (input_data , labels )
146+ for input_variant in input_variants :
147+ for label_variant in label_variants :
148+ estimator .fit (* remove_y_quadruplets (estimator , input_variant ,
149+ label_variant ))
150+ if hasattr (estimator , "predict" ):
151+ estimator .predict (input_variant )
152+ if hasattr (estimator , "predict_proba" ):
153+ estimator .predict_proba (input_variant ) # anticipation in case some
154+ # time we have that, or if ppl want to contribute with new algorithms
155+ # it will be checked automatically
156+ if hasattr (estimator , "decision_function" ):
157+ estimator .decision_function (input_variant )
158+ if hasattr (estimator , "score" ):
159+ for label_variant in label_variants :
160+ estimator .score (* remove_y_quadruplets (estimator , input_variant ,
161+ label_variant ))
162+
163+ X_variants , _ = generate_array_like (X )
164+ for X_variant in X_variants :
165+ estimator .transform (X_variant )
166+
167+ pairs = np .array ([[X [0 ], X [1 ]], [X [0 ], X [2 ]]])
168+ pairs_variants , _ = generate_array_like (pairs )
169+ for pairs_variant in pairs_variants :
170+ estimator .score_pairs (pairs_variant )
171+
108172
109173@pytest .mark .parametrize ('with_preprocessor' , [True , False ])
110174@pytest .mark .parametrize ('estimator, build_dataset' , pairs_learners ,
0 commit comments