33"""
44
55import numpy as np
6- from numpy .random import normal
7- import matplotlib .pyplot as pyplot
6+ from matplotlib import pyplot as plt
87from sklearn .metrics import pairwise_distances
98from sklearn .neighbors import NearestNeighbors
109
1514def sandwich_demo ():
1615 x , y = sandwich_data ()
1716 knn = nearest_neighbors (x , k = 2 )
18- ax = pyplot .subplot (3 , 1 , 1 ) # take the whole top row
17+ ax = plt .subplot (3 , 1 , 1 ) # take the whole top row
1918 plot_sandwich_data (x , y , ax )
2019 plot_neighborhood_graph (x , knn , y , ax )
2120 ax .set_title ('input space' )
@@ -31,27 +30,26 @@ def sandwich_demo():
3130 (LSML (), (x , C .relative_quadruplets (y , num_constraints )))
3231 ]
3332
34- for ax_num , (ml ,args ) in zip (xrange (3 ,7 ), mls ):
33+ for ax_num , (ml ,args ) in zip (range (3 ,7 ), mls ):
3534 ml .fit (* args )
3635 tx = ml .transform ()
3736 ml_knn = nearest_neighbors (tx , k = 2 )
38- ax = pyplot .subplot (3 ,2 ,ax_num )
37+ ax = plt .subplot (3 ,2 ,ax_num )
3938 plot_sandwich_data (tx , y , ax )
4039 plot_neighborhood_graph (tx , ml_knn , y , ax )
4140 ax .set_title ('%s space' % ml .__class__ .__name__ )
4241 ax .set_xticks ([])
4342 ax .set_yticks ([])
44- pyplot .show ()
43+ plt .show ()
4544
4645
4746# TODO: use this somewhere
4847def visualize_class_separation (X , labels ):
49- _ , (ax1 ,ax2 ) = pyplot .subplots (ncols = 2 )
48+ _ , (ax1 ,ax2 ) = plt .subplots (ncols = 2 )
5049 label_order = np .argsort (labels )
5150 ax1 .imshow (pairwise_distances (X [label_order ]), interpolation = 'nearest' )
5251 ax2 .imshow (pairwise_distances (labels [label_order ,None ]),
5352 interpolation = 'nearest' )
54- pyplot .show ()
5553
5654
5755def nearest_neighbors (X , k = 5 ):
@@ -67,27 +65,30 @@ def sandwich_data():
6765 num_points = 9
6866 # distance between layers, the points of each class are in a layer
6967 dist = 0.7
70- # memory pre-allocation
71- x = np .zeros ((num_classes * num_points , 2 ))
72- y = np .zeros (num_classes * num_points , dtype = int )
73- for i ,j in zip (xrange (num_classes ), xrange (- num_classes // 2 ,num_classes // 2 + 1 )):
74- for k ,l in zip (xrange (num_points ), xrange (- num_points // 2 ,num_points // 2 + 1 )):
75- x [i * num_points + k , :] = np .array ([normal (l , 0.1 ), normal (dist * j , 0.1 )])
76- y [i * num_points :i * num_points + num_points ] = i
77- return x ,y
78-
79-
80- def plot_sandwich_data (x , y , axis = pyplot , cols = 'rbgmky' ):
81- for idx ,val in enumerate (np .unique (y )):
68+
69+ data = np .zeros ((num_classes , num_points , 2 ), dtype = float )
70+ labels = np .zeros ((num_classes , num_points ), dtype = int )
71+
72+ x_centers = np .arange (num_points , dtype = float ) - num_points / 2
73+ y_centers = dist * (np .arange (num_classes , dtype = float ) - num_classes / 2 )
74+ for i , yc in enumerate (y_centers ):
75+ for k , xc in enumerate (x_centers ):
76+ data [i , k , 0 ] = np .random .normal (xc , 0.1 )
77+ data [i , k , 1 ] = np .random .normal (yc , 0.1 )
78+ labels [i ,:] = i
79+ return data .reshape ((- 1 , 2 )), labels .ravel ()
80+
81+
82+ def plot_sandwich_data (x , y , axis = plt , colors = 'rbgmky' ):
83+ for idx , val in enumerate (np .unique (y )):
8284 xi = x [y == val ]
83- axis .scatter (xi [:, 0 ], xi [:, 1 ], s = 50 , facecolors = 'none' ,edgecolors = cols [idx ])
85+ axis .scatter (* xi . T , s = 50 , facecolors = 'none' , edgecolors = colors [idx ])
8486
8587
86- def plot_neighborhood_graph (x , nn , y , axis = pyplot , cols = 'rbgmky' ):
87- for i in xrange (x .shape [0 ]):
88- xs = [x [i ,0 ], x [nn [i ,1 ], 0 ]]
89- ys = [x [i ,1 ], x [nn [i ,1 ], 1 ]]
90- axis .plot (xs , ys , cols [y [i ]])
88+ def plot_neighborhood_graph (x , nn , y , axis = plt , colors = 'rbgmky' ):
89+ for i , a in enumerate (x ):
90+ b = x [nn [i ,1 ]]
91+ axis .plot ((a [0 ], b [0 ]), (a [1 ], b [1 ]), colors [y [i ]])
9192
9293
9394if __name__ == '__main__' :
0 commit comments