Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions tests/entrypoints/openai/test_transcription_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import pytest_asyncio
import soundfile as sf

from vllm.assets.audio import AudioAsset

from ...utils import RemoteOpenAIServer

MODEL_NAME = "openai/whisper-large-v3-turbo"
Expand All @@ -24,20 +22,6 @@
]


@pytest.fixture
def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_local_path()
with open(str(path), "rb") as f:
yield f


@pytest.fixture
def winning_call():
path = AudioAsset('winning_call').get_local_path()
with open(str(path), "rb") as f:
yield f


@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
Expand Down Expand Up @@ -73,6 +57,25 @@ async def test_basic_audio(mary_had_lamb, model_name):
assert "Mary had a little lamb," in out


@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
# Gemma accuracy on some of the audio samples we use is particularly bad,
# hence we use a different one here. WER is evaluated separately.
model_name = "google/gemma-3n-E2B-it"
server_args = ["--enforce-eager"]

with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=foscolo,
language="it",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "da cui vergine nacque Venere" in out


@pytest.mark.asyncio
async def test_non_asr_model(winning_call):
# text to text model
Expand Down
78 changes: 44 additions & 34 deletions tests/entrypoints/openai/test_translation_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,24 @@
import pytest_asyncio
import soundfile as sf

from vllm.assets.audio import AudioAsset

from ...utils import RemoteOpenAIServer

MODEL_NAME = "openai/whisper-small"
SERVER_ARGS = ["--enforce-eager"]


@pytest.fixture
def foscolo():
# Test translation it->en
path = AudioAsset('azacinto_foscolo').get_local_path()
with open(str(path), "rb") as f:
yield f


@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
yield remote_server
@pytest.fixture(scope="module",
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
def server(request):
# Parametrize over model name
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
yield remote_server, request.param


@pytest_asyncio.fixture
async def client(server):
async def client_and_model(server):
server, model_name = server
async with server.get_async_client() as async_client:
yield async_client
yield async_client, model_name


@pytest.mark.asyncio
Expand All @@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo):

# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
async def test_basic_audio(foscolo, client):
async def test_basic_audio(foscolo, client_and_model):
client, model_name = client_and_model
translation = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
response_format="text",
# TODO remove once language detection is implemented
extra_body=dict(language="it"),
# TODO remove `language="it"` once language detection is implemented
extra_body=dict(language="it", to_language="en"),
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
assert "greek sea" in out


@pytest.mark.asyncio
async def test_audio_prompt(foscolo, client):
async def test_audio_prompt(foscolo, client_and_model):
client, model_name = client_and_model
# Condition whisper on starting text
prompt = "Nor have I ever"
transcription = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
prompt=prompt,
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en"),
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
Expand All @@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client):


@pytest.mark.asyncio
async def test_streaming_response(foscolo, client, server):
async def test_streaming_response(foscolo, client_and_model, server):
client, model_name = client_and_model
translation = ""
res_no_stream = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=foscolo,
response_format="json",
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en", seed=42),
temperature=0.0)

# Stream via HTTPX since OpenAI translation client doesn't expose streaming
server, model_name = server
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"model": model_name,
"language": "it",
"to_language": "en",
"stream": True,
"temperature": 0.0,
"seed": 42,
}
foscolo.seek(0)
async with httpx.AsyncClient() as http_client:
Expand All @@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server):
text = chunk["choices"][0].get("delta", {}).get("content")
translation += text or ""

assert translation == res_no_stream.text
res_stream = translation.split()
# NOTE There's a small non-deterministic issue here, likely in the attn
# computation, which will cause a few tokens to be different, while still
# being very close semantically.
assert sum([
x == y for x, y in zip(res_stream, res_no_stream.text.split())
]) >= len(res_stream) * 0.9


@pytest.mark.asyncio
async def test_stream_options(foscolo, client, server):
async def test_stream_options(foscolo, server):
server, model_name = server
url = server.url_for("v1/audio/translations")
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
data = {
"model": MODEL_NAME,
"model": model_name,
"language": "it",
"to_language": "en",
"stream": True,
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
Expand Down Expand Up @@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server):


@pytest.mark.asyncio
async def test_long_audio_request(foscolo, client):
async def test_long_audio_request(foscolo, client_and_model):
client, model_name = client_and_model
if model_name == "google/gemma-3n-E2B-it":
pytest.skip("Gemma3n does not support long audio requests")
foscolo.seek(0)
audio, sr = librosa.load(foscolo)
repeated_audio = np.tile(audio, 2)
Expand All @@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client):
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
translation = await client.audio.translations.create(
model=MODEL_NAME,
model=model_name,
file=buffer,
extra_body=dict(language="it"),
extra_body=dict(language="it", to_language="en"),
response_format="text",
temperature=0.0)
out = json.loads(translation)['text'].strip().lower()
Expand Down
19 changes: 19 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,13 @@ class TranscriptionRequest(OpenAIBaseModel):
)
# --8<-- [end:transcription-extra-params]

to_language: Optional[str] = None
"""The language of the output audio we transcribe to.

Please note that this is not currently used by supported models at this
time, but it is a placeholder for future use, matching translation api.
"""

# --8<-- [start:transcription-sampling-params]
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Expand Down Expand Up @@ -2352,6 +2359,9 @@ class TranslationRequest(OpenAIBaseModel):

# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""

temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.

Expand All @@ -2371,6 +2381,14 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
"""

to_language: Optional[str] = None
"""The language of the input audio we translate to.

Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""

stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Expand Down Expand Up @@ -2402,6 +2420,7 @@ def to_sampling_params(

return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ async def _preprocess_speech_to_text(
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt)
request_prompt=request.prompt,
to_language=request.to_language,
)
prompts.append(prompt)
return prompts, duration

Expand Down
65 changes: 61 additions & 4 deletions vllm/model_executor/models/gemma3n_mm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, TypedDict, Union, cast
from typing import Any, Literal, Optional, TypedDict, Union, cast

import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
Expand All @@ -13,14 +14,16 @@
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast

from vllm.config import VllmConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
Expand All @@ -40,7 +43,8 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -410,7 +414,10 @@ def forward(
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -694,3 +701,53 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
return "<audio_soft_token>"
else:
raise ValueError(f"Unsupported modality: {modality}")

@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations
requests.
"""
# Transcribe this audio [into <>] | for transcription
# Translate this audio [from <> into <>] | for translation
prompt = "<start_of_turn>user\n"
prompt += "Transcribe" if task_type == "transcribe" else "Translate"
prompt += " this audio"

# We assume the language is a valid ISO 639-1 code.
full_lang_name = cls.supported_languages.get(language, "")
# Translation only for now
full_lang_name_to = cls.supported_languages.get(to_language, "")

if task_type == "transcribe" and full_lang_name:
prompt += f" into {full_lang_name}"
elif task_type == "translate":
if full_lang_name:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should validate that both languages are valid when doing translation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming languages are validated beforehand here https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py#L91.
Do you have some extra checks in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in that case perhaps we should pass the full_lang_name directly into the method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that we should have a separate function for each task to reduce branching

prompt += f" from {full_lang_name}"
if full_lang_name_to:
prompt += f" into {full_lang_name_to}"

prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"

audio = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
return cast(PromptType, prompts_dict)

@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
return SpeechToTextConfig(
# Let's set this to 30 as suggested in the docs for now, although
# the model is only limited by its context length.
max_audio_clip_s=30,
sample_rate=16000,
# TODO enable chunking after more thorough testing.
min_energy_split_window_size=None,
)
6 changes: 4 additions & 2 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,10 @@ def __init_subclass__(cls, **kwargs):
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""
Expand Down
Loading