2
2
3
3
import torch
4
4
from einops import rearrange , repeat
5
- from torch import nn , einsum
5
+ from torch import einsum , nn
6
6
from torch .nn import functional as F
7
7
8
8
from perceiver_pytorch .rotary import apply_rotary_emb
@@ -37,9 +37,7 @@ def __init__(self, dim, fn, context_dim=None):
37
37
super ().__init__ ()
38
38
self .fn = fn
39
39
self .norm = nn .LayerNorm (dim )
40
- self .norm_context = (
41
- nn .LayerNorm (context_dim ) if exists (context_dim ) else None
42
- )
40
+ self .norm_context = nn .LayerNorm (context_dim ) if exists (context_dim ) else None
43
41
44
42
def forward (self , x , ** kwargs ):
45
43
x = self .norm (x )
@@ -57,6 +55,7 @@ class GEGLU(nn.Module):
57
55
Gaussian Error Gated Linear Unit.
58
56
See Shazer 2020: https://arxiv.org/abs/2002.05202
59
57
"""
58
+
60
59
def forward (self , x ):
61
60
x , gates = x .chunk (2 , dim = - 1 )
62
61
return x * F .gelu (gates )
@@ -85,9 +84,7 @@ def forward(self, x):
85
84
86
85
87
86
class Attention (nn .Module ):
88
- def __init__ (
89
- self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0
90
- ):
87
+ def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
91
88
"""
92
89
Args:
93
90
query_dim: Size of the queries.
@@ -108,9 +105,7 @@ def __init__(
108
105
self .to_q = nn .Linear (query_dim , inner_dim , bias = False )
109
106
self .to_kv = nn .Linear (context_dim , inner_dim * 2 , bias = False )
110
107
111
- self .to_out = nn .Sequential (
112
- nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout )
113
- )
108
+ self .to_out = nn .Sequential (nn .Linear (inner_dim , query_dim ), nn .Dropout (dropout ))
114
109
115
110
def forward (self , x , context = None , mask = None , pos_emb = None ):
116
111
"""
@@ -134,9 +129,7 @@ def forward(self, x, context=None, mask=None, pos_emb=None):
134
129
# Rearrange the query, key and value tensors.
135
130
# b = batch size; n = TODO (PD-2021-09-13)
136
131
# h = number of heads; d = number of dims per head.
137
- q , k , v = map (
138
- lambda t : rearrange (t , "b n (h d) -> (b h) n d" , h = h ), (q , k , v )
139
- )
132
+ q , k , v = map (lambda t : rearrange (t , "b n (h d) -> (b h) n d" , h = h ), (q , k , v ))
140
133
141
134
if exists (pos_emb ):
142
135
q , k = apply_rotary_emb (q , k , pos_emb )
0 commit comments