Skip to content

Commit 15ab548

Browse files
committed
nocaptions -> nospeech to match the paper figure
1 parent 6198952 commit 15ab548

File tree

3 files changed

+27
-39
lines changed

3 files changed

+27
-39
lines changed

whisper/decoding.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class DecodingResult:
108108
tokens: List[int] = field(default_factory=list)
109109
text: str = ""
110110
avg_logprob: float = np.nan
111-
no_caption_prob: float = np.nan
111+
no_speech_prob: float = np.nan
112112
temperature: float = np.nan
113113
compression_ratio: float = np.nan
114114

@@ -543,9 +543,9 @@ def _get_suppress_tokens(self) -> Tuple[int]:
543543
suppress_tokens.extend(
544544
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
545545
)
546-
if self.tokenizer.no_captions is not None:
547-
# no-captions probability is collected separately
548-
suppress_tokens.append(self.tokenizer.no_captions)
546+
if self.tokenizer.no_speech is not None:
547+
# no-speech probability is collected separately
548+
suppress_tokens.append(self.tokenizer.no_speech)
549549

550550
return tuple(sorted(set(suppress_tokens)))
551551

@@ -580,15 +580,15 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
580580
assert audio_features.shape[0] == tokens.shape[0]
581581
n_batch = tokens.shape[0]
582582
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
583-
no_caption_probs = [np.nan] * n_batch
583+
no_speech_probs = [np.nan] * n_batch
584584

585585
try:
586586
for i in range(self.sample_len):
587587
logits = self.inference.logits(tokens, audio_features)
588588

589-
if i == 0 and self.tokenizer.no_captions is not None: # save no_caption_probs
589+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
590590
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
591-
no_caption_probs = probs_at_sot[:, self.tokenizer.no_captions].tolist()
591+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
592592

593593
# now we need to consider the logits at the last token only
594594
logits = logits[:, -1]
@@ -605,7 +605,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
605605
finally:
606606
self.inference.cleanup_caching()
607607

608-
return tokens, sum_logprobs, no_caption_probs
608+
return tokens, sum_logprobs, no_speech_probs
609609

610610
@torch.no_grad()
611611
def run(self, mel: Tensor) -> List[DecodingResult]:
@@ -629,12 +629,12 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
629629
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
630630

631631
# call the main sampling loop
632-
tokens, sum_logprobs, no_caption_probs = self._main_loop(audio_features, tokens)
632+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
633633

634634
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
635635
audio_features = audio_features[:: self.n_group]
636-
no_caption_probs = no_caption_probs[:: self.n_group]
637-
assert audio_features.shape[0] == len(no_caption_probs) == n_audio
636+
no_speech_probs = no_speech_probs[:: self.n_group]
637+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
638638

639639
tokens = tokens.reshape(n_audio, self.n_group, -1)
640640
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
@@ -653,7 +653,7 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
653653
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
654654
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
655655

656-
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_caption_probs)
656+
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
657657
if len(set(map(len, fields))) != 1:
658658
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
659659

@@ -664,11 +664,11 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
664664
tokens=tokens,
665665
text=text,
666666
avg_logprob=avg_logprob,
667-
no_caption_prob=no_caption_prob,
667+
no_speech_prob=no_speech_prob,
668668
temperature=self.options.temperature,
669669
compression_ratio=compression_ratio(text),
670670
)
671-
for text, language, tokens, features, avg_logprob, no_caption_prob in zip(*fields)
671+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
672672
]
673673

674674

whisper/tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def sot_prev(self) -> int:
178178

179179
@property
180180
@lru_cache()
181-
def no_captions(self) -> int:
182-
return self._get_single_token_id("<|nocaptions|>")
181+
def no_speech(self) -> int:
182+
return self._get_single_token_id("<|nospeech|>")
183183

184184
@property
185185
@lru_cache()
@@ -283,7 +283,7 @@ def build_tokenizer(name: str = "gpt2"):
283283
"<|transcribe|>",
284284
"<|startoflm|>",
285285
"<|startofprev|>",
286-
"<|nocaptions|>",
286+
"<|nospeech|>",
287287
"<|notimestamps|>",
288288
]
289289

whisper/transcribe.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def transcribe(
2323
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
2424
compression_ratio_threshold: Optional[float] = 2.4,
2525
logprob_threshold: Optional[float] = -1.0,
26-
no_captions_threshold: Optional[float] = 0.6,
26+
no_speech_threshold: Optional[float] = 0.6,
2727
**decode_options,
2828
):
2929
"""
@@ -50,8 +50,8 @@ def transcribe(
5050
logprob_threshold: float
5151
If the average log probability over sampled tokens is below this value, treat as failed
5252
53-
no_captions_threshold: float
54-
If the no_captions probability is higher than this value AND the average log probability
53+
no_speech_threshold: float
54+
If the no_speech probability is higher than this value AND the average log probability
5555
over sampled tokens is below `logprob_threshold`, consider the segment as silent
5656
5757
decode_options: dict
@@ -148,7 +148,7 @@ def add_segment(
148148
"temperature": result.temperature,
149149
"avg_logprob": result.avg_logprob,
150150
"compression_ratio": result.compression_ratio,
151-
"no_caption_prob": result.no_caption_prob,
151+
"no_speech_prob": result.no_speech_prob,
152152
}
153153
)
154154
if verbose:
@@ -163,11 +163,11 @@ def add_segment(
163163
result = decode_with_fallback(segment)[0]
164164
tokens = torch.tensor(result.tokens)
165165

166-
if no_captions_threshold is not None:
166+
if no_speech_threshold is not None:
167167
# no voice activity check
168-
should_skip = result.no_caption_prob > no_captions_threshold
168+
should_skip = result.no_speech_prob > no_speech_threshold
169169
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
170-
# don't skip if the logprob is high enough, despite the no_captions_prob
170+
# don't skip if the logprob is high enough, despite the no_speech_prob
171171
should_skip = False
172172

173173
if should_skip:
@@ -249,7 +249,7 @@ def cli():
249249
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
250250
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
251251
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
252-
parser.add_argument("--no_caption_threshold", type=optional_float, default=0.6, help="if the probability of the <|nocaptions|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
252+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
253253

254254
args = parser.parse_args().__dict__
255255
model_name: str = args.pop("model")
@@ -261,12 +261,8 @@ def cli():
261261
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
262262
args["language"] = "en"
263263

264-
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
265-
compression_ratio_threshold = args.pop("compression_ratio_threshold")
266-
logprob_threshold = args.pop("logprob_threshold")
267-
no_caption_threshold = args.pop("no_caption_threshold")
268-
269264
temperature = args.pop("temperature")
265+
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
270266
if temperature_increment_on_fallback is not None:
271267
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
272268
else:
@@ -276,15 +272,7 @@ def cli():
276272
model = load_model(model_name, device=device)
277273

278274
for audio_path in args.pop("audio"):
279-
result = transcribe(
280-
model,
281-
audio_path,
282-
temperature=temperature,
283-
compression_ratio_threshold=compression_ratio_threshold,
284-
logprob_threshold=logprob_threshold,
285-
no_captions_threshold=no_caption_threshold,
286-
**args,
287-
)
275+
result = transcribe(model, audio_path, temperature=temperature, **args)
288276

289277
audio_basename = os.path.basename(audio_path)
290278

0 commit comments

Comments
 (0)