Skip to content

Commit ba3f3cd

Browse files
ryanheisejongwook
andauthored
Skip silence around hallucinations (#1838)
* Add clip_timestamps option * Add hallucination_silence_threshold option * Fix typing for python < 3.9 --------- Co-authored-by: Jong Wook Kim <[email protected]>
1 parent 8bc8860 commit ba3f3cd

File tree

3 files changed

+153
-19
lines changed

3 files changed

+153
-19
lines changed

whisper/timing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def add_word_timestamps(
299299
word_durations = np.array([t.end - t.start for t in alignment])
300300
word_durations = word_durations[word_durations.nonzero()]
301301
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
302+
median_duration = min(0.7, float(median_duration))
302303
max_duration = median_duration * 2
303304

304305
# hack: truncate long words at sentence boundaries.

whisper/transcribe.py

Lines changed: 135 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import traceback
44
import warnings
5-
from typing import TYPE_CHECKING, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -23,6 +23,7 @@
2323
from .utils import (
2424
exact_div,
2525
format_timestamp,
26+
get_end,
2627
get_writer,
2728
make_safe,
2829
optional_float,
@@ -48,6 +49,8 @@ def transcribe(
4849
word_timestamps: bool = False,
4950
prepend_punctuations: str = "\"'“¿([{-",
5051
append_punctuations: str = "\"'.。,,!!??::”)]}、",
52+
clip_timestamps: Union[str, List[float]] = "0",
53+
hallucination_silence_threshold: Optional[float] = None,
5154
**decode_options,
5255
):
5356
"""
@@ -102,6 +105,14 @@ def transcribe(
102105
decode_options: dict
103106
Keyword arguments to construct `DecodingOptions` instances
104107
108+
clip_timestamps: Union[str, List[float]]
109+
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
110+
The last end timestamp defaults to the end of the file.
111+
112+
hallucination_silence_threshold: Optional[float]
113+
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
114+
when a possible hallucination is detected
115+
105116
Returns
106117
-------
107118
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@@ -121,6 +132,7 @@ def transcribe(
121132
# Pad 30-seconds of silence to the input audio, for slicing
122133
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
123134
content_frames = mel.shape[-1] - N_FRAMES
135+
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
124136

125137
if decode_options.get("language", None) is None:
126138
if not model.is_multilingual:
@@ -147,6 +159,19 @@ def transcribe(
147159
task=task,
148160
)
149161

162+
if isinstance(clip_timestamps, str):
163+
clip_timestamps = [
164+
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
165+
]
166+
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
167+
if len(seek_points) == 0:
168+
seek_points.append(0)
169+
if len(seek_points) % 2 == 1:
170+
seek_points.append(content_frames)
171+
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
172+
173+
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
174+
150175
if word_timestamps and task == "translate":
151176
warnings.warn("Word-level timestamps on translations may not be reliable.")
152177

@@ -190,7 +215,8 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
190215

191216
return decode_result
192217

193-
seek = 0
218+
clip_idx = 0
219+
seek = seek_clips[clip_idx][0]
194220
input_stride = exact_div(
195221
N_FRAMES, model.dims.n_audio_ctx
196222
) # mel frames per output token: 2
@@ -229,10 +255,23 @@ def new_segment(
229255
total=content_frames, unit="frames", disable=verbose is not False
230256
) as pbar:
231257
last_speech_timestamp = 0.0
232-
while seek < content_frames:
258+
# NOTE: This loop is obscurely flattened to make the diff readable.
259+
# A later commit should turn this into a simpler nested loop.
260+
# for seek_clip_start, seek_clip_end in seek_clips:
261+
# while seek < seek_clip_end
262+
while clip_idx < len(seek_clips):
263+
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
264+
if seek < seek_clip_start:
265+
seek = seek_clip_start
266+
if seek >= seek_clip_end:
267+
clip_idx += 1
268+
if clip_idx < len(seek_clips):
269+
seek = seek_clips[clip_idx][0]
270+
continue
233271
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
234-
mel_segment = mel[:, seek : seek + N_FRAMES]
235-
segment_size = min(N_FRAMES, content_frames - seek)
272+
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
273+
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
274+
mel_segment = mel[:, seek : seek + segment_size]
236275
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
237276
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
238277

@@ -257,6 +296,30 @@ def new_segment(
257296
previous_seek = seek
258297
current_segments = []
259298

299+
# anomalous words are very long/short/improbable
300+
def word_anomaly_score(word: dict) -> float:
301+
probability = word.get("probability", 0.0)
302+
duration = word["end"] - word["start"]
303+
score = 0.0
304+
if probability < 0.15:
305+
score += 1.0
306+
if duration < 0.133:
307+
score += (0.133 - duration) * 15
308+
if duration > 2.0:
309+
score += duration - 2.0
310+
return score
311+
312+
def is_segment_anomaly(segment: Optional[dict]) -> bool:
313+
if segment is None or not segment["words"]:
314+
return False
315+
words = [w for w in segment["words"] if w["word"] not in punctuation]
316+
words = words[:8]
317+
score = sum(word_anomaly_score(w) for w in words)
318+
return score >= 3 or score + 0.01 >= len(words)
319+
320+
def next_words_segment(segments: List[dict]) -> Optional[dict]:
321+
return next((s for s in segments if s["words"]), None)
322+
260323
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
261324
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
262325

@@ -330,17 +393,71 @@ def new_segment(
330393
append_punctuations=append_punctuations,
331394
last_speech_timestamp=last_speech_timestamp,
332395
)
333-
word_end_timestamps = [
334-
w["end"] for s in current_segments for w in s["words"]
335-
]
336-
if len(word_end_timestamps) > 0:
337-
last_speech_timestamp = word_end_timestamps[-1]
338-
if not single_timestamp_ending and len(word_end_timestamps) > 0:
339-
seek_shift = round(
340-
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
341-
)
342-
if seek_shift > 0:
343-
seek = previous_seek + seek_shift
396+
397+
if not single_timestamp_ending:
398+
last_word_end = get_end(current_segments)
399+
if last_word_end is not None and last_word_end > time_offset:
400+
seek = round(last_word_end * FRAMES_PER_SECOND)
401+
402+
# skip silence before possible hallucinations
403+
if hallucination_silence_threshold is not None:
404+
threshold = hallucination_silence_threshold
405+
if not single_timestamp_ending:
406+
last_word_end = get_end(current_segments)
407+
if last_word_end is not None and last_word_end > time_offset:
408+
remaining_duration = window_end_time - last_word_end
409+
if remaining_duration > threshold:
410+
seek = round(last_word_end * FRAMES_PER_SECOND)
411+
else:
412+
seek = previous_seek + segment_size
413+
414+
# if first segment might be a hallucination, skip leading silence
415+
first_segment = next_words_segment(current_segments)
416+
if first_segment is not None and is_segment_anomaly(first_segment):
417+
gap = first_segment["start"] - time_offset
418+
if gap > threshold:
419+
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
420+
continue
421+
422+
# skip silence before any possible hallucination that is surrounded
423+
# by silence or more hallucinations
424+
hal_last_end = last_speech_timestamp
425+
for si in range(len(current_segments)):
426+
segment = current_segments[si]
427+
if not segment["words"]:
428+
continue
429+
if is_segment_anomaly(segment):
430+
next_segment = next_words_segment(
431+
current_segments[si + 1 :]
432+
)
433+
if next_segment is not None:
434+
hal_next_start = next_segment["words"][0]["start"]
435+
else:
436+
hal_next_start = time_offset + segment_duration
437+
silence_before = (
438+
segment["start"] - hal_last_end > threshold
439+
or segment["start"] < threshold
440+
or segment["start"] - time_offset < 2.0
441+
)
442+
silence_after = (
443+
hal_next_start - segment["end"] > threshold
444+
or is_segment_anomaly(next_segment)
445+
or window_end_time - segment["end"] < 2.0
446+
)
447+
if silence_before and silence_after:
448+
seek = round(
449+
max(time_offset + 1, segment["start"])
450+
* FRAMES_PER_SECOND
451+
)
452+
if content_duration - segment["end"] < threshold:
453+
seek = content_frames
454+
current_segments[si:] = []
455+
break
456+
hal_last_end = segment["end"]
457+
458+
last_word_end = get_end(current_segments)
459+
if last_word_end is not None:
460+
last_speech_timestamp = last_word_end
344461

345462
if verbose:
346463
for segment in current_segments:
@@ -427,6 +544,8 @@ def valid_model_name(name):
427544
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
428545
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
429546
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
547+
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
548+
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
430549
# fmt: on
431550

432551
args = parser.parse_args().__dict__

whisper/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
import zlib
6-
from typing import Callable, Optional, TextIO
6+
from typing import Callable, List, Optional, TextIO
77

88
system_encoding = sys.getdefaultencoding()
99

@@ -68,6 +68,20 @@ def format_timestamp(
6868
)
6969

7070

71+
def get_start(segments: List[dict]) -> Optional[float]:
72+
return next(
73+
(w["start"] for s in segments for w in s["words"]),
74+
segments[0]["start"] if segments else None,
75+
)
76+
77+
78+
def get_end(segments: List[dict]) -> Optional[float]:
79+
return next(
80+
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
81+
segments[-1]["end"] if segments else None,
82+
)
83+
84+
7185
class ResultWriter:
7286
extension: str
7387

@@ -129,8 +143,8 @@ def iterate_subtitles():
129143
line_len = 0
130144
line_count = 1
131145
# the next subtitle to yield (a list of word timings with whitespace)
132-
subtitle: list[dict] = []
133-
last = result["segments"][0]["words"][0]["start"]
146+
subtitle: List[dict] = []
147+
last: float = get_start(result["segments"]) or 0.0
134148
for segment in result["segments"]:
135149
chunk_index = 0
136150
words_count = max_words_per_line

0 commit comments

Comments
 (0)