@@ -70,17 +70,17 @@ def call(self, inputs, **kwargs):
70
70
71
71
k = Softmax ()(k )
72
72
73
- Lc = Lambda ( lambda x : einsum ('b u k m, b u v m -> b k v' , x [ 0 ], x [ 1 ]))([ k , v ] )
74
- Yc = Lambda ( lambda x : einsum ('b h k n, b k v -> b n h v' , x [ 0 ], x [ 1 ]))([ q , Lc ] )
73
+ Lc = einsum ('b u k m, b u v m -> b k v' , k , v )
74
+ Yc = einsum ('b h k n, b k v -> b n h v' , q , Lc )
75
75
76
76
if self .local_contexts :
77
77
v = Rearrange ('b u v (hh ww) -> b v hh ww u' , hh = hh , ww = ww )(v )
78
78
Lp = self .pos_conv (v )
79
79
Lp = Rearrange ('b v h w k -> b v k (h w)' )(Lp )
80
- Yp = Lambda ( lambda x : einsum ('b h k n, b v k n -> b n h v' , x [ 0 ], x [ 1 ]))([ q , Lp ] )
80
+ Yp = einsum ('b h k n, b v k n -> b n h v' , q , Lp )
81
81
else :
82
- Lp = Lambda ( lambda x : einsum ('n m k u, b u v m -> b n k v' , x [ 0 ], x [ 1 ]))([ self .pos_emb , v ] )
83
- Yp = Lambda ( lambda x : einsum ('b h k n, b n k v -> b n h v' , x [ 0 ], x [ 1 ]))([ q , Lp ] )
82
+ Lp = einsum ('n m k u, b u v m -> b n k v' , self .pos_emb , v )
83
+ Yp = einsum ('b h k n, b n k v -> b n h v' , q , Lp )
84
84
85
85
Y = Add ()([Yc , Yp ])
86
86
out = Rearrange ('b (hh ww) h v -> b hh ww (h v)' , hh = hh , ww = ww )(Y )
0 commit comments