Skip to content

Commit 17471ca

Browse files
committed
save on transposes, further cleanup
1 parent 01a05d1 commit 17471ca

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

lambda_networks/lambda_networks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def __init__(
4242
self.local_contexts = exists(r)
4343
if exists(r):
4444
assert (r % 2) == 1, 'Receptive kernel size should be odd'
45-
self.padding = r // 2
46-
self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
45+
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
4746
else:
4847
assert exists(n), 'You must specify the total sequence length (h x w)'
4948
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
@@ -60,8 +59,8 @@ def forward(self, x):
6059
v = self.norm_v(v)
6160

6261
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
63-
k = rearrange(k, 'b (k u) hh ww -> b u k (hh ww)', u = u)
64-
v = rearrange(v, 'b (v u) hh ww -> b u v (hh ww)', u = u)
62+
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
63+
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
6564

6665
k = k.softmax(dim=-1)
6766

@@ -70,7 +69,7 @@ def forward(self, x):
7069

7170
if self.local_contexts:
7271
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
73-
λp = F.conv3d(v, self.R, padding = (0, self.padding, self.padding))
72+
λp = self.pos_conv(v)
7473
Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
7574
else:
7675
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.1.2',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)