Skip to content

Commit 1b950fa

Browse files
committed
reorder einsums to avoid contiguous call at the end
1 parent b6c8874 commit 1b950fa

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

lambda_networks/lambda_networks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,16 @@ def forward(self, x):
6565
k = k.softmax(dim=-1)
6666

6767
λc = einsum('b u k m, b u v m -> b k v', k, v)
68-
Yc = einsum('b h k n, b k v -> b n h v', q, λc)
68+
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
6969

7070
if self.local_contexts:
7171
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
7272
λp = self.pos_conv(v)
73-
Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
73+
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
7474
else:
7575
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
76-
Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
76+
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
7777

7878
Y = Yc + Yp
79-
out = rearrange(Y, 'b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)
80-
return out.contiguous()
79+
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
80+
return out

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.2.1',
6+
version = '0.2.2',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)