@@ -123,19 +123,23 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
123
123
r"""
124
124
Args:
125
125
waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
126
- where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30
127
126
128
127
Returns:
129
- torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
130
- [1, 80, 3000] with default options
128
+ torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks]
129
+ n_chunks is the number of chunks of `sampling_rate` samples in the input waveform.
130
+ [1, 80, 3000] with default options and 1 chunk
131
131
"""
132
- # TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
132
+ n_chunks = ( waveform . shape [ 0 ] - 1 ) // self . n_samples + 1
133
133
waveform = F .pad (
134
134
waveform ,
135
- (0 , self .n_samples - waveform .shape [0 ] - 1 ),
135
+ (0 , self .n_samples * n_chunks - waveform .shape [0 ]),
136
136
mode = "constant" ,
137
137
value = self .padding_value ,
138
138
)
139
+ # Ideally we should do:
140
+ # window = torch.hann_window(self.n_fft)
141
+ # but this is not currently supported when lowering.
142
+ # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
139
143
window = 0.5 * (
140
144
1
141
145
- torch .cos (
@@ -145,10 +149,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
145
149
/ self .n_fft
146
150
)
147
151
)
148
- # Ideally we should do instead
149
- # window = torch.hann_window(self.n_fft)
150
- # but this is not currently supported when lowering
151
- # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
152
152
stft = torch .stft (
153
153
waveform ,
154
154
n_fft = self .n_fft ,
@@ -157,7 +157,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
157
157
center = True ,
158
158
return_complex = True ,
159
159
)
160
- magnitudes = torch .abs (stft ) ** 2 # pyre-ignore[58]
160
+ magnitudes = torch .abs (stft )[..., : - 1 ] ** 2 # pyre-ignore[58]
161
161
162
162
mel_spec = self .mel_filters @ magnitudes
163
163
@@ -173,8 +173,7 @@ def export_processor():
173
173
audio_tensor = torch .randn (480000 )
174
174
chunk_tensor = audio_tensor [:93680 ]
175
175
with torch .no_grad ():
176
- # export. What is the min of waveforms?
177
- dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ))
176
+ dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ) * 10 ) # 10 chunks max
178
177
ep : ExportedProgram = export (
179
178
model , (chunk_tensor ,), dynamic_shapes = {"waveform" : {0 : dim }}, strict = True
180
179
)
0 commit comments