Skip to content

Commit c02e6e5

Browse files
committed
adagrad adaptive learning
1 parent c551344 commit c02e6e5

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

metric_learn/scml.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,16 @@ def _fit(self, triplets, basis=None, n_basis=None):
6060

6161
n_triplets = triplets.shape[0]
6262

63+
# weight vector
6364
w = np.zeros((1, n_basis))
65+
# avarage obj gradient wrt weights
6466
avg_grad_w = np.zeros((1, n_basis))
6567

68+
# l2 norm in time of all obj gradients wrt weights
69+
ada_grad_w = np.zeros((1, n_basis))
70+
# slack for not dividing by zero
71+
delta = 0.001
72+
6673
best_obj = np.inf
6774

6875
rng = check_random_state(self.random_state)
@@ -102,7 +109,9 @@ def _fit(self, triplets, basis=None, n_basis=None):
102109
axis=0, keepdims=True)/self.batch_size
103110
avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1)
104111

105-
scale_f = -np.sqrt(iter+1) / self.gamma
112+
ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w))
113+
114+
scale_f = -(iter+1) / self.gamma / (delta + ada_grad_w)
106115

107116
# proximal operator with negative trimming equivalent
108117
w = scale_f * np.minimum(avg_grad_w + self.beta, 0)

0 commit comments

Comments
 (0)