Skip to content

Conversation

danielafrimi
Copy link
Contributor

@danielafrimi danielafrimi commented Aug 27, 2025

This PR builds on #23644

In addition to supporting the new VL model, it introduces the vision encoder implementation (C-RADIO
) using vLLM’s native layer.

To reduce code duplication, the implementation leverages InternVisionModel blocks.

Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
Signed-off-by: Daniel Afrimi <[email protected]>
@danielafrimi danielafrimi marked this pull request as draft August 27, 2025 14:28
@mergify mergify bot added multi-modality Related to multi-modality (#4194) new-model Requests to new models labels Aug 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the NemotronH Nano VLM model, including a native vLLM implementation of the C-RADIO vision encoder. The changes are well-structured, introducing new model files for nano_vlm and radio, along with corresponding tests. My review has identified a critical bug in the new test file that would prevent it from running correctly, a potential TypeError in the model initialization due to unsafe dictionary access, and the use of print statements for debugging which should be replaced with proper logging. Addressing these points will improve the robustness and maintainability of the new model support.

Comment on lines +85 to +90
pixel_values = [
img_processor(
images,
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
for images in images
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The list comprehension for pixel_values seems to have a bug. The img_processor is called with a single PIL image, which returns a 3D tensor for pixel_values. However, a 4D slice [:, :, :, :640] is then applied to this 3D tensor, which will cause a runtime error.

To fix this, you can process each image as a list containing a single image to ensure the pixel_values tensor is 4D. Also, using a more descriptive loop variable name would improve readability.

Suggested change
pixel_values = [
img_processor(
images,
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
for images in images
]
pixel_values = [
img_processor(
[img],
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
for img in images
]

Comment on lines +689 to +691
model_name = hf_config.args.get("model")
hidden_size, num_layers, num_heads, intermediate_size = vit_dims.get(
model_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to vit_dims.get(model_name) can return None if model_name is not found in the vit_dims dictionary. This would cause a TypeError when attempting to unpack the result into hidden_size, num_layers, num_heads, intermediate_size. It's safer to check if the model name exists in the dictionary before attempting to get the value.

Suggested change
model_name = hf_config.args.get("model")
hidden_size, num_layers, num_heads, intermediate_size = vit_dims.get(
model_name)
model_name = hf_config.args.get("model")
if model_name not in vit_dims:
raise ValueError(
f"Unsupported ViT model type for Radio: {model_name}. "
f"Supported types are: {list(vit_dims.keys())}")
hidden_size, num_layers, num_heads, intermediate_size = vit_dims[
model_name]

Comment on lines +109 to +111
print("in intervit cls token init num_tokens: ", num_tokens)
print("in intervit cls token init num_registers: ",
self.num_registers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are print statements in the ClsToken initializer. In a library like vLLM, using print for logging can clutter the output and is generally discouraged. Please consider using the logger from vllm.logger for debugging information, or removing these statements if they are not necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
multi-modality Related to multi-modality (#4194) new-model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant