Skip to content
2 changes: 1 addition & 1 deletion examples/plot_metric_learning_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
#

# setting up LMNN
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6, init='random')
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6)

# fit the data!
lmnn.fit(X, y)
Expand Down
77 changes: 26 additions & 51 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
Large Margin Nearest Neighbor Metric learning (LMNN)
"""
# TODO: periodic recalculation of impostors, PCA initialization

from __future__ import print_function, absolute_import
import numpy as np
import warnings
Expand Down Expand Up @@ -208,31 +206,19 @@ def fit(self, X, y):
' (smallest class has %d)' % required_k)

target_neighbors = self._select_targets(X, label_inds)
impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds)
if len(impostors) == 0:
# L has already been initialized to an identity matrix
return

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
np.repeat(np.arange(X.shape[0]), k))
df = np.zeros_like(dfG)

# storage
a1 = [None]*k
a2 = [None]*k
for nn_idx in xrange(k):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize L
L = self.transformer_

# first iteration: we compute variables (including objective and gradient)
# at initialization point
G, objective, total_active, df, a1, a2 = (
self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df,
a1, a2))
G, objective, total_active = self._loss_grad(X, L, dfG, k,
reg, target_neighbors,
label_inds)

it = 1 # we already made one iteration

Expand All @@ -246,10 +232,9 @@ def fit(self, X, y):
# we compute the objective at next point
# we copy variables that can be modified by _loss_grad, because if we
# retry we don t want to modify them several times
(G_next, objective_next, total_active_next, df_next, a1_next,
a2_next) = (
self._loss_grad(X, L_next, dfG, impostors, it, k, reg,
target_neighbors, df.copy(), list(a1), list(a2)))
(G_next, objective_next, total_active_next) = (
self._loss_grad(X, L_next, dfG, k, reg, target_neighbors,
label_inds))
assert not np.isnan(objective)
delta_obj = objective_next - objective
if delta_obj > 0:
Expand All @@ -264,8 +249,7 @@ def fit(self, X, y):
# old variables to these new ones before next iteration and we
# slightly increase the learning rate
L = L_next
G, df, objective, total_active, a1, a2 = (
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
G, objective, total_active = G_next, objective_next, total_active_next
learn_rate *= 1.01

if self.verbose:
Expand All @@ -285,54 +269,45 @@ def fit(self, X, y):
self.n_iter_ = it
return self

def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
a1, a2):
def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
# Compute pairwise distances under current metric
Lx = L.dot(X.T).T
g0 = _inplace_paired_L2(*Lx[impostors])

# we need to find the furthest neighbor:
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
furthest_neighbors = np.take_along_axis(target_neighbors,
Ni.argmax(axis=1)[:, None], 1)
impostors = self._find_impostors(furthest_neighbors.ravel(), X,
label_inds, L)

g0 = _inplace_paired_L2(*Lx[impostors])

# we reorder the target neighbors
g1, g2 = Ni[impostors]
# compute the gradient
total_active = 0
for nn_idx in reversed(xrange(k)):
df = np.zeros((X.shape[1], X.shape[1]))
for nn_idx in reversed(xrange(k)): # note: reverse not useful here
act1 = g0 < g1[:, nn_idx]
act2 = g0 < g2[:, nn_idx]
total_active += act1.sum() + act2.sum()

if it > 1:
plus1 = act1 & ~a1[nn_idx]
minus1 = a1[nn_idx] & ~act1
plus2 = act2 & ~a2[nn_idx]
minus2 = a2[nn_idx] & ~act2
else:
plus1 = act1
plus2 = act2
minus1 = np.zeros(0, dtype=int)
minus2 = np.zeros(0, dtype=int)

targets = target_neighbors[:, nn_idx]
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
PLUS, pweight = _count_edges(act1, act2, impostors, targets)
df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight)
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight)

in_imp, out_imp = impostors
df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1])
df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2])

df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1])
df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2])
df -= _sum_outer_products(X, in_imp[act1], out_imp[act1])
df -= _sum_outer_products(X, in_imp[act2], out_imp[act2])

a1[nn_idx] = act1
a2[nn_idx] = act2
# do the gradient update
assert not np.isnan(df).any()
G = dfG * reg + df * (1 - reg)
G = L.dot(G)
# compute the objective function
objective = total_active * (1 - reg)
objective += G.flatten().dot(L.flatten())
return 2 * G, objective, total_active, df, a1, a2
return 2 * G, objective, total_active

def _select_targets(self, X, label_inds):
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
Expand All @@ -344,8 +319,8 @@ def _select_targets(self, X, label_inds):
target_neighbors[inds] = inds[nn]
return target_neighbors

def _find_impostors(self, furthest_neighbors, X, label_inds):
Lx = self.transform(X)
def _find_impostors(self, furthest_neighbors, X, label_inds, L):
Lx = X.dot(L.T)
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
impostors = []
for label in self.labels_[:-1]:
Expand Down
159 changes: 142 additions & 17 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import re
import pytest
import numpy as np
import scipy
from scipy.optimize import check_grad, approx_fprime
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.metrics import pairwise_distances, euclidean_distances
from sklearn.datasets import (load_iris, make_classification, make_regression,
make_spd_matrix)
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
Expand Down Expand Up @@ -242,25 +243,15 @@ def test_loss_grad_lbfgs(self):
lmnn.transformer_ = np.eye(n_components)

target_neighbors = lmnn._select_targets(X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
np.repeat(np.arange(X.shape[0]), k))
df = np.zeros_like(dfG)

# storage
a1 = [None]*k
a2 = [None]*k
for nn_idx in xrange(k):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize L
def loss_grad(flat_L):
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors,
1, k, reg, target_neighbors, df.copy(),
list(a1), list(a2))
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG,
k, reg, target_neighbors, label_inds)

def fun(x):
return loss_grad(x)[1]
Expand Down Expand Up @@ -292,6 +283,141 @@ def test_changed_behaviour_warning(self):
assert any(msg == str(wrn.message) for wrn in raised_warning)


def test_loss_func(capsys):
"""Test the loss function (and its gradient) on a simple example,
by comparing the results with the actual implementation of metric-learn,
with a very simple (but nonperformant) implementation"""

# toy dataset to use
X, y = make_classification(n_samples=10, n_classes=2,
n_features=6,
n_redundant=0, shuffle=True,
scale=[1, 1, 20, 20, 20, 20], random_state=42)

def hinge(a):
if a > 0:
return a, 1
else:
return 0, 0

def loss_fn(L, X, y, target_neighbors, reg):
L = L.reshape(-1, X.shape[1])
Lx = np.dot(X, L.T)
loss = 0
total_active = 0
grad = np.zeros_like(L)
for i in range(X.shape[0]):
for j in target_neighbors[i]:
loss += (1 - reg) * np.sum((Lx[i] - Lx[j]) ** 2)
grad += (1 - reg) * np.outer(Lx[i] - Lx[j], X[i] - X[j])
for l in range(X.shape[0]):
if y[i] != y[l]:
hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) -
np.sum((Lx[i] - Lx[l])**2))
total_active += active
if active:
loss += reg * hin
grad += (reg * (np.outer(Lx[i] - Lx[j], X[i] - X[j]) -
np.outer(Lx[i] - Lx[l], X[i] - X[l])))
grad = 2 * grad
return grad, loss, total_active

# we check that the gradient we have computed in the non-performant implem
# is indeed the true gradient on a toy example:

def _select_targets(X, y, k):
target_neighbors = np.empty((X.shape[0], k), dtype=int)
for label in np.unique(y):
inds, = np.nonzero(y == label)
dd = euclidean_distances(X[inds], squared=True)
np.fill_diagonal(dd, np.inf)
nn = np.argsort(dd)[..., :k]
target_neighbors[inds] = inds[nn]
return target_neighbors

target_neighbors = _select_targets(X, y, 2)
regularization = 0.5
n_features = X.shape[1]
x0 = np.random.randn(1, n_features)

def loss(x0):
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
regularization)[1]

def grad(x0):
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
regularization)[0].ravel()

scipy.optimize.check_grad(loss, grad, x0.ravel())

class LMNN_with_callback(LMNN):
""" We will use a callback to get the gradient (see later)
"""

def __init__(self, callback, *args, **kwargs):
self.callback = callback
super(LMNN_with_callback, self).__init__(*args, **kwargs)

def _loss_grad(self, *args, **kwargs):
grad, objective, total_active = (
super(LMNN_with_callback, self)._loss_grad(*args, **kwargs))
self.callback.append(grad)
return grad, objective, total_active

class LMNN_nonperformant(LMNN_with_callback):

def fit(self, X, y):
self.y = y
return super(LMNN_nonperformant, self).fit(X, y)

def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
grad, loss, total_active = loss_fn(L.ravel(), X, self.y,
target_neighbors, self.regularization)
self.callback.append(grad)
return grad, loss, total_active

mem1, mem2 = [], []
lmnn_perf = LMNN_with_callback(verbose=True, random_state=42,
init='identity', max_iter=30, callback=mem1)
lmnn_nonperf = LMNN_nonperformant(verbose=True, random_state=42,
init='identity', max_iter=30,
callback=mem2)
objectives, obj_diffs, learn_rate, total_active = (dict(), dict(), dict(),
dict())
for algo, name in zip([lmnn_perf, lmnn_nonperf], ['perf', 'nonperf']):
algo.fit(X, y)
out, _ = capsys.readouterr()
lines = re.split("\n+", out)
# we get every variable that is printed from the algorithm in verbose
num = '(-?\d+.?\d*(e[+|-]\d+)?)'
strings = [re.search("\d+ (?:{}) (?:{}) (?:(\d+)) (?:{})"
.format(num, num, num), s) for s in lines]
objectives[name] = [float(match.group(1)) for match in strings if match is
not None]
obj_diffs[name] = [float(match.group(3)) for match in strings if match is
not None]
total_active[name] = [float(match.group(5)) for match in strings if
match is not
None]
learn_rate[name] = [float(match.group(6)) for match in strings if match is
not None]
assert len(strings) >= 10 # we ensure that we actually did more than 10
# iterations
assert total_active[name][0] >= 2 # we ensure that we have some active
# constraints (that's the case we want to test)
# we remove the last element because it can be equal to the penultimate
# if the last gradient update is null
for i in range(len(mem1)):
np.testing.assert_allclose(lmnn_perf.callback[i],
lmnn_nonperf.callback[i],
err_msg='Gradient different at position '
'{}'.format(i))
np.testing.assert_allclose(objectives['perf'], objectives['nonperf'])
np.testing.assert_allclose(obj_diffs['perf'], obj_diffs['nonperf'])
np.testing.assert_allclose(total_active['perf'], total_active['nonperf'])
np.testing.assert_allclose(learn_rate['perf'], learn_rate['nonperf'])


@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
[1, 1, 0, 0], 3.0),
(np.array([[0], [1], [2], [3]]),
Expand All @@ -312,7 +438,7 @@ def test_toy_ex_lmnn(X, y, loss):
lmnn.transformer_ = np.eye(n_components)

target_neighbors = lmnn._select_targets(X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds, L)

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
Expand All @@ -327,9 +453,8 @@ def test_toy_ex_lmnn(X, y, loss):
a2[nn_idx] = np.array([])

# assert that the loss equals the one computed by hand
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k,
reg, target_neighbors, df, a1, a2)[1] == loss

assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, k,
reg, target_neighbors, label_inds)[1] == loss

def test_convergence_simple_example(capsys):
# LMNN should converge on this simple example, which it did not with
Expand Down