|
4 | 4 | """ |
5 | 5 | import numpy as np |
6 | 6 | import random |
| 7 | +import warnings |
7 | 8 | from six.moves import xrange |
| 9 | +from scipy.sparse import coo_matrix |
8 | 10 |
|
9 | | -# @TODO: consider creating a stateful class |
10 | | -# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386226 |
| 11 | +__all__ = ['Constraints'] |
11 | 12 |
|
12 | 13 |
|
13 | | -def adjacency_matrix(labels, num_points, num_constraints): |
14 | | - a, c = np.random.randint(len(labels), size=(2,num_constraints)) |
15 | | - b, d = np.empty((2, num_constraints), dtype=int) |
16 | | - for i,(al,cl) in enumerate(zip(labels[a],labels[c])): |
17 | | - b[i] = random.choice(np.nonzero(labels == al)[0]) |
18 | | - d[i] = random.choice(np.nonzero(labels != cl)[0]) |
19 | | - W = np.zeros((num_points,num_points)) |
20 | | - W[a,b] = 1 |
21 | | - W[c,d] = -1 |
22 | | - # make W symmetric |
23 | | - W[b,a] = 1 |
24 | | - W[d,c] = -1 |
25 | | - return W |
| 14 | +class Constraints(object): |
| 15 | + def __init__(self, partial_labels): |
| 16 | + '''partial_labels : int arraylike, -1 indicating unknown label''' |
| 17 | + partial_labels = np.asanyarray(partial_labels) |
| 18 | + self.num_points, = partial_labels.shape |
| 19 | + self.known_label_idx, = np.where(partial_labels >= 0) |
| 20 | + self.known_labels = partial_labels[self.known_label_idx] |
26 | 21 |
|
| 22 | + def adjacency_matrix(self, num_constraints): |
| 23 | + a, b, c, d = self.positive_negative_pairs(num_constraints) |
| 24 | + row = np.concatenate((a, c)) |
| 25 | + col = np.concatenate((b, d)) |
| 26 | + data = np.ones_like(row, dtype=int) |
| 27 | + data[len(a):] = -1 |
| 28 | + adj = coo_matrix((data, (row, col)), shape=(self.num_points,)*2) |
| 29 | + # symmetrize |
| 30 | + return adj + adj.T |
27 | 31 |
|
28 | | -def positive_negative_pairs(labels, num_points, num_constraints): |
29 | | - ac,bd = np.random.randint(num_points, size=(2,num_constraints)) |
30 | | - pos = labels[ac] == labels[bd] |
31 | | - a,c = ac[pos], ac[~pos] |
32 | | - b,d = bd[pos], bd[~pos] |
33 | | - return a,b,c,d |
| 32 | + def positive_negative_pairs(self, num_constraints, same_length=False): |
| 33 | + a, b = self._pairs(num_constraints, same_label=True) |
| 34 | + c, d = self._pairs(num_constraints, same_label=False) |
| 35 | + if same_length and len(a) != len(c): |
| 36 | + n = min(len(a), len(c)) |
| 37 | + return a[:n], b[:n], c[:n], d[:n] |
| 38 | + return a, b, c, d |
34 | 39 |
|
| 40 | + def _pairs(self, num_constraints, same_label=True, max_iter=10): |
| 41 | + num_labels = len(self.known_labels) |
| 42 | + ab = set() |
| 43 | + it = 0 |
| 44 | + while it < max_iter and len(ab) < num_constraints: |
| 45 | + nc = num_constraints - len(ab) |
| 46 | + for aidx in np.random.randint(num_labels, size=nc): |
| 47 | + if same_label: |
| 48 | + mask = self.known_labels[aidx] == self.known_labels |
| 49 | + mask[aidx] = False # avoid identity pairs |
| 50 | + else: |
| 51 | + mask = self.known_labels[aidx] != self.known_labels |
| 52 | + b_choices, = np.where(mask) |
| 53 | + if len(b_choices) > 0: |
| 54 | + ab.add((aidx, np.random.choice(b_choices))) |
| 55 | + it += 1 |
| 56 | + if len(ab) < num_constraints: |
| 57 | + warnings.warn("Only generated %d %s constraints (requested %d)" % ( |
| 58 | + len(ab), 'positive' if same_label else 'negative', num_constraints)) |
| 59 | + ab = np.array(list(ab)[:num_constraints], dtype=int) |
| 60 | + return self.known_label_idx[ab.T] |
35 | 61 |
|
36 | | -def relative_quadruplets(labels, num_constraints): |
37 | | - C = np.empty((num_constraints,4), dtype=int) |
38 | | - a, c = np.random.randint(len(labels), size=(2,num_constraints)) |
39 | | - for i,(al,cl) in enumerate(zip(labels[a],labels[c])): |
40 | | - C[i,1] = random.choice(np.nonzero(labels == al)[0]) |
41 | | - C[i,3] = random.choice(np.nonzero(labels != cl)[0]) |
42 | | - C[:,0] = a |
43 | | - C[:,2] = c |
44 | | - return C |
| 62 | + def chunks(self, num_chunks=100, chunk_size=2): |
| 63 | + chunks = -np.ones_like(self.known_label_idx, dtype=int) |
| 64 | + uniq, lookup = np.unique(self.known_labels, return_inverse=True) |
| 65 | + all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))] |
| 66 | + idx = 0 |
| 67 | + while idx < num_chunks and all_inds: |
| 68 | + c = random.randint(0, len(all_inds)-1) |
| 69 | + inds = all_inds[c] |
| 70 | + if len(inds) < chunk_size: |
| 71 | + del all_inds[c] |
| 72 | + continue |
| 73 | + ii = random.sample(inds, chunk_size) |
| 74 | + inds.difference_update(ii) |
| 75 | + chunks[ii] = idx |
| 76 | + idx += 1 |
| 77 | + if idx < num_chunks: |
| 78 | + raise ValueError('Unable to make %d chunks of %d examples each' % |
| 79 | + (num_chunks, chunk_size)) |
| 80 | + return chunks |
45 | 81 |
|
46 | | - |
47 | | -def chunks(Y, num_chunks=100, chunk_size=2, seed=None): |
48 | | - # @TODO: remove seed from params and use numpy RandomState |
49 | | - # https://github.com/all-umass/metric-learn/pull/19#discussion_r67386666 |
50 | | - random.seed(seed) |
51 | | - chunks = -np.ones_like(Y, dtype=int) |
52 | | - uniq, lookup = np.unique(Y, return_inverse=True) |
53 | | - all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))] |
54 | | - idx = 0 |
55 | | - while idx < num_chunks and all_inds: |
56 | | - c = random.randint(0, len(all_inds)-1) |
57 | | - inds = all_inds[c] |
58 | | - if len(inds) < chunk_size: |
59 | | - del all_inds[c] |
60 | | - continue |
61 | | - ii = random.sample(inds, chunk_size) |
62 | | - inds.difference_update(ii) |
63 | | - chunks[ii] = idx |
64 | | - idx += 1 |
65 | | - if idx < num_chunks: |
66 | | - raise ValueError('Unable to make %d chunks of %d examples each' % |
67 | | - (num_chunks, chunk_size)) |
68 | | - return chunks |
| 82 | + @staticmethod |
| 83 | + def random_subset(all_labels, num_preserved=np.inf): |
| 84 | + n = len(all_labels) |
| 85 | + num_ignored = max(0, n - num_preserved) |
| 86 | + idx = np.random.randint(n, size=num_ignored) |
| 87 | + partial_labels = np.array(all_labels, copy=True) |
| 88 | + partial_labels[idx] = -1 |
| 89 | + return Constraints(partial_labels) |
0 commit comments