-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
[Frontend] Gemma3n audio transcriptions
/translations
endpoint
#23735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from collections.abc import Iterable, Mapping, Sequence | ||
from typing import Any, Optional, TypedDict, Union, cast | ||
from typing import Any, Literal, Optional, TypedDict, Union, cast | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from transformers import AutoModel, BatchFeature | ||
|
@@ -13,14 +14,16 @@ | |
Gemma3nVisionConfig) | ||
from transformers.models.siglip import SiglipImageProcessorFast | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig | ||
from vllm.inputs.data import PromptType | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import RowParallelLinear | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM | ||
from vllm.model_executor.models.module_mapping import MultiModelKeys | ||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, | ||
|
@@ -40,7 +43,8 @@ | |
from vllm.multimodal.profiling import BaseDummyInputsBuilder | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal | ||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, | ||
SupportsTranscription) | ||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, | ||
init_vllm_registered_model, maybe_prefix, | ||
merge_multimodal_embeddings) | ||
|
@@ -410,7 +414,10 @@ def forward( | |
@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, | ||
info=Gemma3nProcessingInfo, | ||
dummy_inputs=Gemma3nDummyInputsBuilder) | ||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): | ||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, | ||
SupportsTranscription): | ||
supported_languages = ISO639_1_SUPPORTED_LANGS | ||
|
||
packed_modules_mapping = { | ||
"qkv_proj": [ | ||
"q_proj", | ||
|
@@ -694,3 +701,53 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: | |
return "<audio_soft_token>" | ||
else: | ||
raise ValueError(f"Unsupported modality: {modality}") | ||
|
||
@classmethod | ||
def get_generation_prompt(cls, audio: np.ndarray, | ||
stt_config: SpeechToTextConfig, | ||
model_config: ModelConfig, | ||
language: Optional[str], | ||
task_type: Literal["transcribe", "translate"], | ||
request_prompt: str, | ||
to_language: Optional[str]) -> PromptType: | ||
""" | ||
Gemma3n supports "free-form" transcription. | ||
We fix its prompt here to standardize transcriptions/translations | ||
requests. | ||
""" | ||
NickLucche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Transcribe this audio [into <>] | for transcription | ||
# Translate this audio [from <> into <>] | for translation | ||
prompt = "<start_of_turn>user\n" | ||
prompt += "Transcribe" if task_type == "transcribe" else "Translate" | ||
prompt += " this audio" | ||
|
||
# We assume the language is a valid ISO 639-1 code. | ||
full_lang_name = cls.supported_languages.get(language, "") | ||
# Translation only for now | ||
full_lang_name_to = cls.supported_languages.get(to_language, "") | ||
|
||
if task_type == "transcribe" and full_lang_name: | ||
prompt += f" into {full_lang_name}" | ||
elif task_type == "translate": | ||
if full_lang_name: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should validate that both languages are valid when doing translation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am assuming languages are validated beforehand here https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/speech_to_text.py#L91. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, in that case perhaps we should pass the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think that we should have a separate function for each task to reduce branching |
||
prompt += f" from {full_lang_name}" | ||
if full_lang_name_to: | ||
prompt += f" into {full_lang_name_to}" | ||
|
||
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n" | ||
|
||
audio = (audio, stt_config.sample_rate) | ||
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} | ||
return cast(PromptType, prompts_dict) | ||
|
||
@classmethod | ||
def get_speech_to_text_config(cls, model_config: ModelConfig, | ||
task_type: str) -> SpeechToTextConfig: | ||
return SpeechToTextConfig( | ||
# Let's set this to 30 as suggested in the docs for now, although | ||
# the model is only limited by its context length. | ||
max_audio_clip_s=30, | ||
sample_rate=16000, | ||
# TODO enable chunking after more thorough testing. | ||
min_energy_split_window_size=None, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.