Skip to content

Commit 22d6206

Browse files
committed
cleanup keras einsum
1 parent 06a48f2 commit 22d6206

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

lambda_networks/tfkeras.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ def call(self, inputs, **kwargs):
7070

7171
k = Softmax()(k)
7272

73-
Lc = Lambda(lambda x: einsum('b u k m, b u v m -> b k v', x[0], x[1]))([k, v])
74-
Yc = Lambda(lambda x: einsum('b h k n, b k v -> b n h v', x[0], x[1]))([q, Lc])
73+
Lc = einsum('b u k m, b u v m -> b k v', k, v)
74+
Yc = einsum('b h k n, b k v -> b n h v', q, Lc)
7575

7676
if self.local_contexts:
7777
v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
7878
Lp = self.pos_conv(v)
7979
Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
80-
Yp = Lambda(lambda x: einsum('b h k n, b v k n -> b n h v', x[0], x[1]))([q, Lp])
80+
Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
8181
else:
82-
Lp = Lambda(lambda x: einsum('n m k u, b u v m -> b n k v', x[0], x[1]))([self.pos_emb, v])
83-
Yp = Lambda(lambda x: einsum('b h k n, b n k v -> b n h v', x[0], x[1]))([q, Lp])
82+
Lp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
83+
Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)
8484

8585
Y = Add()([Yc, Yp])
8686
out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.3.1',
6+
version = '0.3.2',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)