Skip to content

Commit e31fb50

Browse files
More example code cleanup
1 parent 7fdf150 commit e31fb50

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

examples/sandwich.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
import numpy as np
6-
from numpy.random import normal
7-
import matplotlib.pyplot as pyplot
6+
from matplotlib import pyplot as plt
87
from sklearn.metrics import pairwise_distances
98
from sklearn.neighbors import NearestNeighbors
109

@@ -15,7 +14,7 @@
1514
def sandwich_demo():
1615
x, y = sandwich_data()
1716
knn = nearest_neighbors(x, k=2)
18-
ax = pyplot.subplot(3, 1, 1) # take the whole top row
17+
ax = plt.subplot(3, 1, 1) # take the whole top row
1918
plot_sandwich_data(x, y, ax)
2019
plot_neighborhood_graph(x, knn, y, ax)
2120
ax.set_title('input space')
@@ -31,27 +30,26 @@ def sandwich_demo():
3130
(LSML(), (x, C.relative_quadruplets(y, num_constraints)))
3231
]
3332

34-
for ax_num, (ml,args) in zip(xrange(3,7), mls):
33+
for ax_num, (ml,args) in zip(range(3,7), mls):
3534
ml.fit(*args)
3635
tx = ml.transform()
3736
ml_knn = nearest_neighbors(tx, k=2)
38-
ax = pyplot.subplot(3,2,ax_num)
37+
ax = plt.subplot(3,2,ax_num)
3938
plot_sandwich_data(tx, y, ax)
4039
plot_neighborhood_graph(tx, ml_knn, y, ax)
4140
ax.set_title('%s space' % ml.__class__.__name__)
4241
ax.set_xticks([])
4342
ax.set_yticks([])
44-
pyplot.show()
43+
plt.show()
4544

4645

4746
# TODO: use this somewhere
4847
def visualize_class_separation(X, labels):
49-
_, (ax1,ax2) = pyplot.subplots(ncols=2)
48+
_, (ax1,ax2) = plt.subplots(ncols=2)
5049
label_order = np.argsort(labels)
5150
ax1.imshow(pairwise_distances(X[label_order]), interpolation='nearest')
5251
ax2.imshow(pairwise_distances(labels[label_order,None]),
5352
interpolation='nearest')
54-
pyplot.show()
5553

5654

5755
def nearest_neighbors(X, k=5):
@@ -67,27 +65,30 @@ def sandwich_data():
6765
num_points = 9
6866
# distance between layers, the points of each class are in a layer
6967
dist = 0.7
70-
# memory pre-allocation
71-
x = np.zeros((num_classes*num_points, 2))
72-
y = np.zeros(num_classes*num_points, dtype=int)
73-
for i,j in zip(xrange(num_classes), xrange(-num_classes//2,num_classes//2+1)):
74-
for k,l in zip(xrange(num_points), xrange(-num_points//2,num_points//2+1)):
75-
x[i*num_points + k, :] = np.array([normal(l, 0.1), normal(dist*j, 0.1)])
76-
y[i*num_points:i*num_points + num_points] = i
77-
return x,y
78-
79-
80-
def plot_sandwich_data(x, y, axis=pyplot, cols='rbgmky'):
81-
for idx,val in enumerate(np.unique(y)):
68+
69+
data = np.zeros((num_classes, num_points, 2), dtype=float)
70+
labels = np.zeros((num_classes, num_points), dtype=int)
71+
72+
x_centers = np.arange(num_points, dtype=float) - num_points / 2
73+
y_centers = dist * (np.arange(num_classes, dtype=float) - num_classes / 2)
74+
for i, yc in enumerate(y_centers):
75+
for k, xc in enumerate(x_centers):
76+
data[i, k, 0] = np.random.normal(xc, 0.1)
77+
data[i, k, 1] = np.random.normal(yc, 0.1)
78+
labels[i,:] = i
79+
return data.reshape((-1, 2)), labels.ravel()
80+
81+
82+
def plot_sandwich_data(x, y, axis=plt, colors='rbgmky'):
83+
for idx, val in enumerate(np.unique(y)):
8284
xi = x[y==val]
83-
axis.scatter(xi[:,0], xi[:,1], s=50, facecolors='none',edgecolors=cols[idx])
85+
axis.scatter(*xi.T, s=50, facecolors='none', edgecolors=colors[idx])
8486

8587

86-
def plot_neighborhood_graph(x, nn, y, axis=pyplot, cols='rbgmky'):
87-
for i in xrange(x.shape[0]):
88-
xs = [x[i,0], x[nn[i,1], 0]]
89-
ys = [x[i,1], x[nn[i,1], 1]]
90-
axis.plot(xs, ys, cols[y[i]])
88+
def plot_neighborhood_graph(x, nn, y, axis=plt, colors='rbgmky'):
89+
for i, a in enumerate(x):
90+
b = x[nn[i,1]]
91+
axis.plot((a[0], b[0]), (a[1], b[1]), colors[y[i]])
9192

9293

9394
if __name__ == '__main__':

0 commit comments

Comments
 (0)