We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b6c8874 commit 1b950faCopy full SHA for 1b950fa
lambda_networks/lambda_networks.py
@@ -65,16 +65,16 @@ def forward(self, x):
65
k = k.softmax(dim=-1)
66
67
λc = einsum('b u k m, b u v m -> b k v', k, v)
68
- Yc = einsum('b h k n, b k v -> b n h v', q, λc)
+ Yc = einsum('b h k n, b k v -> b h v n', q, λc)
69
70
if self.local_contexts:
71
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
72
λp = self.pos_conv(v)
73
- Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
+ Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
74
else:
75
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
76
- Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
+ Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
77
78
Y = Yc + Yp
79
- out = rearrange(Y, 'b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)
80
- return out.contiguous()
+ out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
+ return out
setup.py
@@ -3,7 +3,7 @@
3
setup(
4
name = 'lambda-networks',
5
packages = find_packages(),
6
- version = '0.2.1',
+ version = '0.2.2',
7
license='MIT',
8
description = 'Lambda Networks - Pytorch',
9
author = 'Phil Wang',
0 commit comments