Skip to content

Commit 6208340

Browse files
authored
Enable long audio for WhisperAudioProcessor
Differential Revision: D81093558 Pull Request resolved: #13736
1 parent 785d298 commit 6208340

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

extension/audio/mel_spectrogram.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,23 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
123123
r"""
124124
Args:
125125
waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
126-
where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30
127126
128127
Returns:
129-
torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
130-
[1, 80, 3000] with default options
128+
torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks]
129+
n_chunks is the number of chunks of `sampling_rate` samples in the input waveform.
130+
[1, 80, 3000] with default options and 1 chunk
131131
"""
132-
# TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
132+
n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1
133133
waveform = F.pad(
134134
waveform,
135-
(0, self.n_samples - waveform.shape[0] - 1),
135+
(0, self.n_samples * n_chunks - waveform.shape[0]),
136136
mode="constant",
137137
value=self.padding_value,
138138
)
139+
# Ideally we should do:
140+
# window = torch.hann_window(self.n_fft)
141+
# but this is not currently supported when lowering.
142+
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
139143
window = 0.5 * (
140144
1
141145
- torch.cos(
@@ -145,10 +149,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
145149
/ self.n_fft
146150
)
147151
)
148-
# Ideally we should do instead
149-
# window = torch.hann_window(self.n_fft)
150-
# but this is not currently supported when lowering
151-
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
152152
stft = torch.stft(
153153
waveform,
154154
n_fft=self.n_fft,
@@ -157,7 +157,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
157157
center=True,
158158
return_complex=True,
159159
)
160-
magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58]
160+
magnitudes = torch.abs(stft)[..., :-1] ** 2 # pyre-ignore[58]
161161

162162
mel_spec = self.mel_filters @ magnitudes
163163

@@ -173,8 +173,7 @@ def export_processor():
173173
audio_tensor = torch.randn(480000)
174174
chunk_tensor = audio_tensor[:93680]
175175
with torch.no_grad():
176-
# export. What is the min of waveforms?
177-
dim = Dim("waveform", min=1600, max=audio_tensor.size(0))
176+
dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max
178177
ep: ExportedProgram = export(
179178
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
180179
)

0 commit comments

Comments
 (0)