|
| 1 | +from numpy.linalg import cholesky |
| 2 | +from scipy.spatial.distance import euclidean |
1 | 3 | from sklearn.base import BaseEstimator |
2 | 4 | from sklearn.utils.validation import _is_arraylike |
3 | 5 | from sklearn.metrics import roc_auc_score |
4 | 6 | import numpy as np |
5 | 7 | from abc import ABCMeta, abstractmethod |
6 | 8 | import six |
7 | | -from ._util import ArrayIndexer, check_input |
| 9 | +from ._util import ArrayIndexer, check_input, validate_vector |
| 10 | +import warnings |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): |
@@ -34,6 +37,14 @@ def score_pairs(self, pairs): |
34 | 37 | ------- |
35 | 38 | scores: `numpy.ndarray` of shape=(n_pairs,) |
36 | 39 | The score of every pair. |
| 40 | +
|
| 41 | + See Also |
| 42 | + -------- |
| 43 | + get_metric : a method that returns a function to compute the metric between |
| 44 | + two points. The difference with `score_pairs` is that it works on two 1D |
| 45 | + arrays and cannot use a preprocessor. Besides, the returned function is |
| 46 | + independent of the metric learner and hence is not modified if the metric |
| 47 | + learner is. |
37 | 48 | """ |
38 | 49 |
|
39 | 50 | def check_preprocessor(self): |
@@ -85,6 +96,47 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic', |
85 | 96 | tuple_size=getattr(self, '_tuple_size', None), |
86 | 97 | **kwargs) |
87 | 98 |
|
| 99 | + @abstractmethod |
| 100 | + def get_metric(self): |
| 101 | + """Returns a function that takes as input two 1D arrays and outputs the |
| 102 | + learned metric score on these two points. |
| 103 | +
|
| 104 | + This function will be independent from the metric learner that learned it |
| 105 | + (it will not be modified if the initial metric learner is modified), |
| 106 | + and it can be directly plugged into the `metric` argument of |
| 107 | + scikit-learn's estimators. |
| 108 | +
|
| 109 | + Returns |
| 110 | + ------- |
| 111 | + metric_fun : function |
| 112 | + The function described above. |
| 113 | +
|
| 114 | +
|
| 115 | + Examples |
| 116 | + -------- |
| 117 | + .. doctest:: |
| 118 | +
|
| 119 | + >>> from metric_learn import NCA |
| 120 | + >>> from sklearn.datasets import make_classification |
| 121 | + >>> from sklearn.neighbors import KNeighborsClassifier |
| 122 | + >>> nca = NCA() |
| 123 | + >>> X, y = make_classification() |
| 124 | + >>> nca.fit(X, y) |
| 125 | + >>> knn = KNeighborsClassifier(metric=nca.get_metric()) |
| 126 | + >>> knn.fit(X, y) # doctest: +NORMALIZE_WHITESPACE |
| 127 | + KNeighborsClassifier(algorithm='auto', leaf_size=30, |
| 128 | + metric=<function MahalanobisMixin.get_metric.<locals>.metric_fun |
| 129 | + at 0x...>, |
| 130 | + metric_params=None, n_jobs=None, n_neighbors=5, p=2, |
| 131 | + weights='uniform') |
| 132 | +
|
| 133 | + See Also |
| 134 | + -------- |
| 135 | + score_pairs : a method that returns the metric score between several pairs |
| 136 | + of points. Unlike `get_metric`, this is a method of the metric learner |
| 137 | + and therefore can change if the metric learner changes. Besides, it can |
| 138 | + use the metric learner's preprocessor, and works on concatenated arrays. |
| 139 | + """ |
88 | 140 |
|
89 | 141 | class MetricTransformer(six.with_metaclass(ABCMeta)): |
90 | 142 |
|
@@ -146,6 +198,17 @@ def score_pairs(self, pairs): |
146 | 198 | ------- |
147 | 199 | scores: `numpy.ndarray` of shape=(n_pairs,) |
148 | 200 | The learned Mahalanobis distance for every pair. |
| 201 | +
|
| 202 | + See Also |
| 203 | + -------- |
| 204 | + get_metric : a method that returns a function to compute the metric between |
| 205 | + two points. The difference with `score_pairs` is that it works on two 1D |
| 206 | + arrays and cannot use a preprocessor. Besides, the returned function is |
| 207 | + independent of the metric learner and hence is not modified if the metric |
| 208 | + learner is. |
| 209 | +
|
| 210 | + :ref:`mahalanobis_distances` : The section of the project documentation |
| 211 | + that describes Mahalanobis Distances. |
149 | 212 | """ |
150 | 213 | pairs = check_input(pairs, type_of_inputs='tuples', |
151 | 214 | preprocessor=self.preprocessor_, |
@@ -177,7 +240,57 @@ def transform(self, X): |
177 | 240 | accept_sparse=True) |
178 | 241 | return X_checked.dot(self.transformer_.T) |
179 | 242 |
|
| 243 | + def get_metric(self): |
| 244 | + transformer_T = self.transformer_.T.copy() |
| 245 | + |
| 246 | + def metric_fun(u, v, squared=False): |
| 247 | + """This function computes the metric between u and v, according to the |
| 248 | + previously learned metric. |
| 249 | +
|
| 250 | + Parameters |
| 251 | + ---------- |
| 252 | + u : array-like, shape=(n_features,) |
| 253 | + The first point involved in the distance computation. |
| 254 | +
|
| 255 | + v : array-like, shape=(n_features,) |
| 256 | + The second point involved in the distance computation. |
| 257 | +
|
| 258 | + squared : `bool` |
| 259 | + If True, the function will return the squared metric between u and |
| 260 | + v, which is faster to compute. |
| 261 | +
|
| 262 | + Returns |
| 263 | + ------- |
| 264 | + distance: float |
| 265 | + The distance between u and v according to the new metric. |
| 266 | + """ |
| 267 | + u = validate_vector(u) |
| 268 | + v = validate_vector(v) |
| 269 | + transformed_diff = (u - v).dot(transformer_T) |
| 270 | + dist = np.dot(transformed_diff, transformed_diff.T) |
| 271 | + if not squared: |
| 272 | + dist = np.sqrt(dist) |
| 273 | + return dist |
| 274 | + |
| 275 | + return metric_fun |
| 276 | + |
| 277 | + get_metric.__doc__ = BaseMetricLearner.get_metric.__doc__ |
| 278 | + |
180 | 279 | def metric(self): |
| 280 | + # TODO: remove this method in version 0.6.0 |
| 281 | + warnings.warn(("`metric` is deprecated since version 0.5.0 and will be " |
| 282 | + "removed in 0.6.0. Use `get_mahalanobis_matrix` instead."), |
| 283 | + DeprecationWarning) |
| 284 | + return self.get_mahalanobis_matrix() |
| 285 | + |
| 286 | + def get_mahalanobis_matrix(self): |
| 287 | + """Returns a copy of the Mahalanobis matrix learned by the metric learner. |
| 288 | +
|
| 289 | + Returns |
| 290 | + ------- |
| 291 | + M : `numpy.ndarray`, shape=(n_components, n_features) |
| 292 | + The copy of the learned Mahalanobis matrix. |
| 293 | + """ |
181 | 294 | return self.transformer_.T.dot(self.transformer_) |
182 | 295 |
|
183 | 296 |
|
|
0 commit comments