diff --git a/.gitignore b/.gitignore index 4c81e9fa..c532a6cb 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ .coverage htmlcov/ .cache/ +doc/auto_examples/* diff --git a/README.rst b/README.rst index 22b3e7e3..b1893cc6 100644 --- a/README.rst +++ b/README.rst @@ -34,27 +34,8 @@ package installed). **Usage** -For full usage examples, see the `sphinx documentation`_. - -Each metric is a subclass of ``BaseMetricLearner``, which provides -default implementations for the methods ``metric``, ``transformer``, and -``transform``. Subclasses must provide an implementation for either -``metric`` or ``transformer``. - -For an instance of a metric learner named ``foo`` learning from a set of -``d``-dimensional points, ``foo.metric()`` returns a ``d x d`` -matrix ``M`` such that the distance between vectors ``x`` and ``y`` is -expressed ``sqrt((x-y).dot(M).dot(x-y))``. -Using scipy's ``pdist`` function, this would look like -``pdist(X, metric='mahalanobis', VI=foo.metric())``. - -In the same scenario, ``foo.transformer()`` returns a ``d x d`` -matrix ``L`` such that a vector ``x`` can be represented in the learned -space as the vector ``x.dot(L.T)``. - -For convenience, the function ``foo.transform(X)`` is provided for -converting a matrix of points (``X``) into the learned space, in which -standard Euclidean distance can be used. +See the `sphinx documentation`_ for full documentation about installation, API, + usage, and examples. **Notes** diff --git a/bench/benchmarks/iris.py b/bench/benchmarks/iris.py index d0b76895..305c3a0f 100644 --- a/bench/benchmarks/iris.py +++ b/bench/benchmarks/iris.py @@ -10,7 +10,7 @@ 'LMNN': metric_learn.LMNN(k=5, learn_rate=1e-6, verbose=False), 'LSML_Supervised': metric_learn.LSML_Supervised(num_constraints=200), 'MLKR': metric_learn.MLKR(), - 'NCA': metric_learn.NCA(max_iter=700, learning_rate=0.01, num_dims=2), + 'NCA': metric_learn.NCA(max_iter=700, num_dims=2), 'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30, chunk_size=2), 'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500), diff --git a/doc/conf.py b/doc/conf.py index dff9ce47..ed476edd 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -7,6 +7,7 @@ 'sphinx.ext.viewcode', 'sphinx.ext.mathjax', 'numpydoc', + 'sphinx_gallery.gen_gallery' ] templates_path = ['_templates'] @@ -31,3 +32,6 @@ html_static_path = ['_static'] htmlhelp_basename = 'metric-learndoc' +# Option to only need single backticks to refer to symbols +default_role = 'any' + diff --git a/doc/getting_started.rst b/doc/getting_started.rst new file mode 100644 index 00000000..040adedc --- /dev/null +++ b/doc/getting_started.rst @@ -0,0 +1,42 @@ +############### +Getting started +############### + +Installation and Setup +====================== + +Run ``pip install metric-learn`` to download and install from PyPI. + +Alternately, download the source repository and run: + +- ``python setup.py install`` for default installation. +- ``python setup.py test`` to run all tests. + +**Dependencies** + +- Python 2.7+, 3.4+ +- numpy, scipy, scikit-learn +- (for running the examples only: matplotlib) + +**Notes** + +If a recent version of the Shogun Python modular (``modshogun``) library +is available, the LMNN implementation will use the fast C++ version from +there. The two implementations differ slightly, and the C++ version is +more complete. + + +Quick start +=========== + +This example loads the iris dataset, and evaluates a k-nearest neighbors +algorithm on an embedding space learned with `NCA`. + +>>> from metric_learn import NCA +>>> from sklearn.datasets import load_iris +>>> from sklearn.model_selection import cross_val_score +>>> from sklearn.pipeline import make_pipeline +>>> +>>> X, y = load_iris(return_X_y=True) +>>> clf = make_pipeline(NCA(), KNeighborsClassifier()) +>>> cross_val_score(clf, X, y) diff --git a/doc/index.rst b/doc/index.rst index 36a6e80c..9dbcd9b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -2,103 +2,31 @@ metric-learn: Metric Learning in Python ======================================= |License| |PyPI version| -Distance metrics are widely used in the machine learning literature. -Traditionally, practicioners would choose a standard distance metric -(Euclidean, City-Block, Cosine, etc.) using a priori knowledge of -the domain. -Distance metric learning (or simply, metric learning) is the sub-field of -machine learning dedicated to automatically constructing optimal distance -metrics. - -This package contains efficient Python implementations of several popular -metric learning algorithms. - -Supervised Algorithms ---------------------- -Supervised metric learning algorithms take as inputs points `X` and target -labels `y`, and learn a distance matrix that make points from the same class -(for classification) or with close target value (for regression) close to -each other, and points from different classes or with distant target values -far away from each other. +Welcome to metric-learn's documentation ! +----------------------------------------- .. toctree:: - :maxdepth: 1 - - metric_learn.covariance - metric_learn.lmnn - metric_learn.nca - metric_learn.lfda - metric_learn.mlkr + :maxdepth: 2 -Weakly-Supervised Algorithms --------------------------- -Weakly supervised algorithms work on weaker information about the data points -than supervised algorithms. Rather than labeled points, they take as input -similarity judgments on tuples of data points, for instance pairs of similar -and dissimilar points. Refer to the documentation of each algorithm for its -particular form of input data. + getting_started .. toctree:: - :maxdepth: 1 - - metric_learn.itml - metric_learn.lsml - metric_learn.sdml - metric_learn.rca - metric_learn.mmc - -Note that each weakly-supervised algorithm has a supervised version of the form -`*_Supervised` where similarity constraints are generated from -the labels information and passed to the underlying algorithm. - -Each metric learning algorithm supports the following methods: - -- ``fit(...)``, which learns the model. -- ``transformer()``, which returns a transformation matrix - :math:`L \in \mathbb{R}^{D \times d}`, which can be used to convert a - data matrix :math:`X \in \mathbb{R}^{n \times d}` to the - :math:`D`-dimensional learned metric space :math:`X L^{\top}`, - in which standard Euclidean distances may be used. -- ``transform(X)``, which applies the aforementioned transformation. -- ``metric()``, which returns a Mahalanobis matrix - :math:`M = L^{\top}L` such that distance between vectors ``x`` and - ``y`` can be computed as :math:`\left(x-y\right)M\left(x-y\right)`. - - -Installation and Setup -====================== - -Run ``pip install metric-learn`` to download and install from PyPI. + :maxdepth: 2 -Alternately, download the source repository and run: + user_guide -- ``python setup.py install`` for default installation. -- ``python setup.py test`` to run all tests. - -**Dependencies** - -- Python 2.7+, 3.4+ -- numpy, scipy, scikit-learn -- (for running the examples only: matplotlib) +.. toctree:: + :maxdepth: 2 -**Notes** + Package Overview -If a recent version of the Shogun Python modular (``modshogun``) library -is available, the LMNN implementation will use the fast C++ version from -there. The two implementations differ slightly, and the C++ version is -more complete. +.. toctree:: + :maxdepth: 2 -Navigation ----------- + auto_examples/index :ref:`genindex` | :ref:`modindex` | :ref:`search` -.. toctree:: - :maxdepth: 4 - :hidden: - - Package Overview - .. |PyPI version| image:: https://badge.fury.io/py/metric-learn.svg :target: http://badge.fury.io/py/metric-learn .. |License| image:: http://img.shields.io/:license-mit-blue.svg?style=flat diff --git a/doc/introduction.rst b/doc/introduction.rst new file mode 100644 index 00000000..9f2b4165 --- /dev/null +++ b/doc/introduction.rst @@ -0,0 +1,38 @@ +============ +Introduction +============ + +Distance metrics are widely used in the machine learning literature. +Traditionally, practitioners would choose a standard distance metric +(Euclidean, City-Block, Cosine, etc.) using a priori knowledge of +the domain. +Distance metric learning (or simply, metric learning) is the sub-field of +machine learning dedicated to automatically construct task-specific distance +metrics from (weakly) supervised data. +The learned distance metric often corresponds to a Euclidean distance in a new +embedding space, hence distance metric learning can be seen as a form of +representation learning. + +This package contains a efficient Python implementations of several popular +metric learning algorithms, compatible with scikit-learn. This allows to use +all the scikit-learn routines for pipelining and model selection for +metric learning algorithms. + + +Currently, each metric learning algorithm supports the following methods: + +- ``fit(...)``, which learns the model. +- ``metric()``, which returns a Mahalanobis matrix + :math:`M = L^{\top}L` such that distance between vectors ``x`` and + ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`. +- ``transformer_from_metric(metric)``, which returns a transformation matrix + :math:`L \in \mathbb{R}^{D \times d}`, which can be used to convert a + data matrix :math:`X \in \mathbb{R}^{n \times d}` to the + :math:`D`-dimensional learned metric space :math:`X L^{\top}`, + in which standard Euclidean distances may be used. +- ``transform(X)``, which applies the aforementioned transformation. +- ``score_pairs(pairs)`` which returns the distance between pairs of + points. ``pairs`` should be a 3D array-like of pairs of shape ``(n_pairs, + 2, n_features)``, or it can be a 2D array-like of pairs indicators of + shape ``(n_pairs, 2)`` (see section :ref:`preprocessor_section` for more + details). diff --git a/doc/metric_learn.nca.rst b/doc/metric_learn.nca.rst index 7a4ee2c4..00bc4eac 100644 --- a/doc/metric_learn.nca.rst +++ b/doc/metric_learn.nca.rst @@ -21,7 +21,7 @@ Example Code X = iris_data['data'] Y = iris_data['target'] - nca = NCA(max_iter=1000, learning_rate=0.01) + nca = NCA(max_iter=1000) nca.fit(X, Y) References diff --git a/doc/metric_learn.rst b/doc/metric_learn.rst index 70a99a04..c2472408 100644 --- a/doc/metric_learn.rst +++ b/doc/metric_learn.rst @@ -1,8 +1,8 @@ metric_learn package ==================== -Submodules ----------- +Module Contents +--------------- .. toctree:: @@ -16,11 +16,3 @@ Submodules metric_learn.nca metric_learn.rca metric_learn.sdml - -Module contents ---------------- - -.. automodule:: metric_learn - :members: - :undoc-members: - :show-inheritance: diff --git a/doc/preprocessor.rst b/doc/preprocessor.rst new file mode 100644 index 00000000..ad1ffd8f --- /dev/null +++ b/doc/preprocessor.rst @@ -0,0 +1,111 @@ +.. _preprocessor_section: + +============ +Preprocessor +============ + +Estimators in metric-learn all have a ``preprocessor`` option at instantiation. +Filling this argument allows them to take more compact input representation +when fitting, predicting etc... + +If ``preprocessor=None``, no preprocessor will be used and the user must +provide the classical representation to the fit/predict/score/etc... methods of +the estimators (see the documentation of the particular estimator to know the +type of input it accepts). Otherwise, two types of objects can be put in this +argument: + +Array-like +---------- +You can specify ``preprocessor=X`` where ``X`` is an array-like containing the +dataset of points. In this case, the fit/predict/score/etc... methods of the +estimator will be able to take as inputs an array-like of indices, replacing +under the hood each index by the corresponding sample. + + +Example with a supervised metric learner: + +>>> from metric_learn import NCA +>>> +>>> X = np.array([[-0.7 , -0.23], +>>> [-0.43, -0.49], +>>> [ 0.14, -0.37]]) # array of 3 samples of 2 features +>>> points_indices = np.array([2, 0, 1, 0]) +>>> y = np.array([1, 0, 1, 1]) +>>> +>>> nca = NCA(preprocessor=X) +>>> nca.fit(points_indices, y) +>>> # under the hood the algorithm will create +>>> # points = np.array([[ 0.14, -0.37], +>>> # [-0.7 , -0.23], +>>> # [-0.43, -0.49], +>>> # [ 0.14, -0.37]]) and fit on it + + +Example with a weakly supervised metric learner: + +>>> from metric_learn import MMC +>>> X = np.array([[-0.7 , -0.23], +>>> [-0.43, -0.49], +>>> [ 0.14, -0.37]]) # array of 3 samples of 2 features +>>> pairs_indices = np.array([[2, 0], [1, 0]]) +>>> y_pairs = np.array([1, -1]) +>>> +>>> mmc = MMC(preprocessor=X) +>>> mmc.fit(pairs_indices, y_pairs) +>>> # under the hood the algorithm will create +>>> # pairs = np.array([[[ 0.14, -0.37], [-0.7 , -0.23]], +>>> # [[-0.43, -0.49], [-0.7 , -0.23]]]) and fit on it + +Callable +-------- +Alternatively, you can provide a callable as ``preprocessor``. Then the +estimator will accept indicators of points instead of points. Under the hood, +the estimator will call this callable on the indicators you provide as input +when fitting, predicting etc... Using a callable can be really useful to +represent lazily a dataset of images stored on the file system for instance. +The callable should take as an input a 1D array-like, and return a 2D +array-like. For supervised learners it will be applied on the whole 1D array of +indicators at once, and for weakly supervised learners it will be applied on +each column of the 2D array of tuples. + +Example with a supervised metric learner: + +>>> def find_images(file_paths): +>>> # each file contains a small image to use as an input datapoint +>>> return np.row_stack([imread(f).ravel() for f in file_paths]) +>>> +>>> nca = NCA(preprocessor=find_images) +>>> nca.fit(['img01.png', 'img00.png', 'img02.png'], [1, 0, 1]) +>>> # under the hood preprocessor(indicators) will be called + + +Example with a weakly supervised metric learner: + +>>> pairs_images_paths = [['img02.png', 'img00.png'], +>>> ['img01.png', 'img00.png']] +>>> y_pairs = np.array([1, -1]) +>>> +>>> mmc = NCA(preprocessor=find_images) +>>> mmc.fit(pairs_images_paths, y_pairs) +>>> # under the hood preprocessor(pairs_indicators[i]) will be called for each +>>> # i in [0, 1] + + +.. note:: Note that when you fill the ``preprocessor`` option, it allows you + to give more compact inputs, but the classical way of providing inputs + stays valid (2D array-like for supervised learners and 3D array-like of + tuples for weakly supervised learners). If a classical input + is provided, the metric learner will not use the preprocessor. + + Example: This will work: + + >>> from metric_learn import MMC + >>> def preprocessor_wip(array): + >>> raise NotImplementedError("This preprocessor does nothing yet.") + >>> + >>> pairs = np.array([[[ 0.14, -0.37], [-0.7 , -0.23]], + >>> [[-0.43, -0.49], [-0.7 , -0.23]]]) + >>> y_pairs = np.array([1, -1]) + >>> + >>> mmc = MMC(preprocessor=preprocessor_wip) + >>> mmc.fit(pairs, y_pairs) # preprocessor_wip will not be called here diff --git a/doc/supervised.rst b/doc/supervised.rst new file mode 100644 index 00000000..26934a47 --- /dev/null +++ b/doc/supervised.rst @@ -0,0 +1,209 @@ +========================== +Supervised Metric Learning +========================== + +Supervised metric learning algorithms take as inputs points `X` and target +labels `y`, and learn a distance matrix that make points from the same class +(for classification) or with close target value (for regression) close to each +other, and points from different classes or with distant target values far away +from each other. + +Scikit-learn compatibility +========================== + +All supervised algorithms are scikit-learn `Estimators`, so they are +compatible with Pipelining and scikit-learn model selection routines. + +Algorithms +========== + +Covariance +---------- + +.. todo:: Covariance is unsupervised, so its doc should not be here. + +`Covariance` does not "learn" anything, rather it calculates +the covariance matrix of the input data. This is a simple baseline method. + +.. topic:: Example Code: + +:: + + from metric_learn import Covariance + from sklearn.datasets import load_iris + + iris = load_iris()['data'] + + cov = Covariance().fit(iris) + x = cov.transform(iris) + +.. topic:: References: + + .. [1] On the Generalized Distance in Statistics, P.C.Mahalanobis, 1936 + +LMNN +----- + +Large-margin nearest neighbor metric learning. + +`LMNN` learns a Mahanalobis distance metric in the kNN classification +setting using semidefinite programming. The learned metric attempts to keep +k-nearest neighbors in the same class, while keeping examples from different +classes separated by a large margin. This algorithm makes no assumptions about +the distribution of the data. + +.. topic:: Example Code: + +:: + + import numpy as np + from metric_learn import LMNN + from sklearn.datasets import load_iris + + iris_data = load_iris() + X = iris_data['data'] + Y = iris_data['target'] + + lmnn = LMNN(k=5, learn_rate=1e-6) + lmnn.fit(X, Y, verbose=False) + +If a recent version of the Shogun Python modular (``modshogun``) library +is available, the LMNN implementation will use the fast C++ version from +there. Otherwise, the included pure-Python version will be used. +The two implementations differ slightly, and the C++ version is more complete. + +.. topic:: References: + + .. [1] `Distance Metric Learning for Large Margin Nearest Neighbor + Classification + `_ Kilian Q. Weinberger, John + Blitzer, Lawrence K. Saul + +NCA +--- + +Neighborhood Components Analysis (`NCA`) is a distance metric learning +algorithm which aims to improve the accuracy of nearest neighbors +classification compared to the standard Euclidean distance. The algorithm +directly maximizes a stochastic variant of the leave-one-out k-nearest +neighbors (KNN) score on the training set. It can also learn a low-dimensional +linear embedding of data that can be used for data visualization and fast +classification. + +.. topic:: Example Code: + +:: + + import numpy as np + from metric_learn import NCA + from sklearn.datasets import load_iris + + iris_data = load_iris() + X = iris_data['data'] + Y = iris_data['target'] + + nca = NCA(max_iter=1000) + nca.fit(X, Y) + +.. topic:: References: + + .. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov. + "Neighbourhood Components Analysis". Advances in Neural Information + Processing Systems. 17, 513-520, 2005. + http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf + + .. [2] Wikipedia entry on Neighborhood Components Analysis + https://en.wikipedia.org/wiki/Neighbourhood_components_analysis + +LFDA +---- + +Local Fisher Discriminant Analysis (LFDA) + +`LFDA` is a linear supervised dimensionality reduction method. It is +particularly useful when dealing with multimodality, where one ore more classes +consist of separate clusters in input space. The core optimization problem of +LFDA is solved as a generalized eigenvalue problem. + +.. topic:: Example Code: + +:: + + import numpy as np + from metric_learn import LFDA + from sklearn.datasets import load_iris + + iris_data = load_iris() + X = iris_data['data'] + Y = iris_data['target'] + + lfda = LFDA(k=2, dim=2) + lfda.fit(X, Y) + +.. topic:: References: + + .. [1] `Dimensionality Reduction of Multimodal Labeled Data by Local + Fisher Discriminant Analysis `_ Masashi Sugiyama. + + .. [2] `Local Fisher Discriminant Analysis on Beer Style Clustering + `_ Yuan Tang. + + +MLKR +---- + +Metric Learning for Kernel Regression. + +`MLKR` is an algorithm for supervised metric learning, which learns a +distance function by directly minimising the leave-one-out regression error. +This algorithm can also be viewed as a supervised variation of PCA and can be +used for dimensionality reduction and high dimensional data visualization. + +.. topic:: Example Code: + +:: + + from metric_learn import MLKR + from sklearn.datasets import load_iris + + iris_data = load_iris() + X = iris_data['data'] + Y = iris_data['target'] + + mlkr = MLKR() + mlkr.fit(X, Y) + +.. topic:: References: + + .. [1] `Metric Learning for Kernel Regression `_ Kilian Q. Weinberger, + Gerald Tesauro + + +Supervised versions of weakly-supervised algorithms +--------------------------------------------------- + +Note that each :ref:`weakly-supervised algorithm ` +has a supervised version of the form `*_Supervised` where similarity tuples are +generated from the labels information and passed to the underlying algorithm. + +.. todo:: add more details about that (see issue ``_) + + +.. topic:: Example Code: + +:: + + from metric_learn import MMC_Supervised + from sklearn.datasets import load_iris + + iris_data = load_iris() + X = iris_data['data'] + Y = iris_data['target'] + + mmc = MMC_Supervised(num_constraints=200) + mmc.fit(X, Y) diff --git a/doc/user_guide.rst b/doc/user_guide.rst new file mode 100644 index 00000000..fb7060ce --- /dev/null +++ b/doc/user_guide.rst @@ -0,0 +1,15 @@ +.. title:: User guide: contents + +.. _user_guide: + +========== +User Guide +========== + +.. toctree:: + :numbered: + + introduction.rst + supervised.rst + weakly_supervised.rst + preprocessor.rst \ No newline at end of file diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst new file mode 100644 index 00000000..deae9b40 --- /dev/null +++ b/doc/weakly_supervised.rst @@ -0,0 +1,345 @@ +.. _weakly_supervised_section: + +================================= +Weakly Supervised Metric Learning +================================= + +Weakly supervised algorithms work on weaker information about the data points +than supervised algorithms. Rather than labeled points, they take as input +similarity judgments on tuples of data points, for instance pairs of similar +and dissimilar points. Refer to the documentation of each algorithm for its +particular form of input data. + + +Input data +========== + +In the following paragraph we talk about tuples for sake of generality. These +can be pairs, triplets, quadruplets etc, depending on the particular metric +learning algorithm we use. + +Basic form +---------- +Every weakly supervised algorithm will take as input tuples of points, and if +needed labels for theses tuples. + + +The `tuples` argument is the first argument of every method (like the X +argument for classical algorithms in scikit-learn). The second argument is the +label of the tuple: its semantic depends on the algorithm used. For instance +for pairs learners ``y`` is a label indicating whether the pair is of similar +samples or dissimilar samples. + +Then one can fit a Weakly Supervised Metric Learner on this tuple, like this: + +>>> my_algo.fit(tuples, y) + +Like in a classical setting we split the points ``X`` between train and test, +here we split the ``tuples`` between train and test. + +>>> from sklearn.model_selection import train_test_split +>>> pairs_train, pairs_test, y_train, y_test = train_test_split(pairs, y) + +These are two data structures that can be used to represent tuple in metric +learn: + +3D array of tuples +------------------ + +The most intuitive way to represent tuples is to provide the algorithm with a +3D array-like of tuples of shape ``(n_tuples, t, n_features)``, where +``n_tuples`` is the number of tuples, ``tuple_size`` is the number of elements +in a tuple (2 for pairs, 3 for triplets for instance), and ``n_features`` is +the number of features of each point. + +.. topic:: Example: + Here is an artificial dataset of 4 pairs of 2 points of 3 features each: + +>>> import numpy as np +>>> tuples = np.array([[[-0.12, -1.21, -0.20], +>>> [+0.05, -0.19, -0.05]], +>>> +>>> [[-2.16, +0.11, -0.02], +>>> [+1.58, +0.16, +0.93]], +>>> +>>> [[+1.58, +0.16, +0.93 ], # same as tuples[1, 1, :] +>>> [+0.89, -0.34, +2.41]], +>>> +>>> [[-0.12, -1.21, -0.20 ], # same as tuples[0, 0, :] +>>> [-2.16, +0.11, -0.02]]]) # same as tuples[1, 0, :] +>>> y = np.array([-1, 1, 1, -1]) + +.. warning:: This way of specifying pairs is not recommended for a large number + of tuples, as it is redundant (see the comments in the example) and hence + takes a lot of memory. Indeed each feature vector of a point will be + replicated as many times as a point is involved in a tuple. The second way + to specify pairs is more efficient + + +2D array of indicators + preprocessor +------------------------------------- + +Instead of forming each point in each tuple, a more efficient representation +would be to keep the dataset of points ``X`` aside, and just represent tuples +as a collection of tuples of *indices* from the points in ``X``. Since we loose +the feature dimension there, the resulting array is 2D. + +.. topic:: Example: An equivalent representation of the above pairs would be: + +>>> X = np.array([[-0.12, -1.21, -0.20], +>>> [+0.05, -0.19, -0.05], +>>> [-2.16, +0.11, -0.02], +>>> [+1.58, +0.16, +0.93], +>>> [+0.89, -0.34, +2.41]]) +>>> +>>> tuples_indices = np.array([[0, 1], +>>> [2, 3], +>>> [3, 4], +>>> [0, 2]]) +>>> y = np.array([-1, 1, 1, -1]) + +In order to fit metric learning algorithms with this type of input, we need to +give the original dataset of points ``X`` to the estimator so that it knows +the points the indices refer to. We do this when initializing the estimator, +through the argument `preprocessor`. + +.. topic:: Example: + +>>> from metric_learn import MMC +>>> mmc = MMC(preprocessor=X) +>>> mmc.fit(pairs_indice, y) + + +.. note:: + + Instead of an array-like, you can give a callable in the argument + ``preprocessor``, which will go fetch and form the tuples. This allows to + give more general indicators than just indices from an array (for instance + paths in the filesystem, name of records in a database etc...) See section + :ref:`preprocessor_section` for more details on how to use the preprocessor. + + +Scikit-learn compatibility +========================== + +Weakly supervised estimators are compatible with scikit-learn routines for +model selection (grid-search, cross-validation etc). See the scoring section +for more details on the scoring used in the case of Weakly Supervised +Metric Learning. + +.. topic:: Example + +>>> from metric_learn import MMC +>>> from sklearn.datasets import load_iris +>>> from sklearn.model_selection import cross_val_score +>>> rng = np.random.RandomState(42) +>>> X, _ = load_iris(return_X_y=True) +>>> # let's sample 30 random pairs and labels of pairs +>>> pairs_indices = rng.randint(X.shape[0], size=(30, 2)) +>>> y = rng.randint(2, size=30) +>>> mmc = MMC(preprocessor=X) +>>> cross_val_score(mmc, pairs_indices, y) + +Scoring +======= + +Some default scoring are implemented in metric-learn, depending on the kind of +tuples you're working with (pairs, triplets...). See the docstring of the +`score` method of the estimator you use. + + +Algorithms +================== + +ITML +---- + +Information Theoretic Metric Learning, Davis et al., ICML 2007 + +`ITML` minimizes the differential relative entropy between two multivariate +Gaussians under constraints on the distance function, which can be formulated +into a Bregman optimization problem by minimizing the LogDet divergence subject +to linear constraints. This algorithm can handle a wide variety of constraints +and can optionally incorporate a prior on the distance function. Unlike some +other methods, ITML does not rely on an eigenvalue computation or semi-definite +programming. + +.. topic:: Example Code: + +:: + + from metric_learn import ITML + + pairs = [[[1.2, 7.5], [1.3, 1.5]], + [[6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6]], + [[6.2, 5.5], [5.4, 5.4]]] + y = [1, 1, -1, -1] + + # in this task we want points where the first feature is close to be closer + # to each other, no matter how close the second feature is + + + itml = ITML() + itml.fit(pairs, y) + +.. topic:: References: + + .. [1] `Information-theoretic Metric Learning `_ Jason V. Davis, + et al. + + .. [2] Adapted from Matlab code at http://www.cs.utexas.edu/users/pjain/ + itml/ + + +LSML +---- + +`LSML`: Metric Learning from Relative Comparisons by Minimizing Squared +Residual + +.. topic:: Example Code: + +:: + + from metric_learn import LSML + + quadruplets = [[[1.2, 7.5], [1.3, 1.5], [6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6], [6.2, 5.5], [5.4, 5.4]], + [[3.2, 7.5], [3.3, 1.5], [8.4, 2.6], [8.2, 9.7]], + [[3.3, 4.5], [5.2, 4.6], [8.2, 5.5], [7.4, 5.4]]] + + # we want to make closer points where the first feature is close, and + # further if the second feature is close + + lsml = LSML() + lsml.fit(quadruplets) + +.. topic:: References: + + .. [1] Liu et al. + "Metric Learning from Relative Comparisons by Minimizing Squared + Residual". ICDM 2012. http://www.cs.ucla.edu/~weiwang/paper/ICDM12.pdf + + .. [2] Adapted from https://gist.github.com/kcarnold/5439917 + + +SDML +---- + +`SDML`: An efficient sparse metric learning in high-dimensional space via +L1-penalized log-determinant regularization + +.. topic:: Example Code: + +:: + + from metric_learn import SDML + + pairs = [[[1.2, 7.5], [1.3, 1.5]], + [[6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6]], + [[6.2, 5.5], [5.4, 5.4]]] + y = [1, 1, -1, -1] + + # in this task we want points where the first feature is close to be closer + # to each other, no matter how close the second feature is + + sdml = SDML() + sdml.fit(pairs, y) + +.. topic:: References: + + .. [1] Qi et al. + An efficient sparse metric learning in high-dimensional space via + L1-penalized log-determinant regularization. ICML 2009. + http://lms.comp.nus.edu.sg/sites/default/files/publication-attachments/ + icml09-guojun.pdf + + .. [2] Adapted from https://gist.github.com/kcarnold/5439945 + + +RCA +--- + +Relative Components Analysis (RCA) + +`RCA` learns a full rank Mahalanobis distance metric based on a weighted sum of +in-class covariance matrices. It applies a global linear transformation to +assign large weights to relevant dimensions and low weights to irrelevant +dimensions. Those relevant dimensions are estimated using "chunklets", subsets +of points that are known to belong to the same class. + +.. topic:: Example Code: + +:: + + from metric_learn import RCA + + pairs = [[[1.2, 7.5], [1.3, 1.5]], + [[6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6]], + [[6.2, 5.5], [5.4, 5.4]]] + y = [1, 1, -1, -1] + + # in this task we want points where the first feature is close to be closer + # to each other, no matter how close the second feature is + + rca = RCA() + rca.fit(pairs, y) + + +.. topic:: References: + + .. [1] `Adjustment learning and relevant component analysis + `_ Noam Shental, et al. + + .. [2] 'Learning distance functions using equivalence relations', ICML 2003 + + .. [3]'Learning a Mahalanobis metric from equivalence constraints', JMLR + 2005 + +MMC +--- + +Mahalanobis Metric Learning with Application for Clustering with +Side-Information, Xing et al., NIPS 2002 + +`MMC` minimizes the sum of squared distances between similar examples, while +enforcing the sum of distances between dissimilar examples to be greater than a +certain margin. This leads to a convex and, thus, local-minima-free +optimization problem that can be solved efficiently. However, the algorithm +involves the computation of eigenvalues, which is the main speed-bottleneck. +Since it has initially been designed for clustering applications, one of the +implicit assumptions of MMC is that all classes form a compact set, i.e., +follow a unimodal distribution, which restricts the possible use-cases of this +method. However, it is one of the earliest and a still often cited technique. + +.. topic:: Example Code: + +:: + + from metric_learn import MMC + + pairs = [[[1.2, 7.5], [1.3, 1.5]], + [[6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6]], + [[6.2, 5.5], [5.4, 5.4]]] + y = [1, 1, -1, -1] + + # in this task we want points where the first feature is close to be closer + # to each other, no matter how close the second feature is + + mmc = MMC() + mmc.fit(pairs, y) + +.. topic:: References: + + .. [1] `Distance metric learning with application to clustering with + side-information `_ Xing, Jordan, Russell, Ng. + .. [2] Adapted from Matlab code `here `_. diff --git a/examples/README.txt b/examples/README.txt new file mode 100644 index 00000000..10dbe0d5 --- /dev/null +++ b/examples/README.txt @@ -0,0 +1,4 @@ +Examples +======== + +Below is a gallery of example metric-learn use cases. \ No newline at end of file diff --git a/examples/sandwich.py b/examples/plot_sandwich.py similarity index 97% rename from examples/sandwich.py rename to examples/plot_sandwich.py index 34b48a00..0e7658d3 100644 --- a/examples/sandwich.py +++ b/examples/plot_sandwich.py @@ -1,4 +1,8 @@ +# -*- coding: utf-8 -*- """ +Sandwich demo +============= + Sandwich demo based on code from http://nbviewer.ipython.org/6576096 """ @@ -30,7 +34,7 @@ def sandwich_demo(): for ax_num, ml in enumerate(mls, start=3): ml.fit(x, y) - tx = ml.transform() + tx = ml.transform(x) ml_knn = nearest_neighbors(tx, k=2) ax = plt.subplot(3, 2, ax_num) plot_sandwich_data(tx, y, axis=ax) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index b34860d6..27707be9 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,8 @@ import numpy as np - +import six +from sklearn.utils import check_array +from sklearn.utils.validation import check_X_y +from metric_learn.exceptions import PreprocessorError # hack around lack of axis kwarg in older numpy versions try: @@ -9,4 +12,313 @@ def vector_norm(X): return np.apply_along_axis(np.linalg.norm, 1, X) else: def vector_norm(X): - return np.linalg.norm(X, axis=1) \ No newline at end of file + return np.linalg.norm(X, axis=1) + + +def check_input(input_data, y=None, preprocessor=None, + type_of_inputs='classic', tuple_size=None, accept_sparse=False, + dtype='numeric', order=None, + copy=False, force_all_finite=True, + multi_output=False, ensure_min_samples=1, + ensure_min_features=1, y_numeric=False, + warn_on_dtype=False, estimator=None): + """Checks that the input format is valid, and converts it if specified + (this is the equivalent of scikit-learn's `check_array` or `check_X_y`). + All arguments following tuple_size are scikit-learn's `check_X_y` + arguments that will be enforced on the data and labels array. If + indicators are given as an input data array, the returned data array + will be the formed points/tuples, using the given preprocessor. + + Parameters + ---------- + input : array-like + The input data array to check. + + y : array-like + The input labels array to check. + + preprocessor : callable (default=`None`) + The preprocessor to use. If None, no preprocessor is used. + + type_of_inputs : `str` {'classic', 'tuples'} + The type of inputs to check. If 'classic', the input should be + a 2D array-like of points or a 1D array like of indicators of points. If + 'tuples', the input should be a 3D array-like of tuples or a 2D + array-like of indicators of tuples. + + accept_sparse : `bool` + Set to true to allow sparse inputs (only works for sparse inputs with + dim < 3). + + tuple_size : int + The number of elements in a tuple (e.g. 2 for pairs). + + dtype : string, type, list of types or None (default='numeric') + Data type of result. If None, the dtype of the input is preserved. + If 'numeric', dtype is preserved unless array.dtype is object. + If dtype is a list of types, conversion on the first type is only + performed if the dtype of the input is not in the list. + + order : 'F', 'C' or None (default=`None`) + Whether an array will be forced to be fortran or c-style. + + copy : boolean (default=False) + Whether a forced copy will be triggered. If copy=False, a copy might + be triggered by a conversion. + + force_all_finite : boolean or 'allow-nan', (default=True) + Whether to raise an error on np.inf and np.nan in X. This parameter + does not influence whether y can have np.inf or np.nan values. + The possibilities are: + - True: Force all values of X to be finite. + - False: accept both np.inf and np.nan in X. + - 'allow-nan': accept only np.nan values in X. Values cannot be + infinite. + + ensure_min_samples : int (default=1) + Make sure that X has a minimum number of samples in its first + axis (rows for a 2D array). + + ensure_min_features : int (default=1) + Make sure that the 2D array has some minimum number of features + (columns). The default value of 1 rejects empty datasets. + This check is only enforced when X has effectively 2 dimensions or + is originally 1D and ``ensure_2d`` is True. Setting to 0 disables + this check. + + warn_on_dtype : boolean (default=False) + Raise DataConversionWarning if the dtype of the input data structure + does not match the requested dtype, causing a memory copy. + + estimator : str or estimator instance (default=`None`) + If passed, include the name of the estimator in warning messages. + + Returns + ------- + X : `numpy.ndarray` + The checked input data array. + + y: `numpy.ndarray` (optional) + The checked input labels array. + """ + + context = make_context(estimator) + + args_for_sk_checks = dict(accept_sparse=accept_sparse, + dtype=dtype, order=order, + copy=copy, force_all_finite=force_all_finite, + ensure_min_samples=ensure_min_samples, + ensure_min_features=ensure_min_features, + warn_on_dtype=warn_on_dtype, estimator=estimator) + + # We need to convert input_data into a numpy.ndarray if possible, before + # any further checks or conversions, and deal with y if needed. Therefore + # we use check_array/check_X_y with fixed permissive arguments. + if y is None: + input_data = check_array(input_data, ensure_2d=False, allow_nd=True, + copy=False, force_all_finite=False, + accept_sparse=True, dtype=None, + ensure_min_features=0, ensure_min_samples=0) + else: + input_data, y = check_X_y(input_data, y, ensure_2d=False, allow_nd=True, + copy=False, force_all_finite=False, + accept_sparse=True, dtype=None, + ensure_min_features=0, ensure_min_samples=0, + multi_output=multi_output, + y_numeric=y_numeric) + + if type_of_inputs == 'classic': + input_data = check_input_classic(input_data, context, preprocessor, + args_for_sk_checks) + + elif type_of_inputs == 'tuples': + input_data = check_input_tuples(input_data, context, preprocessor, + args_for_sk_checks, tuple_size) + + else: + raise ValueError("Unknown value {} for type_of_inputs. Valid values are " + "'classic' or 'tuples'.".format(type_of_inputs)) + + return input_data if y is None else (input_data, y) + + +def check_input_tuples(input_data, context, preprocessor, args_for_sk_checks, + tuple_size): + preprocessor_has_been_applied = False + if input_data.ndim == 2: + if preprocessor is not None: + input_data = preprocess_tuples(input_data, preprocessor) + preprocessor_has_been_applied = True + else: + make_error_input(201, input_data, context) + elif input_data.ndim == 3: + pass + else: + if preprocessor is not None: + make_error_input(420, input_data, context) + else: + make_error_input(200, input_data, context) + input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + **args_for_sk_checks) + # we need to check num_features because check_array does not check it + # for 3D inputs: + if args_for_sk_checks['ensure_min_features'] > 0: + n_features = input_data.shape[2] + if n_features < args_for_sk_checks['ensure_min_features']: + raise ValueError("Found array with {} feature(s) (shape={}) while" + " a minimum of {} is required{}." + .format(n_features, input_data.shape, + args_for_sk_checks['ensure_min_features'], + context)) + # normally we don't need to check_tuple_size too because tuple_size + # shouldn't be able to be modified by any preprocessor + if input_data.ndim != 3: + # we have to ensure this because check_array above does not + if preprocessor_has_been_applied: + make_error_input(211, input_data, context) + else: + make_error_input(201, input_data, context) + check_tuple_size(input_data, tuple_size, context) + return input_data + + +def check_input_classic(input_data, context, preprocessor, args_for_sk_checks): + preprocessor_has_been_applied = False + if input_data.ndim == 1: + if preprocessor is not None: + input_data = preprocess_points(input_data, preprocessor) + preprocessor_has_been_applied = True + else: + make_error_input(101, input_data, context) + elif input_data.ndim == 2: + pass # OK + else: + if preprocessor is not None: + make_error_input(320, input_data, context) + else: + make_error_input(100, input_data, context) + + input_data = check_array(input_data, allow_nd=True, ensure_2d=False, + **args_for_sk_checks) + if input_data.ndim != 2: + # we have to ensure this because check_array above does not + if preprocessor_has_been_applied: + make_error_input(111, input_data, context) + else: + make_error_input(101, input_data, context) + return input_data + + +def make_error_input(code, input_data, context): + code_str = {'expected_input': {'1': '2D array of formed points', + '2': '3D array of formed tuples', + '3': ('1D array of indicators or 2D array of ' + 'formed points'), + '4': ('2D array of indicators or 3D array ' + 'of formed tuples')}, + 'additional_context': {'0': '', + '2': ' when using a preprocessor', + '1': (' after the preprocessor has been ' + 'applied')}, + 'possible_preprocessor': {'0': '', + '1': ' and/or use a preprocessor' + }} + code_list = str(code) + err_args = dict(expected_input=code_str['expected_input'][code_list[0]], + additional_context=code_str['additional_context'] + [code_list[1]], + possible_preprocessor=code_str['possible_preprocessor'] + [code_list[2]], + input_data=input_data, context=context, + found_size=input_data.ndim) + err_msg = ('{expected_input} expected' + '{context}{additional_context}. Found {found_size}D array ' + 'instead:\ninput={input_data}. Reshape your data' + '{possible_preprocessor}.\n') + raise ValueError(err_msg.format(**err_args)) + + +def preprocess_tuples(tuples, preprocessor): + try: + tuples = np.column_stack([preprocessor(tuples[:, i])[:, np.newaxis] for + i in range(tuples.shape[1])]) + except Exception as e: + raise PreprocessorError(e) + return tuples + + +def preprocess_points(points, preprocessor): + """form points if there is a preprocessor else keep them as such (assumes + that check_points has already been called)""" + try: + points = preprocessor(points) + except Exception as e: + raise PreprocessorError(e) + return points + + +def make_context(estimator): + """Helper function to create a string with the estimator name. + Taken from check_array function in scikit-learn. + Will return the following for instance: + NCA: ' by NCA' + 'NCA': ' by NCA' + None: '' + """ + estimator_name = make_name(estimator) + context = (' by ' + estimator_name) if estimator_name is not None else '' + return context + + +def make_name(estimator): + """Helper function that returns the name of estimator or the given string + if a string is given + """ + if estimator is not None: + if isinstance(estimator, six.string_types): + estimator_name = estimator + else: + estimator_name = estimator.__class__.__name__ + else: + estimator_name = None + return estimator_name + + +def check_tuple_size(tuples, tuple_size, context): + """Helper function to check that the number of points in each tuple is + equal to tuple_size (e.g. 2 for pairs), and raise a `ValueError` otherwise""" + if tuple_size is not None and tuples.shape[1] != tuple_size: + msg_t = (("Tuples of {} element(s) expected{}. Got tuples of {} " + "element(s) instead (shape={}):\ninput={}.\n") + .format(tuple_size, context, tuples.shape[1], tuples.shape, + tuples)) + raise ValueError(msg_t) + + +class ArrayIndexer: + + def __init__(self, X): + # we check the array-like preprocessor here, and we as much permissive + # as possible (because the user will check for the desired + # format with arguments in check_input, and only this latter function + # should return the appropriate errors). We do this only to have a numpy + # array object which can be indexed by another numpy array object. + X = check_array(X, + accept_sparse=True, dtype=None, + force_all_finite=False, + ensure_2d=False, allow_nd=True, + ensure_min_samples=0, + ensure_min_features=0, + warn_on_dtype=False, estimator=None) + self.X = X + + def __call__(self, indices): + return self.X[indices] + + +def check_collapsed_pairs(pairs): + num_ident = (vector_norm(pairs[:, 0] - pairs[:, 1]) < 1e-9).sum() + if num_ident: + raise ValueError("{} collapsed pairs found (where the left element is " + "the same as the right element), out of {} pairs " + "in total.".format(num_ident, pairs.shape[0])) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 02519de1..9af79ecc 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,51 +1,359 @@ -from numpy.linalg import inv, cholesky -from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.utils.validation import check_array +from numpy.linalg import cholesky +from sklearn.base import BaseEstimator +from sklearn.utils.validation import _is_arraylike +from sklearn.metrics import roc_auc_score +import numpy as np +from abc import ABCMeta, abstractmethod +import six +from ._util import ArrayIndexer, check_input -class BaseMetricLearner(BaseEstimator, TransformerMixin): - def __init__(self): - raise NotImplementedError('BaseMetricLearner should not be instantiated') +class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)): - def metric(self): - """Computes the Mahalanobis matrix from the transformation matrix. + def __init__(self, preprocessor=None): + """ + + Parameters + ---------- + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. + """ + self.preprocessor = preprocessor - .. math:: M = L^{\\top} L + @abstractmethod + def score_pairs(self, pairs): + """Returns the score between pairs + (can be a similarity, or a distance/metric depending on the algorithm) + + Parameters + ---------- + pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features) + 3D array of pairs. Returns ------- - M : (d x d) matrix + scores: `numpy.ndarray` of shape=(n_pairs,) + The score of every pair. """ - L = self.transformer() - return L.T.dot(L) - def transformer(self): - """Computes the transformation matrix from the Mahalanobis matrix. + def check_preprocessor(self): + """Initializes the preprocessor""" + if _is_arraylike(self.preprocessor): + self.preprocessor_ = ArrayIndexer(self.preprocessor) + elif callable(self.preprocessor) or self.preprocessor is None: + self.preprocessor_ = self.preprocessor + else: + raise ValueError("Invalid type for the preprocessor: {}. You should " + "provide either None, an array-like object, " + "or a callable.".format(type(self.preprocessor))) + + def _prepare_inputs(self, X, y=None, type_of_inputs='classic', + **kwargs): + """Initializes the preprocessor and processes inputs. See `check_input` + for more details. + + Parameters + ---------- + input: array-like + The input data array to check. - L = cholesky(M).T + y : array-like + The input labels array to check. + + type_of_inputs: `str` {'classic', 'tuples'} + The type of inputs to check. If 'classic', the input should be + a 2D array-like of points or a 1D array like of indicators of points. If + 'tuples', the input should be a 3D array-like of tuples or a 2D + array-like of indicators of tuples. + + **kwargs: dict + Arguments to pass to check_input. Returns ------- - L : upper triangular (d x d) matrix + X : `numpy.ndarray` + The checked input data array. + + y: `numpy.ndarray` (optional) + The checked input labels array. """ - return cholesky(self.metric()).T + self.check_preprocessor() + return check_input(X, y, + type_of_inputs=type_of_inputs, + preprocessor=self.preprocessor_, + estimator=self, + tuple_size=getattr(self, '_tuple_size', None), + **kwargs) + - def transform(self, X=None): +class MetricTransformer(six.with_metaclass(ABCMeta)): + + @abstractmethod + def transform(self, X): """Applies the metric transformation. Parameters ---------- - X : (n x d) matrix, optional - Data to transform. If not supplied, the training data will be used. + X : (n x d) matrix + Data to transform. Returns ------- transformed : (n x d) matrix Input data transformed to the metric space by :math:`XL^{\\top}` """ - if X is None: - X = self.X_ + + +class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner, + MetricTransformer)): + """Mahalanobis metric learning algorithms. + + Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`, + defined between two column vectors :math:`x` and :math:`x'` by: :math:`d_M(x, + x') = \sqrt{(x-x')^T M (x-x')}`, where :math:`M` is a learned symmetric + positive semi-definite (PSD) matrix. The metric between points can then be + expressed as the euclidean distance between points embedded in a new space + through a linear transformation. Indeed, the above matrix can be decomposed + into the product of two transpose matrices (through SVD or Cholesky + decomposition): :math:`d_M(x, x')^2 = (x-x')^T M (x-x') = (x-x')^T L^T L + (x-x') = (L x - L x')^T (L x- L x')` + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + + def score_pairs(self, pairs): + """Returns the learned Mahalanobis distance between pairs. + + This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}` + where ``M`` is the learned Mahalanobis matrix, for every pair of points + ``x`` and ``x'``. This corresponds to the euclidean distance between + embeddings of the points in a new space, obtained through a linear + transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e - + x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See + :class:`MahalanobisMixin`). + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + scores: `numpy.ndarray` of shape=(n_pairs,) + The learned Mahalanobis distance for every pair. + """ + pairs = check_input(pairs, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=2) + pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :]) + # (for MahalanobisMixin, the embedding is linear so we can just embed the + # difference) + return np.sqrt(np.sum(pairwise_diffs**2, axis=-1)) + + def transform(self, X): + """Embeds data points in the learned linear embedding space. + + Transforms samples in ``X`` into ``X_embedded``, samples inside a new + embedding space such that: ``X_embedded = X.dot(L.T)``, where ``L`` is + the learned linear transformation (See :class:`MahalanobisMixin`). + + Parameters + ---------- + X : `numpy.ndarray`, shape=(n_samples, n_features) + The data points to embed. + + Returns + ------- + X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims) + The embedded data points. + """ + X_checked = check_input(X, type_of_inputs='classic', estimator=self, + preprocessor=self.preprocessor_, + accept_sparse=True) + return X_checked.dot(self.transformer_.T) + + def metric(self): + return self.transformer_.T.dot(self.transformer_) + + def transformer_from_metric(self, metric): + """Computes the transformation matrix from the Mahalanobis matrix. + + Since by definition the metric `M` is positive semi-definite (PSD), it + admits a Cholesky decomposition: L = cholesky(M).T. However, currently the + computation of the Cholesky decomposition used does not support + non-definite matrices. If the metric is not definite, this method will + return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector + decomposition of M with the eigenvalues in the diagonal matrix w and the + columns of V being the eigenvectors. If M is diagonal, this method will + just return its elementwise square root (since the diagonalization of + the matrix is itself). + + Returns + ------- + L : (d x d) matrix + """ + + if np.allclose(metric, np.diag(np.diag(metric))): + return np.sqrt(metric) + elif not np.isclose(np.linalg.det(metric), 0): + return cholesky(metric).T else: - X = check_array(X, accept_sparse=True) - L = self.transformer() - return X.dot(L.T) + w, V = np.linalg.eigh(metric) + return V.T * np.sqrt(np.maximum(0, w[:, None])) + + +class _PairsClassifierMixin(BaseMetricLearner): + + _tuple_size = 2 # number of points in a tuple, 2 for pairs + + def predict(self, pairs): + """Predicts the learned metric between input pairs. (For now it just + calls decision function). + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to predict, with each row corresponding to two + points, or 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + return self.decision_function(pairs) + + def decision_function(self, pairs): + """Returns the learned metric between input pairs. + + Returns the learned metric value between samples in every pair. It should + ideally be low for similar samples and high for dissimilar samples. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to predict, with each row corresponding to two + points, or 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) + The predicted learned metric value between samples in every pair. + """ + pairs = check_input(pairs, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) + return self.score_pairs(pairs) + + def score(self, pairs, y): + """Computes score of pairs similarity prediction. + + Returns the ``roc_auc`` score of the fitted metric learner. It is + computed in the following way: for every value of a threshold + ``t`` we classify all pairs of samples where the predicted distance is + inferior to ``t`` as belonging to the "similar" class, and the other as + belonging to the "dissimilar" class, and we count false positive and + true positives as in a classical ``roc_auc`` curve. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs, with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. + + y : array-like, shape=(n_constraints,) + The corresponding labels. + + Returns + ------- + score : float + The ``roc_auc`` score. + """ + return roc_auc_score(y, self.decision_function(pairs)) + + +class _QuadrupletsClassifierMixin(BaseMetricLearner): + + _tuple_size = 4 # number of points in a tuple, 4 for quadruplets + + def predict(self, quadruplets): + """Predicts the ordering between sample distances in input quadruplets. + + For each quadruplet, returns 1 if the quadruplet is in the right order ( + first pair is more similar than second pair), and -1 if not. + + Parameters + ---------- + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to predict, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Predictions of the ordering of pairs, for each quadruplet. + """ + quadruplets = check_input(quadruplets, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) + return np.sign(self.decision_function(quadruplets)) + + def decision_function(self, quadruplets): + """Predicts differences between sample distances in input quadruplets. + + For each quadruplet of samples, computes the difference between the learned + metric of the first pair minus the learned metric of the second pair. + + Parameters + ---------- + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to predict, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. + + Returns + ------- + decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) + Metric differences. + """ + return (self.score_pairs(quadruplets[:, :2]) - + self.score_pairs(quadruplets[:, 2:])) + + def score(self, quadruplets, y=None): + """Computes score on input quadruplets + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that of the last two. + + Parameters + ---------- + quadruplets : array-like, shape=(n_quadruplets, 4, n_features) or + (n_quadruplets, 4) + 3D Array of quadruplets to score, with each row corresponding to four + points, or 2D array of indices of quadruplets if the metric learner + uses a preprocessor. + + y : Ignored, for scikit-learn compatibility. + + Returns + ------- + score : float + The quadruplets score. + """ + return -np.mean(self.predict(quadruplets)) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 8824450a..17523a46 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -100,3 +100,13 @@ def random_subset(all_labels, num_preserved=np.inf, random_state=np.random): partial_labels = np.array(all_labels, copy=True) partial_labels[idx] = -1 return Constraints(partial_labels) + +def wrap_pairs(X, constraints): + a = np.array(constraints[0]) + b = np.array(constraints[1]) + c = np.array(constraints[2]) + d = np.array(constraints[3]) + constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) + y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))]) + pairs = X[constraints] + return pairs, y \ No newline at end of file diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 8fc07873..10bc9582 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -10,27 +10,35 @@ from __future__ import absolute_import import numpy as np -from sklearn.utils.validation import check_array +from sklearn.base import TransformerMixin -from .base_metric import BaseMetricLearner +from .base_metric import MahalanobisMixin -class Covariance(BaseMetricLearner): - def __init__(self): - pass +class Covariance(MahalanobisMixin, TransformerMixin): + """Covariance metric (baseline method) - def metric(self): - return self.M_ + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + + def __init__(self, preprocessor=None): + super(Covariance, self).__init__(preprocessor) def fit(self, X, y=None): """ X : data matrix, (n x d) y : unused """ - self.X_ = check_array(X, ensure_min_samples=2) - self.M_ = np.cov(self.X_, rowvar = False) - if self.M_.ndim == 0: - self.M_ = 1./self.M_ + X = self._prepare_inputs(X, ensure_min_samples=2) + M = np.cov(X, rowvar = False) + if M.ndim == 0: + M = 1./M else: - self.M_ = np.linalg.inv(self.M_) + M = np.linalg.inv(M) + + self.transformer_ = self.transformer_from_metric(np.atleast_2d(M)) return self diff --git a/metric_learn/exceptions.py b/metric_learn/exceptions.py new file mode 100644 index 00000000..424d2c4f --- /dev/null +++ b/metric_learn/exceptions.py @@ -0,0 +1,12 @@ +""" +The :mod:`metric_learn.exceptions` module includes all custom warnings and +error classes used across metric-learn. +""" + + +class PreprocessorError(Exception): + + def __init__(self, original_error): + err_msg = ("An error occurred when trying to use the " + "preprocessor: {}").format(repr(original_error)) + super(PreprocessorError, self).__init__(err_msg) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 7b218895..48e71f56 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -17,17 +17,20 @@ import numpy as np from six.moves import xrange from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_array, check_X_y - -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from sklearn.utils.validation import check_array +from sklearn.base import TransformerMixin +from .base_metric import _PairsClassifierMixin, MahalanobisMixin +from .constraints import Constraints, wrap_pairs from ._util import vector_norm -class ITML(BaseMetricLearner): +class _BaseITML(MahalanobisMixin): """Information Theoretic Metric Learning (ITML)""" + + _tuple_size = 2 # constraints are pairs + def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, - A0=None, verbose=False): + A0=None, verbose=False, preprocessor=None): """Initialize ITML. Parameters @@ -44,23 +47,24 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, verbose : bool, optional if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.gamma = gamma self.max_iter = max_iter self.convergence_threshold = convergence_threshold self.A0 = A0 self.verbose = verbose + super(_BaseITML, self).__init__(preprocessor) - def _process_inputs(self, X, constraints, bounds): - self.X_ = X = check_array(X) - # check to make sure that no two constrained vectors are identical - a,b,c,d = constraints - no_ident = vector_norm(X[a] - X[b]) > 1e-9 - a, b = a[no_ident], b[no_ident] - no_ident = vector_norm(X[c] - X[d]) > 1e-9 - c, d = c[no_ident], d[no_ident] + def _fit(self, pairs, y, bounds=None): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') # init bounds if bounds is None: + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) else: assert len(bounds) == 2 @@ -68,35 +72,20 @@ def _process_inputs(self, X, constraints, bounds): self.bounds_[self.bounds_==0] = 1e-9 # init metric if self.A0 is None: - self.A_ = np.identity(X.shape[1]) + self.A_ = np.identity(pairs.shape[2]) else: self.A_ = check_array(self.A0) - return a,b,c,d - - def fit(self, X, constraints, bounds=None): - """Learn the ITML model. - - Parameters - ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying positive and (c,d) - negative pairs - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg - """ - a,b,c,d = self._process_inputs(X, constraints, bounds) gamma = self.gamma - num_pos = len(a) - num_neg = len(c) + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] + num_pos = len(pos_pairs) + num_neg = len(neg_pairs) _lambda = np.zeros(num_pos + num_neg) lambdaold = np.zeros_like(_lambda) gamma_proj = 1. if gamma is np.inf else gamma/(gamma+1.) pos_bhat = np.zeros(num_pos) + self.bounds_[0] neg_bhat = np.zeros(num_neg) + self.bounds_[1] - pos_vv = self.X_[a] - self.X_[b] - neg_vv = self.X_[c] - self.X_[d] + pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] + neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] A = self.A_ for it in xrange(self.max_iter): @@ -134,17 +123,57 @@ def fit(self, X, constraints, bounds=None): if self.verbose: print('itml converged at iter: %d, conv = %f' % (it, conv)) self.n_iter_ = it + + self.transformer_ = self.transformer_from_metric(self.A_) return self - def metric(self): - return self.A_ +class ITML(_BaseITML, _PairsClassifierMixin): + """Information Theoretic Metric Learning (ITML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + + def fit(self, pairs, y, bounds=None): + """Learn the ITML model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + bounds : list (pos,neg) pairs, optional + bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + + Returns + ------- + self : object + Returns the instance. + """ + return self._fit(pairs, y, bounds=bounds) + + +class ITML_Supervised(_BaseITML, TransformerMixin): + """Supervised version of Information Theoretic Metric Learning (ITML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See `transformer_from_metric`.) + """ -class ITML_Supervised(ITML): - """Information Theoretic Metric Learning (ITML)""" def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, num_labeled=np.inf, num_constraints=None, bounds=None, A0=None, - verbose=False): + verbose=False, preprocessor=None): """Initialize the supervised version of `ITML`. `ITML_Supervised` creates pairs of similar sample by taking same class @@ -169,10 +198,13 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, initial regularization matrix, defaults to identity verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - ITML.__init__(self, gamma=gamma, max_iter=max_iter, - convergence_threshold=convergence_threshold, - A0=A0, verbose=verbose) + _BaseITML.__init__(self, gamma=gamma, max_iter=max_iter, + convergence_threshold=convergence_threshold, + A0=A0, verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints self.bounds = bounds @@ -180,6 +212,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, def fit(self, X, y, random_state=np.random): """Create constraints from labels and learn the ITML model. + Parameters ---------- X : (n x d) matrix @@ -191,7 +224,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y, ensure_min_samples=2) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) @@ -201,4 +234,5 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return ITML.fit(self, X, pos_neg, bounds=self.bounds) + pairs, y = wrap_pairs(X, pos_neg) + return _BaseITML._fit(self, pairs, y, bounds=self.bounds) diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index 809f092b..2feff211 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -16,17 +16,23 @@ import warnings from six.moves import xrange from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_X_y +from sklearn.base import TransformerMixin +from .base_metric import MahalanobisMixin -from .base_metric import BaseMetricLearner - -class LFDA(BaseMetricLearner): +class LFDA(MahalanobisMixin, TransformerMixin): ''' Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction Sugiyama, ICML 2006 + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. ''' - def __init__(self, num_dims=None, k=None, embedding_type='weighted'): + + def __init__(self, num_dims=None, k=None, embedding_type='weighted', + preprocessor=None): ''' Initialize LFDA. @@ -44,20 +50,32 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted'): 'weighted' - weighted eigenvectors 'orthonormalized' - orthonormalized 'plain' - raw eigenvectors + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. ''' if embedding_type not in ('weighted', 'orthonormalized', 'plain'): raise ValueError('Invalid embedding_type: %r' % embedding_type) self.num_dims = num_dims self.embedding_type = embedding_type self.k = k + super(LFDA, self).__init__(preprocessor) + + def fit(self, X, y): + '''Fit the LFDA model. - def transformer(self): - return self.transformer_ + Parameters + ---------- + X : (n, d) array-like + Input data. - def _process_inputs(self, X, y): + y : (n,) array-like + Class labels, one per point of data. + ''' + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) unique_classes, y = np.unique(y, return_inverse=True) - self.X_, y = check_X_y(X, y) - n, d = self.X_.shape + n, d = X.shape num_classes = len(unique_classes) if self.num_dims is None: @@ -74,21 +92,6 @@ def _process_inputs(self, X, y): k = d - 1 else: k = int(self.k) - - return self.X_, y, num_classes, n, d, dim, k - - def fit(self, X, y): - '''Fit the LFDA model. - - Parameters - ---------- - X : (n, d) array-like - Input data. - - y : (n,) array-like - Class labels, one per point of data. - ''' - X, y, num_classes, n, d, dim, k_ = self._process_inputs(X, y) tSb = np.zeros((d,d)) tSw = np.zeros((d,d)) @@ -99,8 +102,8 @@ def fit(self, X, y): # classwise affinity matrix dist = pairwise_distances(Xc, metric='l2', squared=True) # distances to k-th nearest neighbor - k = min(k_, nc-1) - sigma = np.sqrt(np.partition(dist, k, axis=0)[:,k]) + k = min(k, nc - 1) + sigma = np.sqrt(np.partition(dist, k, axis=0)[:, k]) local_scale = np.outer(sigma, sigma) with np.errstate(divide='ignore', invalid='ignore'): diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index d1a41a33..1d7ddf2a 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -14,17 +14,16 @@ import warnings from collections import Counter from six.moves import xrange -from sklearn.utils.validation import check_X_y, check_array from sklearn.metrics import euclidean_distances - -from .base_metric import BaseMetricLearner +from sklearn.base import TransformerMixin +from .base_metric import MahalanobisMixin # commonality between LMNN implementations -class _base_LMNN(BaseMetricLearner): +class _base_LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, - verbose=False): + verbose=False, preprocessor=None): """Initialize the LMNN object. Parameters @@ -34,6 +33,10 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization: float, optional Weighting of pull and push terms, with 0.5 meaning equal weight. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.k = k self.min_iter = min_iter @@ -43,44 +46,41 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, self.convergence_tol = convergence_tol self.use_pca = use_pca self.verbose = verbose - - def transformer(self): - return self.L_ + super(_base_LMNN, self).__init__(preprocessor) # slower Python version class python_LMNN(_base_LMNN): - def _process_inputs(self, X, labels): - self.X_ = check_array(X, dtype=float, ensure_min_samples=2) - num_pts, num_dims = self.X_.shape - unique_labels, self.label_inds_ = np.unique(labels, return_inverse=True) + def fit(self, X, y): + k = self.k + reg = self.regularization + learn_rate = self.learn_rate + + X, y = self._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) + num_pts, num_dims = X.shape + unique_labels, self.label_inds_ = np.unique(y, return_inverse=True) if len(self.label_inds_) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) if self.use_pca: warnings.warn('use_pca does nothing for the python_LMNN implementation') - self.L_ = np.eye(num_dims) + self.transformer_ = np.eye(num_dims) required_k = np.bincount(self.label_inds_).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k) - def fit(self, X, y): - k = self.k - reg = self.regularization - learn_rate = self.learn_rate - self._process_inputs(X, y) - - target_neighbors = self._select_targets() - impostors = self._find_impostors(target_neighbors[:,-1]) + target_neighbors = self._select_targets(X) + impostors = self._find_impostors(target_neighbors[:, -1], X) if len(impostors) == 0: # L has already been initialized to an identity matrix return # sum outer products - dfG = _sum_outer_products(self.X_, target_neighbors.flatten(), - np.repeat(np.arange(self.X_.shape[0]), k)) + dfG = _sum_outer_products(X, target_neighbors.flatten(), + np.repeat(np.arange(X.shape[0]), k)) df = np.zeros_like(dfG) # storage @@ -91,14 +91,15 @@ def fit(self, X, y): a2[nn_idx] = np.array([]) # initialize L - L = self.L_ + L = self.transformer_ # first iteration: we compute variables (including objective and gradient) # at initialization point G, objective, total_active, df, a1, a2 = ( - self._loss_grad(L, dfG, impostors, 1, k, reg, target_neighbors, df, a1, - a2)) + self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df, + a1, a2)) + # main loop for it in xrange(2, self.max_iter): # then at each iteration, we try to find a value of L that has better # objective than the previous L, following the gradient: @@ -110,7 +111,7 @@ def fit(self, X, y): # retry we don t want to modify them several times (G_next, objective_next, total_active_next, df_next, a1_next, a2_next) = ( - self._loss_grad(L_next, dfG, impostors, it, k, reg, + self._loss_grad(X, L_next, dfG, impostors, it, k, reg, target_neighbors, df.copy(), list(a1), list(a2))) assert not np.isnan(objective) delta_obj = objective_next - objective @@ -143,14 +144,14 @@ def fit(self, X, y): print("LMNN didn't converge in %d steps." % self.max_iter) # store the last L - self.L_ = L + self.transformer_ = L self.n_iter_ = it return self - def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1, - a2): + def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, + a1, a2): # Compute pairwise distances under current metric - Lx = L.dot(self.X_.T).T + Lx = L.dot(X.T).T g0 = _inplace_paired_L2(*Lx[impostors]) Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) g1, g2 = Ni[impostors] @@ -174,16 +175,16 @@ def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1, targets = target_neighbors[:, nn_idx] PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) - df += _sum_outer_products(self.X_, PLUS[:, 0], PLUS[:, 1], pweight) + df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) - df -= _sum_outer_products(self.X_, MINUS[:, 0], MINUS[:, 1], mweight) + df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight) in_imp, out_imp = impostors - df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1]) - df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2]) + df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1]) + df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2]) - df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1]) - df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2]) + df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1]) + df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2]) a1[nn_idx] = act1 a2[nn_idx] = act2 @@ -195,18 +196,18 @@ def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1, objective += G.flatten().dot(L.T.dot(L).flatten()) return G, objective, total_active, df, a1, a2 - def _select_targets(self): - target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int) + def _select_targets(self, X): + target_neighbors = np.empty((X.shape[0], self.k), dtype=int) for label in self.labels_: inds, = np.nonzero(self.label_inds_ == label) - dd = euclidean_distances(self.X_[inds], squared=True) + dd = euclidean_distances(X[inds], squared=True) np.fill_diagonal(dd, np.inf) nn = np.argsort(dd)[..., :self.k] target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors): - Lx = self.transform() + def _find_impostors(self, furthest_neighbors, X): + Lx = self.transform(X) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: @@ -260,11 +261,19 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None): from modshogun import RealFeatures, MulticlassLabels class LMNN(_base_LMNN): + """Large Margin Nearest Neighbor (LMNN) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ def fit(self, X, y): - self.X_, y = check_X_y(X, y, dtype=float) + X, y = self._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) labels = MulticlassLabels(y) - self._lmnn = shogun_LMNN(RealFeatures(self.X_.T), labels, self.k) + self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k) self._lmnn.set_maxiter(self.max_iter) self._lmnn.set_obj_threshold(self.convergence_tol) self._lmnn.set_regularization(self.regularization) @@ -273,7 +282,7 @@ def fit(self, X, y): self._lmnn.train() else: self._lmnn.train(np.eye(X.shape[1])) - self.L_ = self._lmnn.get_linear_transform() + self.transformer_ = self._lmnn.get_linear_transform(X) return self except ImportError: diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 4e315b0b..73296b46 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -11,14 +11,18 @@ import numpy as np import scipy.linalg from six.moves import xrange -from sklearn.utils.validation import check_array, check_X_y +from sklearn.base import TransformerMixin -from .base_metric import BaseMetricLearner +from .base_metric import _QuadrupletsClassifierMixin, MahalanobisMixin from .constraints import Constraints -class LSML(BaseMetricLearner): - def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): +class _BaseLSML(MahalanobisMixin): + + _tuple_size = 4 # constraints are quadruplets + + def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False, + preprocessor=None): """Initialize LSML. Parameters @@ -29,17 +33,23 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): guess at a metric [default: inv(covariance(X))] verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.prior = prior self.tol = tol self.max_iter = max_iter self.verbose = verbose + super(_BaseLSML, self).__init__(preprocessor) - def _prepare_inputs(self, X, constraints, weights): - self.X_ = X = check_array(X) - a,b,c,d = constraints - self.vab_ = X[a] - X[b] - self.vcd_ = X[c] - X[d] + def _fit(self, quadruplets, y=None, weights=None): + quadruplets = self._prepare_inputs(quadruplets, + type_of_inputs='tuples') + + # check to make sure that no two constrained vectors are identical + self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] + self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :] if self.vab_.shape != self.vcd_.shape: raise ValueError('Constraints must have same length') if weights is None: @@ -48,28 +58,14 @@ def _prepare_inputs(self, X, constraints, weights): self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: + X = np.vstack({tuple(row) for row in + quadruplets.reshape(-1, quadruplets.shape[2])}) self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) self.M_ = np.linalg.inv(self.prior_inv_) else: self.M_ = self.prior self.prior_inv_ = np.linalg.inv(self.prior) - def metric(self): - return self.M_ - - def fit(self, X, constraints, weights=None): - """Learn the LSML model. - - Parameters - ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d]) - weights : (m,) array of floats, optional - scale factor for each constraint - """ - self._prepare_inputs(X, constraints, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. l_best = 0 @@ -103,6 +99,8 @@ def fit(self, X, constraints, weights=None): if self.verbose: print("Didn't converge after", it, "iterations. Final loss:", s_best) self.n_iter_ = it + + self.transformer_ = self.transformer_from_metric(self.M_) return self def _comparison_loss(self, metric): @@ -131,9 +129,52 @@ def _gradient(self, metric): return dMetric -class LSML_Supervised(LSML): +class LSML(_BaseLSML, _QuadrupletsClassifierMixin): + """Least Squared-residual Metric Learning (LSML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + + def fit(self, quadruplets, weights=None): + """Learn the LSML model. + + Parameters + ---------- + quadruplets : array-like, shape=(n_constraints, 4, n_features) or + (n_constraints, 4) + 3D array-like of quadruplets of points or 2D array of quadruplets of + indicators. In order to supervise the algorithm in the right way, we + should have the four samples ordered in a way such that: + d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) for all 0 <= i < + n_constraints. + weights : (n_constraints,) array of floats, optional + scale factor for each constraint + + Returns + ------- + self : object + Returns the instance. + """ + return self._fit(quadruplets, weights=weights) + + +class LSML_Supervised(_BaseLSML, TransformerMixin): + """Supervised version of Least Squared-residual Metric Learning (LSML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, - num_constraints=None, weights=None, verbose=False): + num_constraints=None, weights=None, verbose=False, + preprocessor=None): """Initialize the supervised version of `LSML`. `LSML_Supervised` creates quadruplets from labeled samples by taking two @@ -157,9 +198,12 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf, scale factor for each constraint verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - LSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, - verbose=verbose) + _BaseLSML.__init__(self, tol=tol, max_iter=max_iter, prior=prior, + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints self.weights = weights @@ -178,7 +222,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y, ensure_min_samples=2) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) @@ -186,6 +230,7 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) - pairs = c.positive_negative_pairs(num_constraints, same_length=True, - random_state=random_state) - return LSML.fit(self, X, pairs, weights=self.weights) + pos_neg = c.positive_negative_pairs(num_constraints, same_length=True, + random_state=random_state) + return _BaseLSML._fit(self, X[np.column_stack(pos_neg)], + weights=self.weights) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index ddcb698a..6b79638e 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -11,22 +11,31 @@ import sys import warnings import numpy as np +from sklearn.exceptions import ConvergenceWarning from sklearn.utils.fixes import logsumexp from scipy.optimize import minimize +from scipy.spatial.distance import pdist, squareform +from sklearn.base import TransformerMixin from sklearn.decomposition import PCA -from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_X_y -from sklearn.exceptions import ConvergenceWarning -from .base_metric import BaseMetricLearner + +from sklearn.metrics import pairwise_distances +from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps -class MLKR(BaseMetricLearner): - """Metric Learning for Kernel Regression (MLKR)""" +class MLKR(MahalanobisMixin, TransformerMixin): + """Metric Learning for Kernel Regression (MLKR) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000, - verbose=False): + verbose=False, preprocessor=None): """ Initialize MLKR. @@ -46,16 +55,30 @@ def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000, verbose : bool, optional (default=False) Whether to print progress messages or not. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.num_dims = num_dims self.A0 = A0 self.tol = tol self.max_iter = max_iter self.verbose = verbose + super(MLKR, self).__init__(preprocessor) - def _process_inputs(self, X, y): - self.X_, y = check_X_y(X, y, y_numeric=True) - n, d = self.X_.shape + def fit(self, X, y): + """ + Fit MLKR model + + Parameters + ---------- + X : (n x d) array of samples + y : (n) data labels + """ + X, y = self._prepare_inputs(X, y, y_numeric=True, + ensure_min_samples=2) + n, d = X.shape if y.shape[0] != n: raise ValueError('Data and label lengths mismatch: %d != %d' % (n, y.shape[0])) @@ -71,18 +94,6 @@ def _process_inputs(self, X, y): elif A.shape != (m, d): raise ValueError('A0 needs shape (%d,%d) but got %s' % ( m, d, A.shape)) - return self.X_, y, A - - def fit(self, X, y): - """ - Fit MLKR model - - Parameters - ---------- - X : (n x d) array of samples - y : (n) data labels - """ - X, y, A = self._process_inputs(X, y) # Measure the total training time train_time = time.time() @@ -105,9 +116,6 @@ def fit(self, X, y): return self - def transformer(self): - return self.transformer_ - def _loss(self, flatA, X, y): if self.n_iter_ == 0 and self.verbose: diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 02974f7e..596f085f 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -19,19 +19,22 @@ from __future__ import print_function, absolute_import, division import numpy as np from six.moves import xrange -from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_array, check_X_y, assert_all_finite +from sklearn.base import TransformerMixin +from sklearn.utils.validation import check_array, assert_all_finite -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import _PairsClassifierMixin, MahalanobisMixin +from .constraints import Constraints, wrap_pairs from ._util import vector_norm - -class MMC(BaseMetricLearner): +class _BaseMMC(MahalanobisMixin): """Mahalanobis Metric for Clustering (MMC)""" + + _tuple_size = 2 # constraints are pairs + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, - A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + A0=None, diagonal=False, diagonal_c=1.0, verbose=False, + preprocessor=None): """Initialize MMC. Parameters ---------- @@ -49,6 +52,9 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, metric learning verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. """ self.max_iter = max_iter self.max_proj = max_proj @@ -57,42 +63,15 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal = diagonal self.diagonal_c = diagonal_c self.verbose = verbose + super(_BaseMMC, self).__init__(preprocessor) - def fit(self, X, constraints): - """Learn the MMC model. - - Parameters - ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) - dissimilar pairs - """ - constraints = self._process_inputs(X, constraints) - if self.diagonal: - return self._fit_diag(X, constraints) - else: - return self._fit_full(X, constraints) - - def _process_inputs(self, X, constraints): - - self.X_ = X = check_array(X) - - # check to make sure that no two constrained vectors are identical - a,b,c,d = constraints - no_ident = vector_norm(X[a] - X[b]) > 1e-9 - a, b = a[no_ident], b[no_ident] - no_ident = vector_norm(X[c] - X[d]) > 1e-9 - c, d = c[no_ident], d[no_ident] - if len(a) == 0: - raise ValueError('No non-trivial similarity constraints given for MMC.') - if len(c) == 0: - raise ValueError('No non-trivial dissimilarity constraints given for MMC.') + def _fit(self, pairs, y): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') # init metric if self.A0 is None: - self.A_ = np.identity(X.shape[1]) + self.A_ = np.identity(pairs.shape[2]) if not self.diagonal: # Don't know why division by 10... it's in the original code # and seems to affect the overall scale of the learned metric. @@ -100,9 +79,12 @@ def _process_inputs(self, X, constraints): else: self.A_ = check_array(self.A0) - return a,b,c,d + if self.diagonal: + return self._fit_diag(pairs, y) + else: + return self._fit_full(pairs, y) - def _fit_full(self, X, constraints): + def _fit_full(self, pairs, y): """Learn full metric using MMC. Parameters @@ -113,17 +95,16 @@ def _fit_full(self, X, constraints): (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) dissimilar pairs """ - a,b,c,d = constraints - num_pos = len(a) - num_neg = len(c) - num_samples, num_dim = X.shape + num_dim = pairs.shape[2] error1 = error2 = 1e10 eps = 0.01 # error-bound of iterative projection on C1 and C2 A = self.A_ + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] + # Create weight vector from similar samples - pos_diff = X[a] - X[b] + pos_diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] w = np.einsum('ij,ik->jk', pos_diff, pos_diff).ravel() # `w` is the sum of all outer products of the rows in `pos_diff`. # The above `einsum` is equivalent to the much more inefficient: @@ -140,9 +121,10 @@ def _fit_full(self, X, constraints): cycle = 1 alpha = 0.1 # initial step size along gradient - - grad1 = self._fS1(X, a, b, A) # gradient of similarity constraint function - grad2 = self._fD1(X, c, d, A) # gradient of dissimilarity constraint function + grad1 = self._fS1(pos_pairs, A) # gradient of similarity + # constraint function + grad2 = self._fD1(neg_pairs, A) # gradient of dissimilarity + # constraint function M = self._grad_projection(grad1, grad2) # gradient of fD1 orthogonal to fS1 A_old = A.copy() @@ -183,8 +165,8 @@ def _fit_full(self, X, constraints): # max: g(A) >= 1 # here we suppose g(A) = fD(A) = \sum_{I,J \in D} sqrt(d_ij' A d_ij) - obj_previous = self._fD(X, c, d, A_old) # g(A_old) - obj = self._fD(X, c, d, A) # g(A) + obj_previous = self._fD(neg_pairs, A_old) # g(A_old) + obj = self._fD(neg_pairs, A) # g(A) if satisfy and (obj > obj_previous or cycle == 0): @@ -193,8 +175,8 @@ def _fit_full(self, X, constraints): # and update from the current A. alpha *= 1.05 A_old[:] = A - grad2 = self._fS1(X, a, b, A) - grad1 = self._fD1(X, c, d, A) + grad2 = self._fS1(pos_pairs, A) + grad1 = self._fD1(neg_pairs, A) M = self._grad_projection(grad1, grad2) A += alpha * M @@ -222,9 +204,11 @@ def _fit_full(self, X, constraints): print('mmc converged at iter %d, conv = %f' % (cycle, delta)) self.A_[:] = A_old self.n_iter_ = cycle + + self.transformer_ = self.transformer_from_metric(self.A_) return self - def _fit_diag(self, X, constraints): + def _fit_diag(self, pairs, y): """Learn diagonal metric using MMC. Parameters ---------- @@ -234,12 +218,9 @@ def _fit_diag(self, X, constraints): (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) dissimilar pairs """ - a,b,c,d = constraints - num_pos = len(a) - num_neg = len(c) - num_samples, num_dim = X.shape - - s_sum = np.sum((X[a] - X[b]) ** 2, axis=0) + num_dim = pairs.shape[2] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] + s_sum = np.sum((pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) ** 2, axis=0) it = 0 error = 1.0 @@ -249,19 +230,20 @@ def _fit_diag(self, X, constraints): while error > self.convergence_threshold and it < self.max_iter: - fD0, fD_1st_d, fD_2nd_d = self._D_constraint(X, c, d, w) + fD0, fD_1st_d, fD_2nd_d = self._D_constraint(neg_pairs, w) obj_initial = np.dot(s_sum, w) + self.diagonal_c * fD0 fS_1st_d = s_sum # first derivative of the similarity constraints gradient = fS_1st_d - self.diagonal_c * fD_1st_d # gradient of the objective hessian = -self.diagonal_c * fD_2nd_d + eps * np.eye(num_dim) # Hessian of the objective - step = np.dot(np.linalg.inv(hessian), gradient); + step = np.dot(np.linalg.inv(hessian), gradient) # Newton-Rapshon update # search over optimal lambda lambd = 1 # initial step-size w_tmp = np.maximum(0, w - lambd * step) - obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + obj = (np.dot(s_sum, w_tmp) + self.diagonal_c * + self._D_objective(neg_pairs, w_tmp)) assert_all_finite(obj) obj_previous = obj + 1 # just to get the while-loop started @@ -271,7 +253,8 @@ def _fit_diag(self, X, constraints): w_previous = w_tmp.copy() lambd /= reduction w_tmp = np.maximum(0, w - lambd * step) - obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + obj = (np.dot(s_sum, w_tmp) + self.diagonal_c * + self._D_objective(neg_pairs, w_tmp)) inner_it += 1 assert_all_finite(obj) @@ -282,18 +265,20 @@ def _fit_diag(self, X, constraints): it += 1 self.A_ = np.diag(w) + + self.transformer_ = self.transformer_from_metric(self.A_) return self - def _fD(self, X, c, d, A): + def _fD(self, neg_pairs, A): """The value of the dissimilarity constraint function. f = f(\sum_{ij \in D} distance(x_i, x_j)) i.e. distance can be L1: \sqrt{(x_i-x_j)A(x_i-x_j)'} """ - diff = X[c] - X[d] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] return np.log(np.sum(np.sqrt(np.sum(np.dot(diff, A) * diff, axis=1))) + 1e-6) - def _fD1(self, X, c, d, A): + def _fD1(self, neg_pairs, A): """The gradient of the dissimilarity constraint function w.r.t. A. For example, let distance by L1 norm: @@ -305,8 +290,8 @@ def _fD1(self, X, c, d, A): df/dA = f'(\sum_{ij \in D} \sqrt{tr(d_ij'*d_ij*A)}) * 0.5*(\sum_{ij \in D} (1/sqrt{tr(d_ij'*d_ij*A)})*(d_ij'*d_ij)) """ - dim = X.shape[1] - diff = X[c] - X[d] + dim = neg_pairs.shape[2] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] # outer products of all rows in `diff` M = np.einsum('ij,ik->ijk', diff, diff) # faster version of: dist = np.sqrt(np.sum(M * A[None,:,:], axis=(1,2))) @@ -316,7 +301,7 @@ def _fD1(self, X, c, d, A): sum_dist = dist.sum() return sum_deri / (sum_dist + 1e-6) - def _fS1(self, X, a, b, A): + def _fS1(self, pos_pairs, A): """The gradient of the similarity constraint function w.r.t. A. f = \sum_{ij}(x_i-x_j)A(x_i-x_j)' = \sum_{ij}d_ij*A*d_ij' @@ -325,8 +310,8 @@ def _fS1(self, X, a, b, A): Note that d_ij*A*d_ij' = tr(d_ij*A*d_ij') = tr(d_ij'*d_ij*A) so, d(d_ij*A*d_ij')/dA = d_ij'*d_ij """ - dim = X.shape[1] - diff = X[a] - X[b] + dim = pos_pairs.shape[2] + diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] return np.einsum('ij,ik->jk', diff, diff) # sum of outer products of all rows in `diff` def _grad_projection(self, grad1, grad2): @@ -335,15 +320,17 @@ def _grad_projection(self, grad1, grad2): gtemp /= np.linalg.norm(gtemp) return gtemp - def _D_objective(self, X, c, d, w): - return np.log(np.sum(np.sqrt(np.sum(((X[c] - X[d]) ** 2) * w[None,:], axis=1) + 1e-6))) + def _D_objective(self, neg_pairs, w): + return np.log(np.sum(np.sqrt(np.sum(((neg_pairs[:, 0, :] - + neg_pairs[:, 1, :]) ** 2) * + w[None,:], axis=1) + 1e-6))) - def _D_constraint(self, X, c, d, w): + def _D_constraint(self, neg_pairs, w): """Compute the value, 1st derivative, second derivative (Hessian) of a dissimilarity constraint function gF(sum_ij distance(d_ij A d_ij)) where A is a diagonal matrix (in the form of a column vector 'w'). """ - diff = X[c] - X[d] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] diff_sq = diff * diff dist = np.sqrt(diff_sq.dot(w)) sum_deri1 = np.einsum('ij,i', diff_sq, 0.5 / np.maximum(dist, 1e-6)) @@ -359,33 +346,52 @@ def _D_constraint(self, X, c, d, w): sum_deri2 / sum_dist - np.outer(sum_deri1, sum_deri1) / (sum_dist * sum_dist) ) - def metric(self): - return self.A_ - def transformer(self): - """Computes the transformation matrix from the Mahalanobis matrix. - L = V.T * w^(-1/2), with A = V*w*V.T being the eigenvector decomposition of A with - the eigenvalues in the diagonal matrix w and the columns of V being the eigenvectors. +class MMC(_BaseMMC, _PairsClassifierMixin): + """Mahalanobis Metric for Clustering (MMC) - The Cholesky decomposition cannot be applied here, since MMC learns only a positive - *semi*-definite Mahalanobis matrix. + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + + def fit(self, pairs, y): + """Learn the MMC model. + + Parameters + ---------- + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. Returns ------- - L : (d x d) matrix + self : object + Returns the instance. """ - if self.diagonal: - return np.sqrt(self.A_) - else: - w, V = np.linalg.eigh(self.A_) - return V.T * np.sqrt(np.maximum(0, w[:,None])) + return self._fit(pairs, y) -class MMC_Supervised(MMC): - """Mahalanobis Metric for Clustering (MMC)""" +class MMC_Supervised(_BaseMMC, TransformerMixin): + """Supervised version of Mahalanobis Metric for Clustering (MMC) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, num_labeled=np.inf, num_constraints=None, - A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + A0=None, diagonal=False, diagonal_c=1.0, verbose=False, + preprocessor=None): """Initialize the supervised version of `MMC`. `MMC_Supervised` creates pairs of similar sample by taking same class @@ -414,11 +420,14 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, metric learning verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, - convergence_threshold=convergence_threshold, - A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, - verbose=verbose) + _BaseMMC.__init__(self, max_iter=max_iter, max_proj=max_proj, + convergence_threshold=convergence_threshold, + A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -434,7 +443,7 @@ def fit(self, X, y, random_state=np.random): random_state : numpy.random.RandomState, optional If provided, controls random number generation. """ - X, y = check_X_y(X, y, ensure_min_samples=2) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) @@ -444,4 +453,5 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return MMC.fit(self, X, pos_neg) + pairs, y = wrap_pairs(X, pos_neg) + return _BaseMMC._fit(self, pairs, y) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 2f15c7af..81045287 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -4,29 +4,32 @@ """ from __future__ import absolute_import - import warnings import time import sys import numpy as np from scipy.optimize import minimize from sklearn.metrics import pairwise_distances -from sklearn.utils.validation import check_X_y from sklearn.exceptions import ConvergenceWarning +from sklearn.utils.fixes import logsumexp +from sklearn.base import TransformerMixin -try: # scipy.misc.logsumexp is deprecated in scipy 1.0.0 - from scipy.special import logsumexp -except ImportError: - from scipy.misc import logsumexp - -from .base_metric import BaseMetricLearner +from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps -class NCA(BaseMetricLearner): - def __init__(self, num_dims=None, max_iter=100, learning_rate='deprecated', - tol=None, verbose=False): +class NCA(MahalanobisMixin, TransformerMixin): + """Neighborhood Components Analysis (NCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + + def __init__(self, num_dims=None, max_iter=100, tol=None, verbose=False, + preprocessor=None): """Neighborhood Components Analysis Parameters @@ -38,13 +41,6 @@ def __init__(self, num_dims=None, max_iter=100, learning_rate='deprecated', max_iter : int, optional (default=100) Maximum number of iterations done by the optimization algorithm. - learning_rate : Not used - - .. deprecated:: 0.4.0 - `learning_rate` was deprecated in version 0.4.0 and will - be removed in 0.5.0. The current optimization algorithm does not need - to fix a learning rate. - tol : float, optional (default=None) Convergence tolerance for the optimization. @@ -53,24 +49,16 @@ def __init__(self, num_dims=None, max_iter=100, learning_rate='deprecated', """ self.num_dims = num_dims self.max_iter = max_iter - self.learning_rate = learning_rate # TODO: remove in v.0.5.0 self.tol = tol self.verbose = verbose - - def transformer(self): - return self.A_ + super(NCA, self).__init__(preprocessor) def fit(self, X, y): """ X: data matrix, (n x d) y: scalar labels, (n) """ - if self.learning_rate != 'deprecated': - warnings.warn('"learning_rate" parameter is not used.' - ' It has been deprecated in version 0.4 and will be' - 'removed in 0.5', DeprecationWarning) - - X, labels = check_X_y(X, y) + X, labels = self._prepare_inputs(X, y, ensure_min_samples=2) n, d = X.shape num_dims = self.num_dims if num_dims is None: @@ -98,8 +86,7 @@ def fit(self, X, y): self.n_iter_ = 0 opt_result = minimize(**optimizer_params) - self.X_ = X - self.A_ = opt_result.x.reshape(-1, X.shape[1]) + self.transformer_ = opt_result.x.reshape(-1, X.shape[1]) self.n_iter_ = opt_result.nit # Stop timer diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 327c5002..3380f4c9 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -16,9 +16,9 @@ import warnings from six.moves import xrange from sklearn import decomposition -from sklearn.utils.validation import check_array +from sklearn.base import TransformerMixin -from .base_metric import BaseMetricLearner +from .base_metric import MahalanobisMixin from .constraints import Constraints @@ -35,9 +35,16 @@ def _chunk_mean_centering(data, chunks): return chunk_mask, chunk_data -class RCA(BaseMetricLearner): - """Relevant Components Analysis (RCA)""" - def __init__(self, num_dims=None, pca_comps=None): +class RCA(MahalanobisMixin, TransformerMixin): + """Relevant Components Analysis (RCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + + def __init__(self, num_dims=None, pca_comps=None, preprocessor=None): """Initialize the learner. Parameters @@ -51,29 +58,17 @@ def __init__(self, num_dims=None, pca_comps=None): If ``0 < pca_comps < 1``, it is used as the minimum explained variance ratio. See sklearn.decomposition.PCA for more details. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ self.num_dims = num_dims self.pca_comps = pca_comps + super(RCA, self).__init__(preprocessor) - def transformer(self): - return self.transformer_ - - def _process_data(self, X): - self.X_ = X = check_array(X) - - # PCA projection to remove noise and redundant information. - if self.pca_comps is not None: - pca = decomposition.PCA(n_components=self.pca_comps) - X = pca.fit_transform(X) - M_pca = pca.components_ - else: - X -= X.mean(axis=0) - M_pca = None - - return X, M_pca - - def _check_dimension(self, rank): - d = self.X_.shape[1] + def _check_dimension(self, rank, X): + d = X.shape[1] if rank < d: warnings.warn('The inner covariance matrix is not invertible, ' 'so the transformation matrix may contain Nan values. ' @@ -92,7 +87,7 @@ def _check_dimension(self, rank): dim = self.num_dims return dim - def fit(self, data, chunks): + def fit(self, X, chunks): """Learn the RCA model. Parameters @@ -103,17 +98,26 @@ def fit(self, data, chunks): When ``chunks[i] == -1``, point i doesn't belong to any chunklet. When ``chunks[i] == j``, point i belongs to chunklet j. """ - data, M_pca = self._process_data(data) + X = self._prepare_inputs(X, ensure_min_samples=2) + + # PCA projection to remove noise and redundant information. + if self.pca_comps is not None: + pca = decomposition.PCA(n_components=self.pca_comps) + X_t = pca.fit_transform(X) + M_pca = pca.components_ + else: + X_t = X - X.mean(axis=0) + M_pca = None chunks = np.asanyarray(chunks, dtype=int) - chunk_mask, chunked_data = _chunk_mean_centering(data, chunks) + chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks) inner_cov = np.cov(chunked_data, rowvar=0, bias=1) - dim = self._check_dimension(np.linalg.matrix_rank(inner_cov)) + dim = self._check_dimension(np.linalg.matrix_rank(inner_cov), X_t) # Fisher Linear Discriminant projection - if dim < data.shape[1]: - total_cov = np.cov(data[chunk_mask], rowvar=0) + if dim < X_t.shape[1]: + total_cov = np.cov(X_t[chunk_mask], rowvar=0) tmp = np.linalg.lstsq(total_cov, inner_cov)[0] vals, vecs = np.linalg.eig(tmp) inds = np.argsort(vals)[:dim] @@ -136,8 +140,16 @@ def _inv_sqrtm(x): class RCA_Supervised(RCA): + """Supervised version of Relevant Components Analysis (RCA) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The learned linear transformation ``L``. + """ + def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, - chunk_size=2): + chunk_size=2, preprocessor=None): """Initialize the supervised version of `RCA`. `RCA_Supervised` creates chunks of similar points by first sampling a @@ -150,8 +162,12 @@ def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, embedding dimension (default: original dimension of data) num_chunks: int, optional chunk_size: int, optional + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - RCA.__init__(self, num_dims=num_dims, pca_comps=pca_comps) + RCA.__init__(self, num_dims=num_dims, pca_comps=pca_comps, + preprocessor=preprocessor) self.num_chunks = num_chunks self.chunk_size = chunk_size @@ -166,6 +182,7 @@ def fit(self, X, y, random_state=np.random): y : (n) data labels random_state : a random.seed object to fix the random_state if needed. """ + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) chunks = Constraints(y).chunks(num_chunks=self.num_chunks, chunk_size=self.chunk_size, random_state=random_state) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 1746ec7d..1892d176 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -10,18 +10,20 @@ from __future__ import absolute_import import numpy as np -from scipy.sparse.csgraph import laplacian +from sklearn.base import TransformerMixin from sklearn.covariance import graph_lasso from sklearn.utils.extmath import pinvh -from sklearn.utils.validation import check_array -from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .base_metric import MahalanobisMixin, _PairsClassifierMixin +from .constraints import Constraints, wrap_pairs -class SDML(BaseMetricLearner): +class _BaseSDML(MahalanobisMixin): + + _tuple_size = 2 # constraints are pairs + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - verbose=False): + verbose=False, preprocessor=None): """ Parameters ---------- @@ -36,59 +38,88 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose : bool, optional if True, prints information while learning + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be gotten like this: X[indices]. """ self.balance_param = balance_param self.sparsity_param = sparsity_param self.use_cov = use_cov self.verbose = verbose + super(_BaseSDML, self).__init__(preprocessor) + + def _fit(self, pairs, y): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') - def _prepare_inputs(self, X, W): - self.X_ = X = check_array(X) - W = check_array(W, accept_sparse=True) # set up prior M if self.use_cov: + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.M_ = pinvh(np.cov(X, rowvar = False)) else: - self.M_ = np.identity(X.shape[1]) - L = laplacian(W, normed=False) - return X.T.dot(L.dot(X)) + self.M_ = np.identity(pairs.shape[2]) + diff = pairs[:, 0] - pairs[:, 1] + loss_matrix = (diff.T * y).dot(diff) + P = self.M_ + self.balance_param * loss_matrix + emp_cov = pinvh(P) + # hack: ensure positive semidefinite + emp_cov = emp_cov.T.dot(emp_cov) + _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + + self.transformer_ = self.transformer_from_metric(self.M_) + return self + - def metric(self): - return self.M_ +class SDML(_BaseSDML, _PairsClassifierMixin): + """Sparse Distance Metric Learning (SDML) - def fit(self, X, W): + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + + def fit(self, pairs, y): """Learn the SDML model. Parameters ---------- - X : array-like, shape (n, d) - data matrix, where each row corresponds to a single instance - W : array-like, shape (n, n) - connectivity graph, with +1 for positive pairs and -1 for negative + pairs: array-like, shape=(n_constraints, 2, n_features) or + (n_constraints, 2) + 3D Array of pairs with each row corresponding to two points, + or 2D array of indices of pairs if the metric learner uses a + preprocessor. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. Returns ------- self : object Returns the instance. """ - loss_matrix = self._prepare_inputs(X, W) - P = self.M_ + self.balance_param * loss_matrix - emp_cov = pinvh(P) - # hack: ensure positive semidefinite - emp_cov = emp_cov.T.dot(emp_cov) - _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - return self + return self._fit(pairs, y) -class SDML_Supervised(SDML): +class SDML_Supervised(_BaseSDML, TransformerMixin): + """Supervised version of Sparse Distance Metric Learning (SDML) + + Attributes + ---------- + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See :meth:`transformer_from_metric`.) + """ + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, - num_labeled=np.inf, num_constraints=None, verbose=False): + num_labeled=np.inf, num_constraints=None, verbose=False, + preprocessor=None): """Initialize the supervised version of `SDML`. `SDML_Supervised` creates pairs of similar sample by taking same class samples, and pairs of dissimilar samples by taking different class samples. It then passes these pairs to `SDML` for training. - Parameters ---------- balance_param : float, optional @@ -105,10 +136,13 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, number of constraints to generate verbose : bool, optional if True, prints information while learning + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. """ - SDML.__init__(self, balance_param=balance_param, - sparsity_param=sparsity_param, use_cov=use_cov, - verbose=verbose) + _BaseSDML.__init__(self, balance_param=balance_param, + sparsity_param=sparsity_param, use_cov=use_cov, + verbose=verbose, preprocessor=preprocessor) self.num_labeled = num_labeled self.num_constraints = num_constraints @@ -130,7 +164,7 @@ def fit(self, X, y, random_state=np.random): self : object Returns the instance. """ - y = check_array(y, ensure_2d=False) + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: num_classes = len(np.unique(y)) @@ -138,5 +172,7 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) - adj = c.adjacency_matrix(num_constraints, random_state=random_state) - return SDML.fit(self, X, adj) + pos_neg = c.positive_negative_pairs(num_constraints, + random_state=random_state) + pairs, y = wrap_pairs(X, pos_neg) + return _BaseSDML._fit(self, pairs, y) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index e5bd071c..74bc25de 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -6,14 +6,16 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.datasets import load_iris, make_classification, make_regression -from numpy.testing import assert_array_almost_equal, assert_array_equal +from numpy.testing import assert_array_almost_equal from sklearn.utils.testing import assert_warns_message from sklearn.exceptions import ConvergenceWarning from sklearn.utils.validation import check_X_y -from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, - LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised) +from metric_learn import ( + LMNN, NCA, LFDA, Covariance, MLKR, MMC, + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) +# Import this specially for testing. +from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN @@ -42,7 +44,7 @@ def test_iris(self): cov = Covariance() cov.fit(self.iris_points) - csep = class_separation(cov.transform(), self.iris_labels) + csep = class_separation(cov.transform(self.iris_points), self.iris_labels) # deterministic result self.assertAlmostEqual(csep, 0.72981476) @@ -52,7 +54,7 @@ def test_iris(self): lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.iris_points, self.iris_labels) - csep = class_separation(lsml.transform(), self.iris_labels) + csep = class_separation(lsml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.8) # it's pretty terrible @@ -61,7 +63,7 @@ def test_iris(self): itml = ITML_Supervised(num_constraints=200) itml.fit(self.iris_points, self.iris_labels) - csep = class_separation(itml.transform(), self.iris_labels) + csep = class_separation(itml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) @@ -72,7 +74,8 @@ def test_iris(self): lmnn = LMNN_cls(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.iris_points, self.iris_labels) - csep = class_separation(lmnn.transform(), self.iris_labels) + csep = class_separation(lmnn.transform(self.iris_points), + self.iris_labels) self.assertLess(csep, 0.25) @@ -115,7 +118,7 @@ def test_iris(self): sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.iris_points, self.iris_labels, random_state=rs) - csep = class_separation(sdml.transform(), self.iris_labels) + csep = class_separation(sdml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) @@ -126,137 +129,33 @@ def test_iris(self): # Without dimension reduction nca = NCA(max_iter=(100000//n)) nca.fit(self.iris_points, self.iris_labels) - csep = class_separation(nca.transform(), self.iris_labels) + csep = class_separation(nca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # With dimension reduction - nca = NCA(max_iter=(100000//n), num_dims=2, tol=1e-9) + nca = NCA(max_iter=(100000//n), num_dims=2) nca.fit(self.iris_points, self.iris_labels) - csep = class_separation(nca.transform(), self.iris_labels) + csep = class_separation(nca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.20) - def test_finite_differences(self): - """Test gradient of loss function - - Assert that the gradient is almost equal to its finite differences - approximation. - """ - # Initialize the transformation `M`, as well as `X` and `y` and `NCA` - X, y = make_classification() - M = np.random.randn(np.random.randint(1, X.shape[1] + 1), X.shape[1]) - mask = y[:, np.newaxis] == y[np.newaxis, :] - nca = NCA() - nca.n_iter_ = 0 - - def fun(M): - return nca._loss_grad_lbfgs(M, X, mask)[0] - - def grad(M): - return nca._loss_grad_lbfgs(M, X, mask)[1].ravel() - - # compute relative error - rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M)) - np.testing.assert_almost_equal(rel_diff, 0., decimal=6) - - def test_simple_example(self): - """Test on a simple example. - - Puts four points in the input space where the opposite labels points are - next to each other. After transform the same labels points should be next - to each other. - - """ - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - nca = NCA(num_dims=2,) - nca.fit(X, y) - Xansformed = nca.transform(X) - np.testing.assert_equal(pairwise_distances(Xansformed).argsort()[:, 1], - np.array([2, 3, 0, 1])) - - def test_deprecation(self): - # test that the right deprecation message is thrown. - # TODO: remove in v.0.5 - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - nca = NCA(num_dims=2, learning_rate=0.01) - msg = ('"learning_rate" parameter is not used.' - ' It has been deprecated in version 0.4 and will be' - 'removed in 0.5') - assert_warns_message(DeprecationWarning, msg, nca.fit, X, y) - - def test_singleton_class(self): - X = self.iris_points - y = self.iris_labels - - # one singleton class: test fitting works - singleton_class = 1 - ind_singleton, = np.where(y == singleton_class) - y[ind_singleton] = 2 - y[ind_singleton[0]] = singleton_class - - nca = NCA(max_iter=30) - nca.fit(X, y) - - # One non-singleton class: test fitting works - ind_1, = np.where(y == 1) - ind_2, = np.where(y == 2) - y[ind_1] = 0 - y[ind_1[0]] = 1 - y[ind_2] = 0 - y[ind_2[0]] = 2 - - nca = NCA(max_iter=30) - nca.fit(X, y) - - # Only singleton classes: test fitting does nothing (the gradient - # must be null in this case, so the final matrix must stay like - # the initialization) - ind_0, = np.where(y == 0) - ind_1, = np.where(y == 1) - ind_2, = np.where(y == 2) - X = X[[ind_0[0], ind_1[0], ind_2[0]]] - y = y[[ind_0[0], ind_1[0], ind_2[0]]] - - EPS = np.finfo(float).eps - A = np.zeros((X.shape[1], X.shape[1])) - np.fill_diagonal(A, - 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, num_dims=X.shape[1]) - nca.fit(X, y) - assert_array_equal(nca.A_, A) - - def test_one_class(self): - # if there is only one class the gradient is null, so the final matrix - # must stay like the initialization - X = self.iris_points[self.iris_labels == 0] - y = self.iris_labels[self.iris_labels == 0] - EPS = np.finfo(float).eps - A = np.zeros((X.shape[1], X.shape[1])) - np.fill_diagonal(A, - 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, num_dims=X.shape[1]) - nca.fit(X, y) - assert_array_equal(nca.A_, A) - class TestLFDA(MetricTestCase): def test_iris(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.iris_points, self.iris_labels) - csep = class_separation(lfda.transform(), self.iris_labels) + csep = class_separation(lfda.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # Sanity checks for learned matrices. self.assertEqual(lfda.metric().shape, (4, 4)) - self.assertEqual(lfda.transformer().shape, (2, 4)) + self.assertEqual(lfda.transformer_.shape, (2, 4)) class TestRCA(MetricTestCase): def test_iris(self): rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.iris_points, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) def test_feature_null_variance(self): @@ -265,14 +164,14 @@ def test_feature_null_variance(self): # Apply PCA with the number of components rca = RCA_Supervised(num_dims=2, pca_comps=3, num_chunks=30, chunk_size=2) rca.fit(X, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) # Apply PCA with the minimum variance ratio rca = RCA_Supervised(num_dims=2, pca_comps=0.95, num_chunks=30, chunk_size=2) rca.fit(X, self.iris_labels) - csep = class_separation(rca.transform(), self.iris_labels) + csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) @@ -280,7 +179,7 @@ class TestMLKR(MetricTestCase): def test_iris(self): mlkr = MLKR() mlkr.fit(self.iris_points, self.iris_labels) - csep = class_separation(mlkr.transform(), self.iris_labels) + csep = class_separation(mlkr.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) def test_finite_differences(self): @@ -318,30 +217,29 @@ def test_iris(self): # Full metric mmc = MMC(convergence_threshold=0.01) - mmc.fit(self.iris_points, [a,b,c,d]) - expected = [[ 0.000514, 0.000868, -0.001195, -0.001703], - [ 0.000868, 0.001468, -0.002021, -0.002879], - [-0.001195, -0.002021, 0.002782, 0.003964], - [-0.001703, -0.002879, 0.003964, 0.005648]] + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) + expected = [[+0.000514, +0.000868, -0.001195, -0.001703], + [+0.000868, +0.001468, -0.002021, -0.002879], + [-0.001195, -0.002021, +0.002782, +0.003964], + [-0.001703, -0.002879, +0.003964, +0.005648]] assert_array_almost_equal(expected, mmc.metric(), decimal=6) # Diagonal metric mmc = MMC(diagonal=True) - mmc.fit(self.iris_points, [a,b,c,d]) + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) expected = [0, 0, 1.210220, 1.228596] - assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6) # Supervised Full mmc = MMC_Supervised() mmc.fit(self.iris_points, self.iris_labels) - csep = class_separation(mmc.transform(), self.iris_labels) + csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # Supervised Diagonal mmc = MMC_Supervised(diagonal=True) mmc.fit(self.iris_points, self.iris_labels) - csep = class_separation(mmc.transform(), self.iris_labels) + csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 4b132af4..c9c8fb57 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -5,74 +5,80 @@ class TestStringRepr(unittest.TestCase): def test_covariance(self): - self.assertEqual(str(metric_learn.Covariance()), "Covariance()") + self.assertEqual(str(metric_learn.Covariance()), + "Covariance(preprocessor=None)") def test_lmnn(self): self.assertRegexpMatches( str(metric_learn.LMNN()), r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " - r"max_iter=1000,\n min_iter=50, regularization=0.5, " - r"use_pca=True, verbose=False\)") + r"max_iter=1000,\n min_iter=50, preprocessor=None, " + r"regularization=0.5, use_pca=True,\n verbose=False\)") def test_nca(self): self.assertEqual(str(metric_learn.NCA()), - ("NCA(learning_rate='deprecated', max_iter=100, " - "num_dims=None, tol=None,\n verbose=False)")) + "NCA(max_iter=100, num_dims=None, preprocessor=None, " + "tol=None, verbose=False)") def test_lfda(self): self.assertEqual(str(metric_learn.LFDA()), - "LFDA(embedding_type='weighted', k=None, num_dims=None)") + "LFDA(embedding_type='weighted', k=None, num_dims=None, " + "preprocessor=None)") def test_itml(self): self.assertEqual(str(metric_learn.ITML()), """ ITML(A0=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000, - verbose=False) + preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.ITML_Supervised()), """ ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000, num_constraints=None, num_labeled=inf, - verbose=False) + preprocessor=None, verbose=False) """.strip('\n')) def test_lsml(self): self.assertEqual( str(metric_learn.LSML()), - "LSML(max_iter=1000, prior=None, tol=0.001, verbose=False)") + "LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, " + "verbose=False)") self.assertEqual(str(metric_learn.LSML_Supervised()), """ LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled=inf, - prior=None, tol=0.001, verbose=False, weights=None) + preprocessor=None, prior=None, tol=0.001, verbose=False, + weights=None) """.strip('\n')) def test_sdml(self): self.assertEqual(str(metric_learn.SDML()), - "SDML(balance_param=0.5, sparsity_param=0.01, " - "use_cov=True, verbose=False)") + "SDML(balance_param=0.5, preprocessor=None, " + "sparsity_param=0.01, use_cov=True,\n verbose=False)") self.assertEqual(str(metric_learn.SDML_Supervised()), """ SDML_Supervised(balance_param=0.5, num_constraints=None, num_labeled=inf, - sparsity_param=0.01, use_cov=True, verbose=False) + preprocessor=None, sparsity_param=0.01, use_cov=True, + verbose=False) """.strip('\n')) def test_rca(self): self.assertEqual(str(metric_learn.RCA()), - "RCA(num_dims=None, pca_comps=None)") + "RCA(num_dims=None, pca_comps=None, preprocessor=None)") self.assertEqual(str(metric_learn.RCA_Supervised()), "RCA_Supervised(chunk_size=2, num_chunks=100, " - "num_dims=None, pca_comps=None)") + "num_dims=None, pca_comps=None,\n " + "preprocessor=None)") def test_mlkr(self): self.assertEqual(str(metric_learn.MLKR()), - "MLKR(A0=None, max_iter=1000, num_dims=None, tol=None, " - "verbose=False)") + "MLKR(A0=None, max_iter=1000, num_dims=None, " + "preprocessor=None, tol=None,\n verbose=False)") def test_mmc(self): self.assertEqual(str(metric_learn.MMC()), """ MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0, - max_iter=100, max_proj=10000, verbose=False) + max_iter=100, max_proj=10000, preprocessor=None, verbose=False) """.strip('\n')) self.assertEqual(str(metric_learn.MMC_Supervised()), """ MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False, diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None, - num_labeled=inf, verbose=False) + num_labeled=inf, preprocessor=None, verbose=False) """.strip('\n')) if __name__ == '__main__': diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 707815ec..118f6b90 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -19,7 +19,7 @@ def setUpClass(self): def test_cov(self): cov = Covariance() cov.fit(self.X) - res_1 = cov.transform() + res_1 = cov.transform(self.X) cov = Covariance() res_2 = cov.fit_transform(self.X) @@ -30,7 +30,7 @@ def test_lsml_supervised(self): seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.X, self.y, random_state=seed) - res_1 = lsml.transform() + res_1 = lsml.transform(self.X) seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) @@ -42,7 +42,7 @@ def test_itml_supervised(self): seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) itml.fit(self.X, self.y, random_state=seed) - res_1 = itml.transform() + res_1 = itml.transform(self.X) seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) @@ -53,7 +53,7 @@ def test_itml_supervised(self): def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) - res_1 = lmnn.transform() + res_1 = lmnn.transform(self.X) lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) res_2 = lmnn.fit_transform(self.X, self.y) @@ -64,7 +64,7 @@ def test_sdml_supervised(self): seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.X, self.y, random_state=seed) - res_1 = sdml.transform() + res_1 = sdml.transform(self.X) seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) @@ -74,11 +74,11 @@ def test_sdml_supervised(self): def test_nca(self): n = self.X.shape[0] - nca = NCA(max_iter=(100000//n), learning_rate=0.01) + nca = NCA(max_iter=(100000//n)) nca.fit(self.X, self.y) - res_1 = nca.transform() + res_1 = nca.transform(self.X) - nca = NCA(max_iter=(100000//n), learning_rate=0.01) + nca = NCA(max_iter=(100000//n)) res_2 = nca.fit_transform(self.X, self.y) assert_array_almost_equal(res_1, res_2) @@ -86,7 +86,7 @@ def test_nca(self): def test_lfda(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.X, self.y) - res_1 = lfda.transform() + res_1 = lfda.transform(self.X) lfda = LFDA(k=2, num_dims=2) res_2 = lfda.fit_transform(self.X, self.y) @@ -100,7 +100,7 @@ def test_rca_supervised(self): seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) - res_1 = rca.transform() + res_1 = rca.transform(self.X) seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) @@ -111,7 +111,7 @@ def test_rca_supervised(self): def test_mlkr(self): mlkr = MLKR(num_dims=2) mlkr.fit(self.X, self.y) - res_1 = mlkr.transform() + res_1 = mlkr.transform(self.X) mlkr = MLKR(num_dims=2) res_2 = mlkr.fit_transform(self.X, self.y) @@ -122,7 +122,7 @@ def test_mmc_supervised(self): seed = np.random.RandomState(1234) mmc = MMC_Supervised(num_constraints=200) mmc.fit(self.X, self.y, random_state=seed) - res_1 = mmc.transform() + res_1 = mmc.transform(self.X) seed = np.random.RandomState(1234) mmc = MMC_Supervised(num_constraints=200) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py new file mode 100644 index 00000000..0d834f10 --- /dev/null +++ b/test/test_mahalanobis_mixin.py @@ -0,0 +1,169 @@ +from itertools import product + +import pytest +import numpy as np +from numpy.testing import assert_array_almost_equal +from scipy.spatial.distance import pdist, squareform +from sklearn import clone +from sklearn.utils import check_random_state +from sklearn.utils.testing import set_random_state + +from metric_learn._util import make_context + +from test.test_utils import ids_metric_learners, metric_learners + +RNG = check_random_state(0) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_score_pairs_pairwise(estimator, build_dataset): + # Computing pairwise scores should return a euclidean distance matrix. + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + + pairwise = model.score_pairs(np.array(list(product(X, X))))\ + .reshape(n_samples, n_samples) + + check_is_distance_matrix(pairwise) + + # a necessary condition for euclidean distance matrices: (see + # https://en.wikipedia.org/wiki/Euclidean_distance_matrix) + assert np.linalg.matrix_rank(pairwise**2) <= min(X.shape) + 2 + + # assert that this distance is coherent with pdist on embeddings + assert_array_almost_equal(squareform(pairwise), pdist(model.transform(X))) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_score_pairs_toy_example(estimator, build_dataset): + # Checks that score_pairs works on a toy example + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + pairs = np.stack([X[:10], X[10:20]], axis=1) + embedded_pairs = pairs.dot(model.transformer_.T) + distances = np.sqrt(np.sum((embedded_pairs[:, 1] - + embedded_pairs[:, 0])**2, + axis=-1)) + assert_array_almost_equal(model.score_pairs(pairs), distances) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_score_pairs_finite(estimator, build_dataset): + # tests that the score is finite + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + pairs = np.array(list(product(X, X))) + assert np.isfinite(model.score_pairs(pairs)).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_score_pairs_dim(estimator, build_dataset): + # scoring of 3D arrays should return 1D array (several tuples), + # and scoring of 2D arrays (one tuple) should return an error (like + # scikit-learn's error when scoring 1D arrays) + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + tuples = np.array(list(product(X, X))) + assert model.score_pairs(tuples).shape == (tuples.shape[0],) + context = make_context(estimator) + msg = ("3D array of formed tuples expected{}. Found 2D array " + "instead:\ninput={}. Reshape your data and/or use a preprocessor.\n" + .format(context, tuples[1])) + with pytest.raises(ValueError) as raised_error: + model.score_pairs(tuples[1]) + assert str(raised_error.value) == msg + + +def check_is_distance_matrix(pairwise): + assert (pairwise >= 0).all() # positivity + assert np.array_equal(pairwise, pairwise.T) # symmetry + assert (pairwise.diagonal() == 0).all() # identity + # triangular inequality + tol = 1e-15 + assert (pairwise <= pairwise[:, :, np.newaxis] + + pairwise[:, np.newaxis, :] + tol).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_embed_toy_example(estimator, build_dataset): + # Checks that embed works on a toy example + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + embedded_points = X.dot(model.transformer_.T) + assert_array_almost_equal(model.transform(X), embedded_points) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_embed_dim(estimator, build_dataset): + # Checks that the the dimension of the output space is as expected + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + assert model.transform(X).shape == X.shape + + # assert that ValueError is thrown if input shape is 1D + context = make_context(estimator) + err_msg = ("2D array of formed points expected{}. Found 1D array " + "instead:\ninput={}. Reshape your data and/or use a " + "preprocessor.\n".format(context, X[0])) + with pytest.raises(ValueError) as raised_error: + model.score_pairs(model.transform(X[0, :])) + assert str(raised_error.value) == err_msg + # we test that the shape is also OK when doing dimensionality reduction + if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}: + model.set_params(num_dims=2) + model.fit(input_data, labels) + assert model.transform(X).shape == (X.shape[0], 2) + # assert that ValueError is thrown if input shape is 1D + with pytest.raises(ValueError) as raised_error: + model.transform(model.transform(X[0, :])) + assert str(raised_error.value) == err_msg + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_embed_finite(estimator, build_dataset): + # Checks that embed returns vectors with finite values + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + assert np.isfinite(model.transform(X)).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_embed_is_linear(estimator, build_dataset): + # Checks that the embedding is linear + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(input_data, labels) + assert_array_almost_equal(model.transform(X[:10] + X[10:20]), + model.transform(X[:10]) + + model.transform(X[10:20])) + assert_array_almost_equal(model.transform(5 * X[:10]), + 5 * model.transform(X[:10])) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index f1e1a09d..d9dce685 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -1,10 +1,23 @@ -import numpy as np +import pytest import unittest from sklearn.utils.estimator_checks import check_estimator +from sklearn.base import TransformerMixin +from sklearn.pipeline import make_pipeline +from sklearn.utils import check_random_state +from sklearn.utils.estimator_checks import is_public_parameter +from sklearn.utils.testing import (assert_allclose_dense_sparse, + set_random_state) -from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) +from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, + ITML_Supervised, LSML_Supervised, + MMC_Supervised, RCA_Supervised, SDML_Supervised) +from sklearn import clone +import numpy as np +from sklearn.model_selection import (cross_val_score, cross_val_predict, + train_test_split, KFold) +from sklearn.utils.testing import _get_args +from test.test_utils import (metric_learners, ids_metric_learners, + mock_preprocessor) # Wrap the _Supervised methods with a deterministic wrapper for testing. @@ -68,5 +81,263 @@ def test_mmc(self): # check_estimator(RCA_Supervised) +RNG = check_random_state(0) + + +# ---------------------- Test scikit-learn compatibility ---------------------- + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_cross_validation_is_finite(estimator, build_dataset, + with_preprocessor): + """Tests that validation on metric-learn estimators returns something finite + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + if hasattr(estimator, "score"): + assert np.isfinite(cross_val_score(estimator, input_data, labels)).all() + if hasattr(estimator, "predict"): + assert np.isfinite(cross_val_predict(estimator, + input_data, labels)).all() + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_cross_validation_manual_vs_scikit(estimator, build_dataset, + with_preprocessor): + """Tests that if we make a manual cross-validation, the result will be the + same as scikit-learn's cross-validation (some code for generating the + folds is taken from scikit-learn). + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + n_splits = 3 + kfold = KFold(shuffle=False, n_splits=n_splits) + n_samples = input_data.shape[0] + fold_sizes = (n_samples // n_splits) * np.ones(n_splits, dtype=np.int) + fold_sizes[:n_samples % n_splits] += 1 + current = 0 + scores, predictions = [], np.zeros(input_data.shape[0]) + for fold_size in fold_sizes: + start, stop = current, current + fold_size + current = stop + test_slice = slice(start, stop) + train_mask = np.ones(input_data.shape[0], bool) + train_mask[test_slice] = False + y_train, y_test = labels[train_mask], labels[test_slice] + estimator.fit(input_data[train_mask], y_train) + if hasattr(estimator, "score"): + scores.append(estimator.score(input_data[test_slice], y_test)) + if hasattr(estimator, "predict"): + predictions[test_slice] = estimator.predict(input_data[test_slice]) + if hasattr(estimator, "score"): + assert all(scores == cross_val_score(estimator, input_data, labels, + cv=kfold)) + if hasattr(estimator, "predict"): + assert all(predictions == cross_val_predict(estimator, input_data, + labels, + cv=kfold)) + + +def check_score(estimator, tuples, y): + if hasattr(estimator, "score"): + score = estimator.score(tuples, y) + assert np.isfinite(score) + + +def check_predict(estimator, tuples): + if hasattr(estimator, "predict"): + y_predicted = estimator.predict(tuples) + assert len(y_predicted), len(tuples) + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_simple_estimator(estimator, build_dataset, with_preprocessor): + """Tests that fit, predict and scoring works. + """ + if any(hasattr(estimator, method) for method in ["predict", "score"]): + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + (tuples_train, tuples_test, y_train, + y_test) = train_test_split(input_data, labels, random_state=RNG) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + + estimator.fit(tuples_train, y_train) + check_score(estimator, tuples_test, y_test) + check_predict(estimator, tuples_test) + + +@pytest.mark.parametrize('estimator', [est[0] for est in metric_learners], + ids=ids_metric_learners) +@pytest.mark.parametrize('preprocessor', [None, mock_preprocessor]) +def test_no_attributes_set_in_init(estimator, preprocessor): + """Check setting during init. Adapted from scikit-learn.""" + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(type(estimator).__init__, "deprecated_original"): + return + + init_params = _get_args(type(estimator).__init__) + parents_init_params = [param for params_parent in + (_get_args(parent) for parent in + type(estimator).__mro__) + for param in params_parent] + + # Test for no setting apart from parameters during init + invalid_attr = (set(vars(estimator)) - set(init_params) - + set(parents_init_params)) + assert not invalid_attr, \ + ("Estimator %s should not set any attribute apart" + " from parameters during init. Found attributes %s." + % (type(estimator).__name__, sorted(invalid_attr))) + # Ensure that each parameter is set in init + invalid_attr = (set(init_params) - set(vars(estimator)) - + set(["self"])) + assert not invalid_attr, \ + ("Estimator %s should store all parameters" + " as an attribute during init. Did not find " + "attributes %s." % (type(estimator).__name__, sorted(invalid_attr))) + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_estimators_fit_returns_self(estimator, build_dataset, + with_preprocessor): + """Check if self is returned when calling fit""" + # Adapted from scikit-learn + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + assert estimator.fit(input_data, labels) is estimator + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_pipeline_consistency(estimator, build_dataset, + with_preprocessor): + # Adapted from scikit learn + # check that make_pipeline(est) gives same score as est + input_data, y, preprocessor, _ = build_dataset(with_preprocessor) + + def make_random_state(estimator, in_pipeline): + rs = {} + name_estimator = estimator.__class__.__name__ + if name_estimator[-11:] == '_Supervised': + name_param = 'random_state' + if in_pipeline: + name_param = name_estimator.lower() + '__' + name_param + rs[name_param] = check_random_state(0) + return rs + + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + pipeline = make_pipeline(estimator) + estimator.fit(input_data, y, **make_random_state(estimator, False)) + pipeline.fit(input_data, y, **make_random_state(estimator, True)) + + if hasattr(estimator, 'score'): + result = estimator.score(input_data, y) + result_pipe = pipeline.score(input_data, y) + assert_allclose_dense_sparse(result, result_pipe) + + if hasattr(estimator, 'predict'): + result = estimator.predict(input_data) + result_pipe = pipeline.predict(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + if issubclass(estimator.__class__, TransformerMixin): + if hasattr(estimator, 'transform'): + result = estimator.transform(input_data) + result_pipe = pipeline.transform(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + +@pytest.mark.parametrize('with_preprocessor',[True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_dict_unchanged(estimator, build_dataset, with_preprocessor): + # Adapted from scikit-learn + (input_data, labels, preprocessor, + to_transform) = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 + estimator.fit(input_data, labels) + + def check_dict(): + assert estimator.__dict__ == dict_before, ( + "Estimator changes __dict__ during %s" % method) + for method in ["predict", "decision_function", "predict_proba"]: + if hasattr(estimator, method): + dict_before = estimator.__dict__.copy() + getattr(estimator, method)(input_data) + check_dict() + if hasattr(estimator, "transform"): + dict_before = estimator.__dict__.copy() + # we transform only dataset of points + estimator.transform(to_transform) + check_dict() + + +@pytest.mark.parametrize('with_preprocessor',[True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_dont_overwrite_parameters(estimator, build_dataset, + with_preprocessor): + # Adapted from scikit-learn + # check that fit method only changes or sets private attributes + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + if hasattr(estimator, "num_dims"): + estimator.num_dims = 1 + dict_before_fit = estimator.__dict__.copy() + + estimator.fit(input_data, labels) + dict_after_fit = estimator.__dict__ + + public_keys_after_fit = [key for key in dict_after_fit.keys() + if is_public_parameter(key)] + + attrs_added_by_fit = [key for key in public_keys_after_fit + if key not in dict_before_fit.keys()] + + # check that fit doesn't add any public attribute + assert not attrs_added_by_fit, ( + "Estimator adds public attribute(s) during" + " the fit method." + " Estimators are only allowed to add private " + "attributes" + " either started with _ or ended" + " with _ but %s added" % ', '.join(attrs_added_by_fit)) + + # check that fit doesn't change any public attribute + attrs_changed_by_fit = [key for key in public_keys_after_fit + if (dict_before_fit[key] + is not dict_after_fit[key])] + + assert not attrs_changed_by_fit, ( + "Estimator changes public attribute(s) during" + " the fit method. Estimators are only allowed" + " to change attributes started" + " or ended with _, but" + " %s changed" % ', '.join(attrs_changed_by_fit)) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index e027d176..ab38d65e 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -19,60 +19,60 @@ def setUpClass(self): def test_cov(self): cov = Covariance() cov.fit(self.X) - L = cov.transformer() + L = cov.transformer_ assert_array_almost_equal(L.T.dot(L), cov.metric()) def test_lsml_supervised(self): seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.X, self.y, random_state=seed) - L = lsml.transformer() + L = lsml.transformer_ assert_array_almost_equal(L.T.dot(L), lsml.metric()) def test_itml_supervised(self): seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) itml.fit(self.X, self.y, random_state=seed) - L = itml.transformer() + L = itml.transformer_ assert_array_almost_equal(L.T.dot(L), itml.metric()) def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) - L = lmnn.transformer() + L = lmnn.transformer_ assert_array_almost_equal(L.T.dot(L), lmnn.metric()) def test_sdml_supervised(self): seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.X, self.y, random_state=seed) - L = sdml.transformer() + L = sdml.transformer_ assert_array_almost_equal(L.T.dot(L), sdml.metric()) def test_nca(self): n = self.X.shape[0] - nca = NCA(max_iter=(100000//n), learning_rate=0.01) + nca = NCA(max_iter=(100000//n)) nca.fit(self.X, self.y) - L = nca.transformer() + L = nca.transformer_ assert_array_almost_equal(L.T.dot(L), nca.metric()) def test_lfda(self): lfda = LFDA(k=2, num_dims=2) lfda.fit(self.X, self.y) - L = lfda.transformer() + L = lfda.transformer_ assert_array_almost_equal(L.T.dot(L), lfda.metric()) def test_rca_supervised(self): seed = np.random.RandomState(1234) rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) - L = rca.transformer() + L = rca.transformer_ assert_array_almost_equal(L.T.dot(L), rca.metric()) def test_mlkr(self): mlkr = MLKR(num_dims=2) mlkr.fit(self.X, self.y) - L = mlkr.transformer() + L = mlkr.transformer_ assert_array_almost_equal(L.T.dot(L), mlkr.metric()) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..de59e9ff --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,1013 @@ +import pytest +from collections import namedtuple +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.exceptions import DataConversionWarning +from sklearn.utils import check_random_state, shuffle +from sklearn.utils.testing import set_random_state +from sklearn.base import clone +from metric_learn._util import (check_input, make_context, preprocess_tuples, + make_name, preprocess_points, + check_collapsed_pairs) +from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, + LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, + MMC_Supervised, RCA_Supervised, SDML_Supervised, + Constraints) +from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, + _PairsClassifierMixin, + _QuadrupletsClassifierMixin) +from metric_learn.exceptions import PreprocessorError +from sklearn.datasets import make_regression, make_blobs, load_iris + + +SEED = 42 +RNG = check_random_state(SEED) + +Dataset = namedtuple('Dataset', ('data target preprocessor to_transform')) +# Data and target are what we will fit on. Preprocessor is the additional +# data if we use a preprocessor (which should be the default ArrayIndexer), +# and to_transform is some additional data that we would want to transform + + +@pytest.fixture +def build_classification(with_preprocessor=False): + """Basic array for testing when using a preprocessor""" + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + indices = shuffle(np.arange(X.shape[0]), random_state=SEED).astype(int) + if with_preprocessor: + return Dataset(indices, y[indices], X, indices) + else: + return Dataset(X[indices], y[indices], None, X[indices]) + + +@pytest.fixture +def build_regression(with_preprocessor=False): + """Basic array for testing when using a preprocessor""" + X, y = shuffle(*make_regression(n_samples=100, n_features=5, + random_state=SEED), + random_state=SEED) + indices = shuffle(np.arange(X.shape[0]), random_state=SEED).astype(int) + if with_preprocessor: + return Dataset(indices, y[indices], X, indices) + else: + return Dataset(X[indices], y[indices], None, X[indices]) + + +def build_data(): + input_data, labels = load_iris(return_X_y=True) + X, y = shuffle(input_data, labels, random_state=SEED) + num_constraints = 50 + constraints = ( + Constraints.random_subset(y, random_state=check_random_state(SEED))) + pairs = ( + constraints + .positive_negative_pairs(num_constraints, same_length=True, + random_state=check_random_state(SEED))) + return X, pairs + + +def build_pairs(with_preprocessor=False): + # builds a toy pairs problem + X, indices = build_data() + c = np.vstack([np.column_stack(indices[:2]), np.column_stack(indices[2:])]) + target = np.concatenate([np.ones(indices[0].shape[0]), + - np.ones(indices[0].shape[0])]) + c, target = shuffle(c, target, random_state=SEED) + if with_preprocessor: + # if preprocessor, we build a 2D array of pairs of indices + return Dataset(c, target, X, c[:, 0]) + else: + # if not, we build a 3D array of pairs of samples + return Dataset(X[c], target, None, X[c[:, 0]]) + + +def build_quadruplets(with_preprocessor=False): + # builds a toy quadruplets problem + X, indices = build_data() + c = np.column_stack(indices) + target = np.ones(c.shape[0]) # quadruplets targets are not used + # anyways + c, target = shuffle(c, target, random_state=SEED) + if with_preprocessor: + # if preprocessor, we build a 2D array of quadruplets of indices + return Dataset(c, target, X, c[:, 0]) + else: + # if not, we build a 3D array of quadruplets of samples + return Dataset(X[c], target, None, X[c[:, 0]]) + + +quadruplets_learners = [(LSML(), build_quadruplets)] +ids_quadruplets_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + quadruplets_learners])) + +pairs_learners = [(ITML(), build_pairs), + (MMC(max_iter=2), build_pairs), # max_iter=2 for faster + (SDML(), build_pairs), + ] +ids_pairs_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + pairs_learners])) + +classifiers = [(Covariance(), build_classification), + (LFDA(), build_classification), + (LMNN(), build_classification), + (NCA(), build_classification), + (RCA(), build_classification), + (ITML_Supervised(max_iter=5), build_classification), + (LSML_Supervised(), build_classification), + (MMC_Supervised(max_iter=5), build_classification), + (RCA_Supervised(num_chunks=10), build_classification), + (SDML_Supervised(), build_classification) + ] +ids_classifiers = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + classifiers])) + +regressors = [(MLKR(), build_regression)] +ids_regressors = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in regressors])) + +WeaklySupervisedClasses = (_PairsClassifierMixin, + _QuadrupletsClassifierMixin) + +tuples_learners = pairs_learners + quadruplets_learners +ids_tuples_learners = ids_pairs_learners + ids_quadruplets_learners + +supervised_learners = classifiers + regressors +ids_supervised_learners = ids_classifiers + ids_regressors + +metric_learners = tuples_learners + supervised_learners +ids_metric_learners = ids_tuples_learners + ids_supervised_learners + + +def mock_preprocessor(indices): + """A preprocessor for testing purposes that returns an all ones 3D array + """ + return np.ones((indices.shape[0], 3)) + + +@pytest.mark.parametrize('type_of_inputs', ['other', 'tuple', 'classics', 2, + int, NCA()]) +def test_check_input_invalid_type_of_inputs(type_of_inputs): + """Tests that an invalid type of inputs in check_inputs raises an error.""" + with pytest.raises(ValueError) as e: + check_input([[0.2, 2.1], [0.2, .8]], type_of_inputs=type_of_inputs) + msg = ("Unknown value {} for type_of_inputs. Valid values are " + "'classic' or 'tuples'.".format(type_of_inputs)) + assert str(e.value) == msg + + +# ---------------- test check_input with 'tuples' type_of_input' ------------ + + +@pytest.fixture +def tuples_prep(): + """Basic array for testing when using a preprocessor""" + tuples = np.array([[1, 2], + [2, 3]]) + return tuples + + +@pytest.fixture +def tuples_no_prep(): + """Basic array for testing when using no preprocessor""" + tuples = np.array([[[1., 2.3], [2.3, 5.3]], + [[2.3, 4.3], [0.2, 0.4]]]) + return tuples + + +@pytest.mark.parametrize('estimator, expected', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_make_context(estimator, expected): + """test the make_name function""" + assert make_context(estimator) == expected + + +@pytest.mark.parametrize('estimator, expected', + [(NCA(), "NCA"), ('NCA', "NCA"), (None, None)]) +def test_make_name(estimator, expected): + """test the make_name function""" + assert make_name(estimator) == expected + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_tuple_size(estimator, context, load_tuples, + preprocessor): + """Checks that the exception are raised if tuple_size is not the one + expected""" + tuples = load_tuples() + preprocessed_tuples = (preprocess_tuples(tuples, preprocessor) + if (preprocessor is not None and + tuples.ndim == 2) else tuples) + expected_msg = ("Tuples of 3 element(s) expected{}. Got tuples of 2 " + "element(s) instead (shape={}):\ninput={}.\n" + .format(context, preprocessed_tuples.shape, + preprocessed_tuples)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', tuple_size=3, + preprocessor=preprocessor, estimator=estimator) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('tuples, found, expected, preprocessor', + [(5, '0', '2D array of indicators or 3D array of ' + 'formed tuples', mock_preprocessor), + (5, '0', '3D array of formed tuples', None), + ([1, 2], '1', '2D array of indicators or 3D array ' + 'of formed tuples', mock_preprocessor), + ([1, 2], '1', '3D array of formed tuples', None), + ([[[[5]]]], '4', '2D array of indicators or 3D array' + ' of formed tuples', + mock_preprocessor), + ([[[[5]]]], '4', '3D array of formed tuples', None), + ([[1], [3]], '2', '3D array of formed ' + 'tuples', None)]) +def test_check_tuples_invalid_shape(estimator, context, tuples, found, + expected, preprocessor): + """Checks that a value error with the appropriate message is raised if + shape is invalid (not 2D with preprocessor or 3D with no preprocessor) + """ + tuples = np.array(tuples) + msg = ("{} expected{}{}. Found {}D array instead:\ninput={}. Reshape your " + "data{}.\n" + .format(expected, context, ' when using a preprocessor' + if preprocessor else '', found, tuples, + ' and/or use a preprocessor' if + (not preprocessor and tuples.ndim == 2) else '')) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, ensure_min_samples=0, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_check_tuples_invalid_n_features(estimator, context, tuples_no_prep): + """Checks that the right warning is printed if not enough features + Here we only test if no preprocessor (otherwise we don't ensure this) + """ + msg = ("Found array with 2 feature(s) (shape={}) while" + " a minimum of 3 is required{}.".format(tuples_no_prep.shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples_no_prep, type_of_inputs='tuples', + preprocessor=None, ensure_min_features=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_n_samples(estimator, context, load_tuples, + preprocessor): + """Checks that the right warning is printed if n_samples is too small""" + tuples = load_tuples() + msg = ("Found array with 2 sample(s) (shape={}) while a minimum of 3 " + "is required{}.".format((preprocess_tuples(tuples, preprocessor) + if (preprocessor is not None and + tuples.ndim == 2) else tuples).shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, + ensure_min_samples=3, estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_tuples, preprocessor', + [(tuples_prep, mock_preprocessor), + (tuples_no_prep, None), + (tuples_no_prep, mock_preprocessor)]) +def test_check_tuples_invalid_dtype_convertible(estimator, context, + load_tuples, preprocessor): + """Checks that a warning is raised if a convertible input is converted to + float""" + tuples = load_tuples().astype(object) # here the object conversion is + # useless for the tuples_prep case, but this allows to test the + # tuples_prep case + + if preprocessor is not None: # if the preprocessor is not None we + # overwrite it to have a preprocessor that returns objects + def preprocessor(indices): # + # preprocessor that returns objects + return np.ones((indices.shape[0], 3)).astype(object) + + msg = ("Data with input dtype object was converted to float64{}." + .format(context)) + with pytest.warns(DataConversionWarning) as raised_warning: + check_input(tuples, type_of_inputs='tuples', + preprocessor=preprocessor, dtype=np.float64, + warn_on_dtype=True, estimator=estimator) + assert str(raised_warning[0].message) == msg + + +def test_check_tuples_invalid_dtype_not_convertible_with_preprocessor( + tuples_prep): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float, when using a preprocessor + """ + + def preprocessor(indices): + # preprocessor that returns objects + return np.full((indices.shape[0], 3), 'a') + + with pytest.raises(ValueError): + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=preprocessor, dtype=np.float64) + + +def test_check_tuples_invalid_dtype_not_convertible_without_preprocessor( + tuples_no_prep): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float, when using no preprocessor + """ + tuples = np.full_like(tuples_no_prep, 'a', dtype=object) + with pytest.raises(ValueError): + check_input(tuples, type_of_inputs='tuples', + preprocessor=None, dtype=np.float64) + + +@pytest.mark.parametrize('tuple_size', [2, None]) +def test_check_tuples_valid_tuple_size(tuple_size, tuples_prep, tuples_no_prep): + """For inputs that have the right matrix dimension (2D or 3D for instance), + checks that checking the number of tuples (pairs, quadruplets, etc) raises + no warning if there is the right number of points in a tuple. + """ + with pytest.warns(None) as record: + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=mock_preprocessor, tuple_size=tuple_size) + check_input(tuples_no_prep, type_of_inputs='tuples', preprocessor=None, + tuple_size=tuple_size) + assert len(record) == 0 + + +@pytest.mark.parametrize('tuples', + [np.array([[2.5, 0.1, 2.6], + [1.6, 4.8, 9.1]]), + np.array([[2, 0, 2], + [1, 4, 9]]), + np.array([["img1.png", "img3.png"], + ["img2.png", "img4.png"]]), + [[2, 0, 2], + [1, 4, 9]], + [np.array([2, 0, 2]), + np.array([1, 4, 9])], + ((2, 0, 2), + (1, 4, 9)), + np.array([[[1.2, 2.2], [1.4, 3.3]], + [[2.6, 2.3], [3.4, 5.0]]])]) +def test_check_tuples_valid_with_preprocessor(tuples): + """Test that valid inputs when using a preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(tuples, type_of_inputs='tuples', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + +@pytest.mark.parametrize('tuples', + [np.array([[[2.5], [0.1], [2.6]], + [[1.6], [4.8], [9.1]], + [[5.6], [2.8], [6.1]]]), + np.array([[[2], [0], [2]], + [[1], [4], [9]], + [[1], [5], [3]]]), + [[[2], [0], [2]], + [[1], [4], [9]], + [[3], [4], [29]]], + (((2, 1), (0, 2), (2, 3)), + ((1, 2), (4, 4), (9, 3)), + ((3, 1), (4, 4), (29, 4)))]) +def test_check_tuples_valid_without_preprocessor(tuples): + """Test that valid inputs when using no preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(tuples, type_of_inputs='tuples', preprocessor=None) + assert len(record) == 0 + + +def test_check_tuples_behaviour_auto_dtype(tuples_no_prep): + """Checks that check_tuples allows by default every type if using a + preprocessor, and numeric types if using no preprocessor""" + tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']] + with pytest.warns(None) as record: + check_input(tuples_prep, type_of_inputs='tuples', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + with pytest.warns(None) as record: + check_input(tuples_no_prep, type_of_inputs='tuples') # numeric type + assert len(record) == 0 + + # not numeric type + tuples_no_prep = np.array([[['img1.png'], ['img2.png']], + [['img3.png'], ['img5.png']]]) + tuples_no_prep = tuples_no_prep.astype(object) + with pytest.raises(ValueError): + check_input(tuples_no_prep, type_of_inputs='tuples') + + +def test_check_tuples_invalid_complex_data(): + """Checks that the right error message is thrown if given complex data ( + this comes from sklearn's check_array's message)""" + tuples = np.array([[[1 + 2j, 3 + 4j], [5 + 7j, 5 + 7j]], + [[1 + 3j, 2 + 4j], [5 + 8j, 1 + 7j]]]) + msg = ("Complex data not supported\n" + "{}\n".format(tuples)) + with pytest.raises(ValueError) as raised_error: + check_input(tuples, type_of_inputs='tuples') + assert str(raised_error.value) == msg + + +# ------------- test check_input with 'classic' type_of_inputs ---------------- + + +@pytest.fixture +def points_prep(): + """Basic array for testing when using a preprocessor""" + points = np.array([1, 2]) + return points + + +@pytest.fixture +def points_no_prep(): + """Basic array for testing when using no preprocessor""" + points = np.array([[1., 2.3], + [2.3, 4.3]]) + return points + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('points, found, expected, preprocessor', + [(5, '0', '1D array of indicators or 2D array of ' + 'formed points', mock_preprocessor), + (5, '0', '2D array of formed points', None), + ([1, 2], '1', '2D array of formed points', None), + ([[[5]]], '3', '1D array of indicators or 2D ' + 'array of formed points', + mock_preprocessor), + ([[[5]]], '3', '2D array of formed points', None)]) +def test_check_classic_invalid_shape(estimator, context, points, found, + expected, preprocessor): + """Checks that a value error with the appropriate message is raised if + shape is invalid (valid being 1D or 2D with preprocessor or 2D with no + preprocessor) + """ + points = np.array(points) + msg = ("{} expected{}{}. Found {}D array instead:\ninput={}. Reshape your " + "data{}.\n" + .format(expected, context, ' when using a preprocessor' + if preprocessor else '', found, points, + ' and/or use a preprocessor' if + (not preprocessor and points.ndim == 1) else '')) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic', preprocessor=preprocessor, + ensure_min_samples=0, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +def test_check_classic_invalid_n_features(estimator, context, + points_no_prep): + """Checks that the right warning is printed if not enough features + Here we only test if no preprocessor (otherwise we don't ensure this) + """ + msg = ("Found array with 2 feature(s) (shape={}) while" + " a minimum of 3 is required{}.".format(points_no_prep.shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(points_no_prep, type_of_inputs='classic', preprocessor=None, + ensure_min_features=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_points, preprocessor', + [(points_prep, mock_preprocessor), + (points_no_prep, None), + (points_no_prep, mock_preprocessor)]) +def test_check_classic_invalid_n_samples(estimator, context, load_points, + preprocessor): + """Checks that the right warning is printed if n_samples is too small""" + points = load_points() + msg = ("Found array with 2 sample(s) (shape={}) while a minimum of 3 " + "is required{}.".format((preprocess_points(points, + preprocessor) + if preprocessor is not None and + points.ndim == 1 else + points).shape, + context)) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic', preprocessor=preprocessor, + ensure_min_samples=3, + estimator=estimator) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, context', + [(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")]) +@pytest.mark.parametrize('load_points, preprocessor', + [(points_prep, mock_preprocessor), + (points_no_prep, None), + (points_no_prep, mock_preprocessor)]) +def test_check_classic_invalid_dtype_convertible(estimator, context, + load_points, + preprocessor): + """Checks that a warning is raised if a convertible input is converted to + float""" + points = load_points().astype(object) # here the object conversion is + # useless for the points_prep case, but this allows to test the + # points_prep case + + if preprocessor is not None: # if the preprocessor is not None we + # overwrite it to have a preprocessor that returns objects + def preprocessor(indices): + # preprocessor that returns objects + return np.ones((indices.shape[0], 3)).astype(object) + + msg = ("Data with input dtype object was converted to float64{}." + .format(context)) + with pytest.warns(DataConversionWarning) as raised_warning: + check_input(points, type_of_inputs='classic', + preprocessor=preprocessor, dtype=np.float64, + warn_on_dtype=True, estimator=estimator) + assert str(raised_warning[0].message) == msg + + +@pytest.mark.parametrize('preprocessor, points', + [(mock_preprocessor, np.array([['a', 'b'], + ['e', 'b']])), + (None, np.array([[['b', 'v'], ['a', 'd']], + [['x', 'u'], ['c', 'a']]]))]) +def test_check_classic_invalid_dtype_not_convertible(preprocessor, points): + """Checks that a value error is thrown if attempting to convert an + input not convertible to float + """ + with pytest.raises(ValueError): + check_input(points, type_of_inputs='classic', + preprocessor=preprocessor, dtype=np.float64) + + +@pytest.mark.parametrize('points', + [["img1.png", "img3.png", "img2.png"], + np.array(["img1.png", "img3.png", "img2.png"]), + [2, 0, 2, 1, 4, 9], + range(10), + np.array([2, 0, 2]), + (2, 0, 2), + np.array([[1.2, 2.2], + [2.6, 2.3]])]) +def test_check_classic_valid_with_preprocessor(points): + """Test that valid inputs when using a preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(points, type_of_inputs='classic', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + +@pytest.mark.parametrize('points', + [np.array([[2.5, 0.1, 2.6], + [1.6, 4.8, 9.1], + [5.6, 2.8, 6.1]]), + np.array([[2, 0, 2], + [1, 4, 9], + [1, 5, 3]]), + [[2, 0, 2], + [1, 4, 9], + [3, 4, 29]], + ((2, 1, 0, 2, 2, 3), + (1, 2, 4, 4, 9, 3), + (3, 1, 4, 4, 29, 4))]) +def test_check_classic_valid_without_preprocessor(points): + """Test that valid inputs when using no preprocessor raises no warning""" + with pytest.warns(None) as record: + check_input(points, type_of_inputs='classic', preprocessor=None) + assert len(record) == 0 + + +def test_check_classic_by_default(): + """Checks that 'classic' is the default behaviour of check_input""" + assert (check_input([[2, 3], [3, 2]]) == + check_input([[2, 3], [3, 2]], type_of_inputs='classic')).all() + + +def test_check_classic_behaviour_auto_dtype(points_no_prep): + """Checks that check_input (for points) allows by default every type if + using a preprocessor, and numeric types if using no preprocessor""" + points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png'] + with pytest.warns(None) as record: + check_input(points_prep, type_of_inputs='classic', + preprocessor=mock_preprocessor) + assert len(record) == 0 + + with pytest.warns(None) as record: + check_input(points_no_prep, type_of_inputs='classic') # numeric type + assert len(record) == 0 + + # not numeric type + points_no_prep = np.array(['img1.png', 'img2.png', 'img3.png', + 'img5.png']) + points_no_prep = points_no_prep.astype(object) + with pytest.raises(ValueError): + check_input(points_no_prep, type_of_inputs='classic') + + +def test_check_classic_invalid_complex_data(): + """Checks that the right error message is thrown if given complex data ( + this comes from sklearn's check_array's message)""" + points = np.array([[[1 + 2j, 3 + 4j], [5 + 7j, 5 + 7j]], + [[1 + 3j, 2 + 4j], [5 + 8j, 1 + 7j]]]) + msg = ("Complex data not supported\n" + "{}\n".format(points)) + with pytest.raises(ValueError) as raised_error: + check_input(points, type_of_inputs='classic') + assert str(raised_error.value) == msg + + +# ----------------------------- Test preprocessor ----------------------------- + + +X = np.array([[0.89, 0.11, 1.48, 0.12], + [2.63, 1.08, 1.68, 0.46], + [1.00, 0.59, 0.62, 1.15]]) + + +class MockFileLoader: + """Preprocessor that takes a root file path at construction and simulates + fetching the file in the specific root folder when given the name of the + file""" + + def __init__(self, root): + self.root = root + self.folders = {'fake_root': {'img0.png': X[0], + 'img1.png': X[1], + 'img2.png': X[2] + }, + 'other_folder': {} # empty folder + } + + def __call__(self, path_list): + images = list() + for path in path_list: + images.append(self.folders[self.root][path]) + return np.array(images) + + +def mock_id_loader(list_of_indicators): + """A preprocessor as a function that takes indicators (strings) and + returns the corresponding samples""" + points = [] + for indicator in list_of_indicators: + points.append(X[int(indicator[2:])]) + return np.array(points) + + +tuples_list = [np.array([[0, 1], + [2, 1]]), + + np.array([['img0.png', 'img1.png'], + ['img2.png', 'img1.png']]), + + np.array([['id0', 'id1'], + ['id2', 'id1']]) + ] + +points_list = [np.array([0, 1, 2, 1]), + + np.array(['img0.png', 'img1.png', 'img2.png', 'img1.png']), + + np.array(['id0', 'id1', 'id2', 'id1']) + ] + +preprocessors = [X, MockFileLoader('fake_root'), mock_id_loader] + + +@pytest.fixture +def y_tuples(): + y = [-1, 1] + return y + + +@pytest.fixture +def y_points(): + y = [0, 1, 0, 0] + return y + + +@pytest.mark.parametrize('preprocessor, tuples', zip(preprocessors, + tuples_list)) +def test_preprocessor_weakly_supervised(preprocessor, tuples, y_tuples): + """Tests different ways to use the preprocessor argument: an array, + a class callable, and a function callable, with a weakly supervised + algorithm + """ + nca = ITML(preprocessor=preprocessor) + nca.fit(tuples, y_tuples) + + +@pytest.mark.parametrize('preprocessor, points', zip(preprocessors, + points_list)) +def test_preprocessor_supervised(preprocessor, points, y_points): + """Tests different ways to use the preprocessor argument: an array, + a class callable, and a function callable, with a supervised algorithm + """ + lfda = LFDA(preprocessor=preprocessor) + lfda.fit(points, y_points) + + +@pytest.mark.parametrize('estimator', ['NCA', NCA(), None]) +def test_preprocess_tuples_invalid_message(estimator): + """Checks that if the preprocessor does some weird stuff, the preprocessed + input is detected as weird. Checks this for preprocess_tuples.""" + + context = make_context(estimator) + (' after the preprocessor ' + 'has been applied') + + def preprocessor(sequence): + return np.ones((len(sequence), 2, 2)) # returns a 3D array instead of 2D + + with pytest.raises(ValueError) as raised_error: + check_input(np.ones((3, 2)), type_of_inputs='tuples', + preprocessor=preprocessor, estimator=estimator) + expected_msg = ("3D array of formed tuples expected{}. Found 4D " + "array instead:\ninput={}. Reshape your data{}.\n" + .format(context, np.ones((3, 2, 2, 2)), + ' and/or use a preprocessor' if preprocessor + is not None else '')) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('estimator', ['NCA', NCA(), None]) +def test_preprocess_points_invalid_message(estimator): + """Checks that if the preprocessor does some weird stuff, the preprocessed + input is detected as weird.""" + + context = make_context(estimator) + (' after the preprocessor ' + 'has been applied') + + def preprocessor(sequence): + return np.ones((len(sequence), 2, 2)) # returns a 3D array instead of 2D + + with pytest.raises(ValueError) as raised_error: + check_input(np.ones((3,)), type_of_inputs='classic', + preprocessor=preprocessor, estimator=estimator) + expected_msg = ("2D array of formed points expected{}. " + "Found 3D array instead:\ninput={}. Reshape your data{}.\n" + .format(context, np.ones((3, 2, 2)), + ' and/or use a preprocessor' if preprocessor + is not None else '')) + assert str(raised_error.value) == expected_msg + + +def test_preprocessor_error_message(): + """Tests whether the preprocessor returns a preprocessor error when there + is a problem using the preprocessor + """ + preprocessor = ArrayIndexer(np.array([[1.2, 3.3], [3.1, 3.2]])) + + # with tuples + X = np.array([[[2, 3], [3, 3]], [[2, 3], [3, 2]]]) + # There are less samples than the max index we want to preprocess + with pytest.raises(PreprocessorError): + preprocess_tuples(X, preprocessor) + + # with points + X = np.array([[1], [2], [3], [3]]) + with pytest.raises(PreprocessorError): + preprocess_points(X, preprocessor) + + +@pytest.mark.parametrize('input_data', [[[5, 3], [3, 2]], + ((5, 3), (3, 2)) + ]) +@pytest.mark.parametrize('indices', [[0, 1], (1, 0)]) +def test_array_like_indexer_array_like_valid_classic(input_data, indices): + """Checks that any array-like is valid in the 'preprocessor' argument, + and in the indices, for a classic input""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=input_data) + mock_algo._prepare_inputs(indices, type_of_inputs='classic') + + +@pytest.mark.parametrize('input_data', [[[5, 3], [3, 2]], + ((5, 3), (3, 2)) + ]) +@pytest.mark.parametrize('indices', [[[0, 1], [1, 0]], ((1, 0), (1, 0))]) +def test_array_like_indexer_array_like_valid_tuples(input_data, indices): + """Checks that any array-like is valid in the 'preprocessor' argument, + and in the indices, for a classic input""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=input_data) + mock_algo._prepare_inputs(indices, type_of_inputs='tuples') + + +@pytest.mark.parametrize('preprocessor', [4, NCA()]) +def test_error_message_check_preprocessor(preprocessor): + """Checks that if the preprocessor given is not an array-like or a + callable, the right error message is returned""" + class MockMetricLearner(MahalanobisMixin): + pass + + mock_algo = MockMetricLearner(preprocessor=preprocessor) + with pytest.raises(ValueError) as e: + mock_algo.check_preprocessor() + assert str(e.value) == ("Invalid type for the preprocessor: {}. You should " + "provide either None, an array-like object, " + "or a callable.".format(type(preprocessor))) + + +@pytest.mark.parametrize('estimator', [ITML(), LSML(), MMC(), SDML()], + ids=['ITML', 'LSML', 'MMC', 'SDML']) +def test_error_message_tuple_size(estimator): + """Tests that if a tuples learner is not given the good number of points + per tuple, it throws an error message""" + estimator = clone(estimator) + set_random_state(estimator) + invalid_pairs = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], + [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) + y = [1, 1] + with pytest.raises(ValueError) as raised_err: + estimator.fit(invalid_pairs, y) + expected_msg = ("Tuples of {} element(s) expected{}. Got tuples of 3 " + "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" + .format(estimator._tuple_size, make_context(estimator), + invalid_pairs)) + assert str(raised_err.value) == expected_msg + + +@pytest.mark.parametrize('estimator, _', metric_learners, + ids=ids_metric_learners) +def test_error_message_t_score_pairs(estimator, _): + """tests that if you want to score_pairs on triplets for instance, it returns + the right error message + """ + estimator = clone(estimator) + set_random_state(estimator) + estimator.check_preprocessor() + triplets = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], + [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) + with pytest.raises(ValueError) as raised_err: + estimator.score_pairs(triplets) + expected_msg = ("Tuples of 2 element(s) expected{}. Got tuples of 3 " + "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" + .format(make_context(estimator), triplets)) + assert str(raised_err.value) == expected_msg + + +def test_preprocess_tuples_simple_example(): + """Test the preprocessor on a very simple example of tuples to ensure the + result is as expected""" + array = np.array([[1, 2], + [2, 3], + [4, 5]]) + + def fun(row): + return np.array([[1, 1], [3, 3], [4, 4]]) + + expected_result = np.array([[[1, 1], [1, 1]], + [[3, 3], [3, 3]], + [[4, 4], [4, 4]]]) + + assert (preprocess_tuples(array, fun) == expected_result).all() + + +def test_preprocess_points_simple_example(): + """Test the preprocessor on very simple examples of points to ensure the + result is as expected""" + array = np.array([1, 2, 4]) + + def fun(row): + return [[1, 1], [3, 3], [4, 4]] + + expected_result = np.array([[1, 1], + [3, 3], + [4, 4]]) + + assert (preprocess_points(array, fun) == expected_result).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_same_with_or_without_preprocessor(estimator, build_dataset): + """Test that algorithms using a preprocessor behave consistently +# with their no-preprocessor equivalent + """ + dataset_indices = build_dataset(with_preprocessor=True) + dataset_formed = build_dataset(with_preprocessor=False) + X = dataset_indices.preprocessor + indicators_to_transform = dataset_indices.to_transform + formed_points_to_transform = dataset_formed.to_transform + (indices_train, indices_test, y_train, y_test, formed_train, + formed_test) = train_test_split(dataset_indices.data, + dataset_indices.target, + dataset_formed.data, + random_state=SEED) + + def make_random_state(estimator): + rs = {} + if estimator.__class__.__name__[-11:] == '_Supervised': + rs['random_state'] = check_random_state(SEED) + return rs + + estimator_with_preprocessor = clone(estimator) + set_random_state(estimator_with_preprocessor) + estimator_with_preprocessor.set_params(preprocessor=X) + estimator_with_preprocessor.fit(indices_train, y_train, + **make_random_state(estimator)) + + estimator_without_preprocessor = clone(estimator) + set_random_state(estimator_without_preprocessor) + estimator_without_preprocessor.set_params(preprocessor=None) + estimator_without_preprocessor.fit(formed_train, y_train, + **make_random_state(estimator)) + + estimator_with_prep_formed = clone(estimator) + set_random_state(estimator_with_prep_formed) + estimator_with_prep_formed.set_params(preprocessor=X) + estimator_with_prep_formed.fit(indices_train, y_train, + **make_random_state(estimator)) + + # test prediction methods + for method in ["predict", "decision_function"]: + if hasattr(estimator, method): + output_with_prep = getattr(estimator_with_preprocessor, + method)(indices_test) + output_without_prep = getattr(estimator_without_preprocessor, + method)(formed_test) + assert np.array(output_with_prep == output_without_prep).all() + output_with_prep = getattr(estimator_with_preprocessor, + method)(indices_test) + output_with_prep_formed = getattr(estimator_with_prep_formed, + method)(formed_test) + assert np.array(output_with_prep == output_with_prep_formed).all() + + # test score_pairs + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[[[[0, 2], [5, 3]]]]) + output_without_prep = estimator_without_preprocessor.score_pairs( + formed_points_to_transform[[[[0, 2], [5, 3]]]]) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[[[[0, 2], [5, 3]]]]) + output_without_prep = estimator_with_prep_formed.score_pairs( + formed_points_to_transform[[[[0, 2], [5, 3]]]]) + assert np.array(output_with_prep == output_without_prep).all() + + # test transform + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_without_preprocessor.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_with_prep_formed.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + +def test_check_collapsed_pairs_raises_no_error(): + """Checks that check_collapsed_pairs raises no error if no collapsed pairs + is present""" + pairs_ok = np.array([[[0.1, 3.3], [3.3, 0.1]], + [[0.1, 3.3], [3.3, 0.1]], + [[2.5, 8.1], [0.1, 3.3]]]) + check_collapsed_pairs(pairs_ok) + + +def test_check_collapsed_pairs_raises_error(): + """Checks that check_collapsed_pairs raises no error if no collapsed pairs + is present""" + pairs_not_ok = np.array([[[0.1, 3.3], [0.1, 3.3]], + [[0.1, 3.3], [3.3, 0.1]], + [[2.5, 8.1], [2.5, 8.1]]]) + with pytest.raises(ValueError) as e: + check_collapsed_pairs(pairs_not_ok) + assert str(e.value) == ("2 collapsed pairs found (where the left element is " + "the same as the right element), out of 3 pairs in" + " total.")