From 8fd35729556cddc4385001fa44a48e0260839fe8 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 25 Aug 2025 15:50:40 +0000 Subject: [PATCH 1/5] init Signed-off-by: NickLucche --- vllm/model_executor/models/gemma3n_mm.py | 52 +++++++++++++++++++++--- vllm/model_executor/models/interfaces.py | 5 ++- vllm/model_executor/models/voxtral.py | 7 ++-- vllm/model_executor/models/whisper.py | 6 +-- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index d59dde1560ae..54fca79cf11e 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, TypedDict, Union, cast +from typing import Any, Literal, Optional, TypedDict, Union, cast +import numpy as np import torch from torch import nn + from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import (Gemma3nAudioConfig, Gemma3nAudioFeatureExtractor, @@ -12,8 +14,8 @@ Gemma3nTextConfig, Gemma3nVisionConfig) from transformers.models.siglip import SiglipImageProcessorFast - -from vllm.config import VllmConfig +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear @@ -21,6 +23,7 @@ VocabParallelEmbedding) from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -40,7 +43,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -410,7 +414,10 @@ def forward( @MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, info=Gemma3nProcessingInfo, dummy_inputs=Gemma3nDummyInputsBuilder) -class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): +class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsTranscription): + supported_languages = ISO639_1_SUPPORTED_LANGS + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -694,3 +701,38 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: return "" else: raise ValueError(f"Unsupported modality: {modality}") + + @classmethod + def get_generation_prompt(cls, audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str) -> PromptType: + """ + Gemma3n supports "free-form" transcription. + We fix its prompt here to standardize transcriptions/translations + requests. + """ + prompt = "user\n" + prompt += "Transcribe" if task_type == "transcribe" else "Translate" + prompt += " this audio" + if language is not None: + # We assume the language is a valid ISO 639-1 code. + full_lang_name = cls.supported_languages[language] + prompt += f" into {full_lang_name}" + prompt += ":\nmodel\n" + + audio = (audio, stt_config.sample_rate) + prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} + return cast(PromptType, prompts_dict) + + @classmethod + def get_speech_to_text_config(cls, model_config: ModelConfig, + task_type: str) -> SpeechToTextConfig: + return SpeechToTextConfig( + # Let's set this to 30 as suggested in the docs for now, although + # the model is only limited by its context length. + max_audio_clip_s=30, + sample_rate=16000, + ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9415e67924e7..83ccac7d5b30 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -8,9 +8,9 @@ import numpy as np import torch from torch import Tensor -from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs +from transformers.models.whisper.tokenization_whisper import LANGUAGES from vllm.config import ModelConfig, SpeechToTextConfig from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType @@ -729,7 +729,8 @@ def __init_subclass__(cls, **kwargs): def get_generation_prompt(cls, audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], task_type: str, + language: Optional[str], + task_type: Literal["transcribe", "translate"], request_prompt: str) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 77f11a691e08..b7398682c966 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -5,7 +5,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from math import ceil -from typing import Optional, Union, cast +from typing import Literal, Optional, Union, cast import numpy as np import regex as re @@ -17,9 +17,9 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder + from transformers import TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput - from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -451,7 +451,8 @@ def get_speech_to_text_config(cls, model_config: ModelConfig, def get_generation_prompt(cls, audio: np.ndarray, model_config: ModelConfig, stt_config: SpeechToTextConfig, - language: Optional[str], task_type: str, + language: Optional[str], + task_type: Literal["transcribe", "translate"], request_prompt: str) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) audio = Audio(audio, int(stt_config.sample_rate), diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 16bbe2f2010a..bf8715bb70fe 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,15 +4,15 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Optional, TypedDict, Union, cast +from typing import Literal, Optional, TypedDict, Union, cast import numpy as np import torch from torch import nn + from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids - from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, @@ -783,7 +783,7 @@ def get_generation_prompt( model_config: ModelConfig, # not needed here stt_config: SpeechToTextConfig, language: Optional[str], - task_type: str, + task_type: Literal["transcribe", "translate"], request_prompt: str) -> PromptType: if language is None: raise ValueError( From 54e0f948510d836319e6629c046d48eeb0e8cd01 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 27 Aug 2025 10:22:53 +0000 Subject: [PATCH 2/5] to_language interface + seed Signed-off-by: NickLucche --- vllm/entrypoints/openai/protocol.py | 19 ++++++++++++++++ vllm/entrypoints/openai/speech_to_text.py | 4 +++- vllm/model_executor/models/gemma3n_mm.py | 27 ++++++++++++++++++----- vllm/model_executor/models/interfaces.py | 5 +++-- vllm/model_executor/models/voxtral.py | 5 +++-- vllm/model_executor/models/whisper.py | 5 +++-- 6 files changed, 52 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a3d7b78cf455..96e6985cfadb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2125,6 +2125,13 @@ class TranscriptionRequest(OpenAIBaseModel): ) # --8<-- [end:transcription-extra-params] + to_language: Optional[str] = None + """The language of the output audio we transcribe to. + + Please note that this is not currently used by supported models at this + time, but it is a placeholder for future use, matching translation api. + """ + # --8<-- [start:transcription-sampling-params] temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2352,6 +2359,9 @@ class TranslationRequest(OpenAIBaseModel): # TODO support additional sampling parameters # --8<-- [start:translation-sampling-params] + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. @@ -2371,6 +2381,14 @@ class TranslationRequest(OpenAIBaseModel): will improve accuracy. """ + to_language: Optional[str] = None + """The language of the input audio we translate to. + + Please note that this is not supported by all models, refer to the specific + model documentation for more details. + For instance, Whisper only supports `to_language=en`. + """ + stream: Optional[bool] = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat @@ -2402,6 +2420,7 @@ def to_sampling_params( return SamplingParams.from_optional(temperature=temperature, max_tokens=max_tokens, + seed=self.seed, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY) diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 01140a4bfea7..e36644119c1c 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -110,7 +110,9 @@ async def _preprocess_speech_to_text( model_config=self.model_config, language=language, task_type=self.task_type, - request_prompt=request.prompt) + request_prompt=request.prompt, + to_language=request.to_language, + ) prompts.append(prompt) return prompts, duration diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 54fca79cf11e..c25bbcd420c3 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -6,7 +6,6 @@ import numpy as np import torch from torch import nn - from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import (Gemma3nAudioConfig, Gemma3nAudioFeatureExtractor, @@ -14,6 +13,7 @@ Gemma3nTextConfig, Gemma3nVisionConfig) from transformers.models.siglip import SiglipImageProcessorFast + from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -708,20 +708,33 @@ def get_generation_prompt(cls, audio: np.ndarray, model_config: ModelConfig, language: Optional[str], task_type: Literal["transcribe", "translate"], - request_prompt: str) -> PromptType: + request_prompt: str, + to_language: Optional[str]) -> PromptType: """ Gemma3n supports "free-form" transcription. We fix its prompt here to standardize transcriptions/translations requests. """ + # Transcribe this audio [into <>] | for transcription + # Translate this audio [from <> into <>] | for translation prompt = "user\n" prompt += "Transcribe" if task_type == "transcribe" else "Translate" prompt += " this audio" - if language is not None: - # We assume the language is a valid ISO 639-1 code. - full_lang_name = cls.supported_languages[language] + + # We assume the language is a valid ISO 639-1 code. + full_lang_name = cls.supported_languages.get(language, "") + # Translation only for now + full_lang_name_to = cls.supported_languages.get(to_language, "") + + if task_type == "transcribe" and full_lang_name: prompt += f" into {full_lang_name}" - prompt += ":\nmodel\n" + elif task_type == "translate": + if full_lang_name: + prompt += f" from {full_lang_name}" + if full_lang_name_to: + prompt += f" into {full_lang_name_to}" + + prompt += ": \nmodel\n" audio = (audio, stt_config.sample_rate) prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} @@ -735,4 +748,6 @@ def get_speech_to_text_config(cls, model_config: ModelConfig, # the model is only limited by its context length. max_audio_clip_s=30, sample_rate=16000, + # TODO enable chunking after more thorough testing. + min_energy_split_window_size=None, ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 83ccac7d5b30..47c3f9a121d8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -8,9 +8,9 @@ import numpy as np import torch from torch import Tensor +from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs -from transformers.models.whisper.tokenization_whisper import LANGUAGES from vllm.config import ModelConfig, SpeechToTextConfig from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType @@ -731,7 +731,8 @@ def get_generation_prompt(cls, audio: np.ndarray, model_config: ModelConfig, language: Optional[str], task_type: Literal["transcribe", "translate"], - request_prompt: str) -> PromptType: + request_prompt: str, + to_language: Optional[str]) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it returns a valid PromptType.""" diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index b7398682c966..477d1dae2845 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -17,9 +17,9 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder - from transformers import TensorType, WhisperConfig from transformers.tokenization_utils_base import TextInput + from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -453,7 +453,8 @@ def get_generation_prompt(cls, audio: np.ndarray, stt_config: SpeechToTextConfig, language: Optional[str], task_type: Literal["transcribe", "translate"], - request_prompt: str) -> PromptType: + request_prompt: str, + to_language: Optional[str]) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index bf8715bb70fe..848b6e0f8093 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -9,10 +9,10 @@ import numpy as np import torch from torch import nn - from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids + from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, @@ -784,7 +784,8 @@ def get_generation_prompt( stt_config: SpeechToTextConfig, language: Optional[str], task_type: Literal["transcribe", "translate"], - request_prompt: str) -> PromptType: + request_prompt: str, + to_language: Optional[str]) -> PromptType: if language is None: raise ValueError( "Language must be specified when creating the Whisper prompt") From 79bebe9bf4ca79f06bc242d2a00f0b4ca4938874 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 27 Aug 2025 10:34:15 +0000 Subject: [PATCH 3/5] tests Signed-off-by: NickLucche --- .../openai/test_transcription_validation.py | 35 +++++---- .../openai/test_translation_validation.py | 78 +++++++++++-------- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 93239f41a4ae..87e9dc24f59f 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -12,8 +12,6 @@ import pytest_asyncio import soundfile as sf -from vllm.assets.audio import AudioAsset - from ...utils import RemoteOpenAIServer MODEL_NAME = "openai/whisper-large-v3-turbo" @@ -24,20 +22,6 @@ ] -@pytest.fixture -def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_local_path() - with open(str(path), "rb") as f: - yield f - - -@pytest.fixture -def winning_call(): - path = AudioAsset('winning_call').get_local_path() - with open(str(path), "rb") as f: - yield f - - @pytest.fixture(scope="module") def server(): with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: @@ -73,6 +57,25 @@ async def test_basic_audio(mary_had_lamb, model_name): assert "Mary had a little lamb," in out +@pytest.mark.asyncio +async def test_basic_audio_gemma(foscolo): + # Gemma accuracy on some of the audio samples we use is particularly bad, + # hence we use a different one here. WER is evaluated separately. + model_name = "google/gemma-3n-E2B-it" + server_args = ["--enforce-eager"] + + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=foscolo, + language="it", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "da cui vergine nacque Venere" in out + + @pytest.mark.asyncio async def test_non_asr_model(winning_call): # text to text model diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index f4f5c66f2dee..f43b7a253d28 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -12,32 +12,24 @@ import pytest_asyncio import soundfile as sf -from vllm.assets.audio import AudioAsset - from ...utils import RemoteOpenAIServer -MODEL_NAME = "openai/whisper-small" SERVER_ARGS = ["--enforce-eager"] -@pytest.fixture -def foscolo(): - # Test translation it->en - path = AudioAsset('azacinto_foscolo').get_local_path() - with open(str(path), "rb") as f: - yield f - - -@pytest.fixture(scope="module") -def server(): - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server: - yield remote_server +@pytest.fixture(scope="module", + params=["openai/whisper-small", "google/gemma-3n-E2B-it"]) +def server(request): + # Parametrize over model name + with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server: + yield remote_server, request.param @pytest_asyncio.fixture -async def client(server): +async def client_and_model(server): + server, model_name = server async with server.get_async_client() as async_client: - yield async_client + yield async_client, model_name @pytest.mark.asyncio @@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo): # NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! @pytest.mark.asyncio -async def test_basic_audio(foscolo, client): +async def test_basic_audio(foscolo, client_and_model): + client, model_name = client_and_model translation = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, response_format="text", - # TODO remove once language detection is implemented - extra_body=dict(language="it"), + # TODO remove `language="it"` once language detection is implemented + extra_body=dict(language="it", to_language="en"), temperature=0.0) out = json.loads(translation)['text'].strip().lower() assert "greek sea" in out @pytest.mark.asyncio -async def test_audio_prompt(foscolo, client): +async def test_audio_prompt(foscolo, client_and_model): + client, model_name = client_and_model # Condition whisper on starting text prompt = "Nor have I ever" transcription = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, prompt=prompt, - extra_body=dict(language="it"), + extra_body=dict(language="it", to_language="en"), response_format="text", temperature=0.0) out = json.loads(transcription)['text'] @@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client): @pytest.mark.asyncio -async def test_streaming_response(foscolo, client, server): +async def test_streaming_response(foscolo, client_and_model, server): + client, model_name = client_and_model translation = "" res_no_stream = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=foscolo, response_format="json", - extra_body=dict(language="it"), + extra_body=dict(language="it", to_language="en", seed=42), temperature=0.0) + # Stream via HTTPX since OpenAI translation client doesn't expose streaming + server, model_name = server url = server.url_for("v1/audio/translations") headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} data = { - "model": MODEL_NAME, + "model": model_name, "language": "it", + "to_language": "en", "stream": True, "temperature": 0.0, + "seed": 42, } foscolo.seek(0) async with httpx.AsyncClient() as http_client: @@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server): text = chunk["choices"][0].get("delta", {}).get("content") translation += text or "" - assert translation == res_no_stream.text + res_stream = translation.split() + # NOTE There's a small non-deterministic issue here, likely in the attn + # computation, which will cause a few tokens to be different, while still + # being very close semantically. + assert sum([ + x == y for x, y in zip(res_stream, res_no_stream.text.split()) + ]) >= len(res_stream) * 0.9 @pytest.mark.asyncio -async def test_stream_options(foscolo, client, server): +async def test_stream_options(foscolo, server): + server, model_name = server url = server.url_for("v1/audio/translations") headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"} data = { - "model": MODEL_NAME, + "model": model_name, "language": "it", + "to_language": "en", "stream": True, "stream_include_usage": True, "stream_continuous_usage_stats": True, @@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server): @pytest.mark.asyncio -async def test_long_audio_request(foscolo, client): +async def test_long_audio_request(foscolo, client_and_model): + client, model_name = client_and_model + if model_name == "google/gemma-3n-E2B-it": + pytest.skip("Gemma3n does not support long audio requests") foscolo.seek(0) audio, sr = librosa.load(foscolo) repeated_audio = np.tile(audio, 2) @@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client): sf.write(buffer, repeated_audio, sr, format='WAV') buffer.seek(0) translation = await client.audio.translations.create( - model=MODEL_NAME, + model=model_name, file=buffer, - extra_body=dict(language="it"), + extra_body=dict(language="it", to_language="en"), response_format="text", temperature=0.0) out = json.loads(translation)['text'].strip().lower() From 52bab20db9ba40dfc132c5123ede8d88c4e17d0f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 27 Aug 2025 13:13:28 +0000 Subject: [PATCH 4/5] to_language validation Signed-off-by: NickLucche --- vllm/entrypoints/openai/speech_to_text.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index e36644119c1c..50241c8944b7 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -87,6 +87,9 @@ async def _preprocess_speech_to_text( ) -> tuple[list[PromptType], float]: # Validate request language = self.model_cls.validate_language(request.language) + # Skip to_language validation to avoid extra logging for Whisper. + to_language = self.model_cls.validate_language(request.to_language) \ + if request.to_language else None if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") @@ -111,7 +114,7 @@ async def _preprocess_speech_to_text( language=language, task_type=self.task_type, request_prompt=request.prompt, - to_language=request.to_language, + to_language=to_language, ) prompts.append(prompt) return prompts, duration From 853ac5bae564eca467557588a94babe36a37aa8d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 1 Sep 2025 07:27:44 +0000 Subject: [PATCH 5/5] conftest Signed-off-by: NickLucche --- tests/entrypoints/openai/conftest.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/entrypoints/openai/conftest.py diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/openai/conftest.py new file mode 100644 index 000000000000..0ecdd4245df4 --- /dev/null +++ b/tests/entrypoints/openai/conftest.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.assets.audio import AudioAsset + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def foscolo(): + # Test translation it->en + path = AudioAsset('azacinto_foscolo').get_local_path() + with open(str(path), "rb") as f: + yield f