Skip to content

Commit c551344

Browse files
committed
batch grad refactored
1 parent bdf981e commit c551344

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

metric_learn/scml.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def _fit(self, triplets, basis=None, n_basis=None):
7373
# regularization part of obj function
7474
obj1 = np.sum(w)*self.beta
7575

76-
# Every triplet distance difference in the space given by L
77-
# plus a slack of one
76+
# Every triplet distance difference in the space given by L
77+
# plus a slack of one
7878
slack_val = 1 + np.matmul(dist_diff, w.T)
79-
# Mask of places with positive slack
79+
# Mask of places with positive slack
8080
slack_mask = slack_val > 0
8181

8282
# loss function of learning task part of obj function
@@ -96,13 +96,13 @@ def _fit(self, triplets, basis=None, n_basis=None):
9696
idx = rand_int[iter]
9797

9898
slack_val = 1 + np.matmul(dist_diff[idx, :], w.T)
99-
10099
slack_mask = np.squeeze(slack_val > 0, axis=1)
101-
avg_grad_w = ((iter * avg_grad_w + np.sum(dist_diff[idx[slack_mask], :],
102-
axis=0, keepdims=True))
103-
/ (iter+1))
104100

105-
scale_f = -np.sqrt(iter+1) / (self.gamma*self.batch_size)
101+
grad_w = np.sum(dist_diff[idx[slack_mask], :],
102+
axis=0, keepdims=True)/self.batch_size
103+
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
104+
105+
scale_f = -np.sqrt(iter+1) / self.gamma
106106

107107
# proximal operator with negative trimming equivalent
108108
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)

0 commit comments

Comments
 (0)