|
25 | 25 | from .base_metric import MahalanobisMixin |
26 | 26 |
|
27 | 27 |
|
28 | | -# commonality between LMNN implementations |
29 | | -class _base_LMNN(MahalanobisMixin, TransformerMixin): |
| 28 | +class LMNN(MahalanobisMixin, TransformerMixin): |
30 | 29 | def __init__(self, init=None, k=3, min_iter=50, max_iter=1000, |
31 | 30 | learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, |
32 | 31 | use_pca=True, verbose=False, preprocessor=None, |
@@ -114,11 +113,7 @@ def __init__(self, init=None, k=3, min_iter=50, max_iter=1000, |
114 | 113 | self.n_components = n_components |
115 | 114 | self.num_dims = num_dims |
116 | 115 | self.random_state = random_state |
117 | | - super(_base_LMNN, self).__init__(preprocessor) |
118 | | - |
119 | | - |
120 | | -# slower Python version |
121 | | -class python_LMNN(_base_LMNN): |
| 116 | + super(LMNN, self).__init__(preprocessor) |
122 | 117 |
|
123 | 118 | def fit(self, X, y): |
124 | 119 | if self.num_dims != 'deprecated': |
@@ -344,40 +339,3 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None): |
344 | 339 | if weights is not None: |
345 | 340 | return np.dot(Xab.T, Xab * weights[:,None]) |
346 | 341 | return np.dot(Xab.T, Xab) |
347 | | - |
348 | | - |
349 | | -try: |
350 | | - # use the fast C++ version, if available |
351 | | - from modshogun import LMNN as shogun_LMNN |
352 | | - from modshogun import RealFeatures, MulticlassLabels |
353 | | - |
354 | | - class LMNN(_base_LMNN): |
355 | | - """Large Margin Nearest Neighbor (LMNN) |
356 | | -
|
357 | | - Attributes |
358 | | - ---------- |
359 | | - n_iter_ : `int` |
360 | | - The number of iterations the solver has run. |
361 | | -
|
362 | | - transformer_ : `numpy.ndarray`, shape=(n_components, n_features) |
363 | | - The learned linear transformation ``L``. |
364 | | - """ |
365 | | - |
366 | | - def fit(self, X, y): |
367 | | - X, y = self._prepare_inputs(X, y, dtype=float, |
368 | | - ensure_min_samples=2) |
369 | | - labels = MulticlassLabels(y) |
370 | | - self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k) |
371 | | - self._lmnn.set_maxiter(self.max_iter) |
372 | | - self._lmnn.set_obj_threshold(self.convergence_tol) |
373 | | - self._lmnn.set_regularization(self.regularization) |
374 | | - self._lmnn.set_stepsize(self.learn_rate) |
375 | | - if self.use_pca: |
376 | | - self._lmnn.train() |
377 | | - else: |
378 | | - self._lmnn.train(np.eye(X.shape[1])) |
379 | | - self.transformer_ = self._lmnn.get_linear_transform(X) |
380 | | - return self |
381 | | - |
382 | | -except ImportError: |
383 | | - LMNN = python_LMNN |
0 commit comments