@@ -42,8 +42,7 @@ def __init__(
42
42
self .local_contexts = exists (r )
43
43
if exists (r ):
44
44
assert (r % 2 ) == 1 , 'Receptive kernel size should be odd'
45
- self .padding = r // 2
46
- self .R = nn .Parameter (torch .randn (dim_k , dim_u , 1 , r , r ))
45
+ self .pos_conv = nn .Conv3d (dim_u , dim_k , (1 , r , r ), padding = (0 , r // 2 , r // 2 ))
47
46
else :
48
47
assert exists (n ), 'You must specify the total sequence length (h x w)'
49
48
self .pos_emb = nn .Parameter (torch .randn (n , n , dim_k , dim_u ))
@@ -60,8 +59,8 @@ def forward(self, x):
60
59
v = self .norm_v (v )
61
60
62
61
q = rearrange (q , 'b (h k) hh ww -> b h k (hh ww)' , h = h )
63
- k = rearrange (k , 'b (k u ) hh ww -> b u k (hh ww)' , u = u )
64
- v = rearrange (v , 'b (v u ) hh ww -> b u v (hh ww)' , u = u )
62
+ k = rearrange (k , 'b (u k ) hh ww -> b u k (hh ww)' , u = u )
63
+ v = rearrange (v , 'b (u v ) hh ww -> b u v (hh ww)' , u = u )
65
64
66
65
k = k .softmax (dim = - 1 )
67
66
@@ -70,7 +69,7 @@ def forward(self, x):
70
69
71
70
if self .local_contexts :
72
71
v = rearrange (v , 'b u v (hh ww) -> b u v hh ww' , hh = hh , ww = ww )
73
- λp = F . conv3d ( v , self .R , padding = ( 0 , self . padding , self . padding ) )
72
+ λp = self .pos_conv ( v )
74
73
Yp = einsum ('b h k n, b k v n -> b n h v' , q , λp .flatten (3 ))
75
74
else :
76
75
λp = einsum ('n m k u, b u v m -> b n k v' , self .pos_emb , v )
0 commit comments