Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import warnings
from collections import Counter
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.utils.validation import check_X_y, check_array
from sklearn.metrics import euclidean_distances

from .base_metric import BaseMetricLearner

Expand Down Expand Up @@ -185,7 +185,7 @@ def _select_targets(self):
target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int)
for label in self.labels_:
inds, = np.nonzero(self.label_inds_ == label)
dd = pairwise_distances(self.X_[inds])
dd = euclidean_distances(self.X_[inds], self.X_[inds], squared=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave out the second argument, similar to how pairwise_distances works.

np.fill_diagonal(dd, np.inf)
nn = np.argsort(dd)[..., :self.k]
target_neighbors[inds] = inds[nn]
Expand All @@ -198,7 +198,7 @@ def _find_impostors(self, furthest_neighbors):
for label in self.labels_[:-1]:
in_inds, = np.nonzero(self.label_inds_ == label)
out_inds, = np.nonzero(self.label_inds_ > label)
dist = pairwise_distances(Lx[out_inds], Lx[in_inds])
dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True)
i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None])
i2,j2 = np.nonzero(dist < margin_radii[in_inds])
i = np.hstack((i1,i2))
Expand Down