Skip to content

Commit 173ab0a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0474deb commit 173ab0a

22 files changed

+130
-138
lines changed

perceiver_pytorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from perceiver_pytorch.perceiver_pytorch import Perceiver
2-
from perceiver_pytorch.perceiver_io import PerceiverIO
31
from perceiver_pytorch.multi_perceiver_pytorch import MultiPerceiver
2+
from perceiver_pytorch.perceiver_io import PerceiverIO
3+
from perceiver_pytorch.perceiver_pytorch import Perceiver

perceiver_pytorch/convolutions.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def __init__(
2121

2222
layers = [self.make_layer(input_channels, output_channels, batch=use_batchnorm)]
2323
for _ in range(num_layers - 1):
24-
layers += [
25-
self.make_layer(output_channels, output_channels, batch=use_batchnorm)
26-
]
24+
layers += [self.make_layer(output_channels, output_channels, batch=use_batchnorm)]
2725

2826
super().__init__(*layers)
2927

@@ -94,8 +92,7 @@ def __init__(
9492
# The decoder sets the number of upsamples as log2(upsample_value), and this changes the number of channels
9593
# in a similar way, so it all scales together.
9694
intermediate_output_channels = [
97-
output_channels * pow(2, num_upsamples - 1 - i)
98-
for i in range(0, num_upsamples)
95+
output_channels * pow(2, num_upsamples - 1 - i) for i in range(0, num_upsamples)
9996
]
10097
intermediate_input_channels = [input_channels] + intermediate_output_channels
10198

perceiver_pytorch/decoders.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import torch
21
import numpy as np
3-
from perceiver_pytorch.utils import reverse_space_to_depth
2+
import torch
3+
44
from perceiver_pytorch.convolutions import Conv2DUpsample, Conv3DUpsample
5+
from perceiver_pytorch.utils import reverse_space_to_depth
56

67

78
class ImageDecoder(torch.nn.Module):
@@ -180,7 +181,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
180181

181182
class ImageDecoderPatches(torch.nn.Module):
182183
def __init__(
183-
self, spatial_upsample: int = 1, temporal_upsample: int = 1,
184+
self,
185+
spatial_upsample: int = 1,
186+
temporal_upsample: int = 1,
184187
):
185188
"""
186189
Patch-based image decoder
@@ -196,7 +199,5 @@ def __init__(
196199
self.spatial_upsample = spatial_upsample
197200

198201
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
199-
inputs = reverse_space_to_depth(
200-
inputs, self.temporal_upsample, self.spatial_upsample
201-
)
202+
inputs = reverse_space_to_depth(inputs, self.temporal_upsample, self.spatial_upsample)
202203
return inputs

perceiver_pytorch/encoders.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import math
2+
3+
import numpy as np
14
import torch
2-
from torch import nn
3-
import torchvision
45
import torch.nn.functional as F
5-
import numpy as np
6-
import math
6+
import torchvision
7+
from torch import nn
78

89
from perceiver_pytorch.convolutions import Conv2DDownsample
910
from perceiver_pytorch.utils import space_to_depth
@@ -66,9 +67,7 @@ def __init__(
6667
spatial_downsample=spatial_downsample,
6768
)
6869
elif self.prep_type == "metnet":
69-
self.encoder = ImageEncoderMetNet(
70-
crop_size=crop_size, use_space2depth=use_space2depth
71-
)
70+
self.encoder = ImageEncoderMetNet(crop_size=crop_size, use_space2depth=use_space2depth)
7271

7372
def forward(self, x: torch.Tensor) -> torch.Tensor:
7473
return self.encoder(x)
@@ -169,7 +168,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
169168

170169
class ImageEncoderPatches(torch.nn.Module):
171170
def __init__(
172-
self, spatial_downsample: int = 4, temporal_downsample: int = 1,
171+
self,
172+
spatial_downsample: int = 4,
173+
temporal_downsample: int = 1,
173174
):
174175
"""
175176
Image encoder that uses patches
@@ -198,7 +199,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
198199

199200
class ImageEncoderPixel(torch.nn.Module):
200201
def __init__(
201-
self, spatial_downsample: int = 4, temporal_downsample: int = 1,
202+
self,
203+
spatial_downsample: int = 4,
204+
temporal_downsample: int = 1,
202205
):
203206
"""
204207
Image encoder class for simple downsampling with pixels
@@ -231,7 +234,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
231234

232235
class ImageEncoderMetNet(nn.Module):
233236
def __init__(
234-
self, crop_size: int = 256, use_space2depth: bool = True,
237+
self,
238+
crop_size: int = 256,
239+
use_space2depth: bool = True,
235240
):
236241
"""
237242
Performs the MetNet preprocessing of mean pooling Sat channels, followed by

perceiver_pytorch/gated.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
2-
from torch import nn, einsum
32
import torch.nn.functional as F
4-
53
from einops import rearrange, repeat
4+
from torch import einsum, nn
65

7-
from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention
6+
from perceiver_pytorch.layers import Attention, FeedForward, PreNorm, cache_fn, default, exists
87
from perceiver_pytorch.utils import fourier_encode
98

10-
119
# helpers
1210

1311

@@ -31,9 +29,7 @@ def forward(self, x, **kwargs):
3129
b, dim = x.shape[0], self.dim
3230
y = self.fn(x, **kwargs)
3331

34-
gated_output = self.gru(
35-
rearrange(y, "... d -> (...) d"), rearrange(x, "... d -> (...) d")
36-
)
32+
gated_output = self.gru(rearrange(y, "... d -> (...) d"), rearrange(x, "... d -> (...) d"))
3733

3834
gated_output = rearrange(gated_output, "(b n) d -> b n d", b=b)
3935
return gated_output
@@ -127,30 +123,22 @@ def __init__(
127123
)
128124
)
129125

130-
self.to_logits = nn.Sequential(
131-
nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes)
132-
)
126+
self.to_logits = nn.Sequential(nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes))
133127

134128
def forward(self, data, mask=None):
135129
b, *axis, _, device = *data.shape, data.device
136-
assert (
137-
len(axis) == self.input_axis
138-
), "input data must have the right number of axis"
130+
assert len(axis) == self.input_axis, "input data must have the right number of axis"
139131

140132
# calculate fourier encoded positions in the range of [-1, 1], for all axis
141133

142134
axis_pos = list(
143135
map(
144-
lambda size: torch.linspace(
145-
-1.0, 1.0, steps=size, device=device
146-
),
136+
lambda size: torch.linspace(-1.0, 1.0, steps=size, device=device),
147137
axis,
148138
)
149139
)
150140
pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
151-
enc_pos = fourier_encode(
152-
pos, self.max_freq, self.num_freq_bands, base=self.freq_base
153-
)
141+
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base=self.freq_base)
154142
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)")
155143
enc_pos = repeat(enc_pos, "... -> b ...", b=b)
156144

perceiver_pytorch/layers.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from einops import rearrange, repeat
5-
from torch import nn, einsum
5+
from torch import einsum, nn
66
from torch.nn import functional as F
77

88
from perceiver_pytorch.rotary import apply_rotary_emb
@@ -37,9 +37,7 @@ def __init__(self, dim, fn, context_dim=None):
3737
super().__init__()
3838
self.fn = fn
3939
self.norm = nn.LayerNorm(dim)
40-
self.norm_context = (
41-
nn.LayerNorm(context_dim) if exists(context_dim) else None
42-
)
40+
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
4341

4442
def forward(self, x, **kwargs):
4543
x = self.norm(x)
@@ -57,6 +55,7 @@ class GEGLU(nn.Module):
5755
Gaussian Error Gated Linear Unit.
5856
See Shazer 2020: https://arxiv.org/abs/2002.05202
5957
"""
58+
6059
def forward(self, x):
6160
x, gates = x.chunk(2, dim=-1)
6261
return x * F.gelu(gates)
@@ -85,9 +84,7 @@ def forward(self, x):
8584

8685

8786
class Attention(nn.Module):
88-
def __init__(
89-
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0
90-
):
87+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
9188
"""
9289
Args:
9390
query_dim: Size of the queries.
@@ -108,9 +105,7 @@ def __init__(
108105
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
109106
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
110107

111-
self.to_out = nn.Sequential(
112-
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
113-
)
108+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
114109

115110
def forward(self, x, context=None, mask=None, pos_emb=None):
116111
"""
@@ -134,9 +129,7 @@ def forward(self, x, context=None, mask=None, pos_emb=None):
134129
# Rearrange the query, key and value tensors.
135130
# b = batch size; n = TODO (PD-2021-09-13)
136131
# h = number of heads; d = number of dims per head.
137-
q, k, v = map(
138-
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
139-
)
132+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
140133

141134
if exists(pos_emb):
142135
q, k = apply_rotary_emb(q, k, pos_emb)

perceiver_pytorch/mixed_latents.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
2-
from torch import nn, einsum
32
import torch.nn.functional as F
4-
53
from einops import rearrange, repeat
4+
from torch import einsum, nn
65

7-
from perceiver_pytorch.layers import exists, default, cache_fn, PreNorm, FeedForward, Attention
6+
from perceiver_pytorch.layers import Attention, FeedForward, PreNorm, cache_fn, default, exists
87
from perceiver_pytorch.utils import fourier_encode
98

10-
119
# latent mixer
1210

1311

perceiver_pytorch/modalities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import torch
21
from dataclasses import dataclass
32

3+
import torch
4+
45

56
@dataclass
67
class InputModality:

perceiver_pytorch/multi_perceiver_pytorch.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from perceiver_pytorch.perceiver_io import PerceiverIO
2-
from perceiver_pytorch.modalities import InputModality, modality_encoding
3-
from perceiver_pytorch.utils import encode_position, fourier_encode
1+
from math import prod
2+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
3+
44
import torch
5-
from typing import List, Iterable, Dict, Optional, Any, Union, Tuple
65
from einops import rearrange, repeat
7-
from math import prod
6+
7+
from perceiver_pytorch.modalities import InputModality, modality_encoding
8+
from perceiver_pytorch.perceiver_io import PerceiverIO
9+
from perceiver_pytorch.utils import encode_position, fourier_encode
810

911

1012
class MultiPerceiver(torch.nn.Module):

0 commit comments

Comments
 (0)