Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8be9a3a
Add initial implementation of Phi-4 model with core components
lab176344 Mar 4, 2025
e794953
Add dependencies for Phi-4 model in pyproject.toml
lab176344 Mar 4, 2025
0c6cd53
Implement vision-only processing for Phi-4 model, including model loa…
lab176344 Mar 4, 2025
2b4a3c3
Update Phi-4 model to support eager attention implementation and add …
lab176344 Mar 4, 2025
b4db4f7
Enhance Phi-4 model loading with fast option and streamline device ha…
lab176344 Mar 4, 2025
b39742a
Add training command for Phi-4 model with configurable options and ou…
lab176344 Mar 4, 2025
a25ea78
Fix naming convention for default model ID and revision constants in …
lab176344 Mar 4, 2025
5dd3138
Add collate functions for training and evaluation data in Phi-4 model
lab176344 Mar 4, 2025
be0cba2
added core
lab176344 Mar 4, 2025
0bc7c99
Add inference functions for multimodal predictions in Phi-4 model
lab176344 Mar 5, 2025
4434ce3
Add input_mode to training and evaluation collate functions and steps
lab176344 Mar 5, 2025
107be0e
Add bitsandbytes dependency and update input_mode comment in training…
lab176344 Mar 5, 2025
7c8bd38
Add attribution comment for loader based on Hugging Face Phi-4 sample
lab176344 Mar 5, 2025
9b15fed
Refactor collate functions to use AutoProcessor and set input_mode as…
lab176344 Mar 5, 2025
a38611d
Add TODO comments to clarify target modules and device map handling i…
lab176344 Mar 6, 2025
3c2c826
Remove unused finetune_phi4_pmc_vqa.py example file
lab176344 Mar 6, 2025
cbc3aa8
Refactor predict_with_inputs to include image embeddings and attentio…
lab176344 Mar 10, 2025
3abcb95
updates test
lab176344 Mar 10, 2025
822e945
Removed QLora and documentation checks
lab176344 Mar 10, 2025
9c3ead4
added docs
lab176344 Mar 10, 2025
f89b86b
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Mar 10, 2025
179d69e
fixed mypy error
lab176344 Mar 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
489 changes: 489 additions & 0 deletions cookbooks/maestro_full_finetune_phi_4_vl_json_extraction.ipynb

Large diffs are not rendered by default.

533 changes: 533 additions & 0 deletions cookbooks/maestro_lora_phi_4_vl_json_extraction.ipynb

Large diffs are not rendered by default.

138 changes: 138 additions & 0 deletions docs/models/phi_4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
---
comments: true
---

## Overview

Phi-4 is Microsoft's state-of-the-art multimodal model that combines advanced vision capabilities with powerful language understanding. This multimodal model can process both images, audio and texts text inputs to perform various tasks including image understanding, reasoning, and generating contextually relevant responses.

The model excels at tasks requiring visual comprehension alongside natural language processing, making it suitable for applications such as visual question answering, image captioning, and multimodal reasoning.

## Install

To use Phi-4 with Maestro, install the required dependencies:

```bash
pip install "maestro[phi_4]"
```

This will install the necessary packages including transformers, torch, flash-attention (for non-macOS platforms), and other dependencies required for working with Phi-4.

## Train

Fine-tune the Phi-4 model using LoRA or other optimization strategies to adapt it to your specific multimodal tasks.

### CLI

Start training from the command line with the following command:

```bash
maestro phi_4 train \
--dataset "dataset/location" \
--epochs 10 \
--batch-size 4 \
--optimization_strategy "lora" \
--metrics "edit_distance"
```

Customize the command with your specific dataset path and adjust hyperparameters like learning rate and batch size according to your requirements.

### Python

For more control over the training process, use the Python API:

```python
from maestro.trainer.models.phi_4.core import train

config = {
"dataset": "dataset/location",
"model_id": "microsoft/Phi-4-multimodal-instruct",
"epochs": 10,
"batch_size": 4,
"lr": 1e-5,
"optimization_strategy": "lora",
"metrics": ["edit_distance"],
"use_flash_attention": True,
}

train(config)
```

## Load

Load a pre-trained or fine-tuned Phi-4 model along with its processor:

```python
from maestro.trainer.models.phi_4.checkpoints import (
OptimizationStrategy, load_model
)

processor, model = load_model(
model_id_or_path="microsoft/Phi-4-multimodal-instruct", # or your fine-tuned model path
optimization_strategy=OptimizationStrategy.NONE,
use_flash_attention=True
)
```

## Predict

Generate predictions using your Phi-4 model with the dedicated prediction function:

```python
from PIL import Image
from maestro.trainer.models.phi_4.inference import predict

# Load an image
image = Image.open("path/to/your/image.jpg")

# Generate a prediction
result = predict(
model=model,
processor=processor,
image=image,
prefix="Describe this image in detail:"
)

print(result)
```

For more control over the prediction process, you can use the lower-level API:

```python
import torch
from PIL import Image
from transformers import BatchFeature
from maestro.trainer.models.phi_4.checkpoints import filter_audio_components

# Load an image
image = Image.open("path/to/your/image.jpg")

# Prepare inputs
inputs = processor(
text="Describe this image in detail:",
images=image,
return_tensors="pt"
)

# Filter out audio components if any
inputs = filter_audio_components(inputs)
inputs = BatchFeature(inputs)
input_len = inputs.input_ids.size(1)

# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
eos_token_id=processor.tokenizer.eos_token_id
)

# Decode the generated text
generated_text = processor.batch_decode(outputs[:, input_len:], skip_special_tokens=True)[0]
formatted_prompt = "Describe this image in detail:".strip()
response_text = generated_text.split(formatted_prompt)[-1].strip()

print(response_text)
```

This provides workflows for using Phi-4 for image understanding and text generation tasks. The simplified `predict` function handles the common case, while the lower-level API gives you more control over the generation process.
Empty file.
237 changes: 237 additions & 0 deletions maestro/trainer/models/phi_4/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import json
import os
from enum import Enum
from typing import Any, Optional

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoProcessor

from maestro.trainer.common.utils.device import parse_device_spec
from maestro.trainer.logger import get_maestro_logger

DEFAULT_PHI_4_MODEL_ID = "microsoft/Phi-4-multimodal-instruct"
DEFAULT_PHI_4_MODEL_REVISION = "refs/heads/main"
logger = get_maestro_logger()


class OptimizationStrategy(Enum):
"""Enumeration for optimization strategies."""

LORA = "lora"
NONE = "none"


def load_model(
model_id_or_path: str = DEFAULT_PHI_4_MODEL_ID,
revision: str = DEFAULT_PHI_4_MODEL_REVISION,
device: str | torch.device = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
use_flash_attention: bool = True,
) -> tuple[AutoProcessor, AutoModelForCausalLM]:
"""
Loads a Phi-4 multimodal model and its associated processor with optional LoRA.

Args:
model_id_or_path (str): The model name or path.
revision (str): The model revision to load.
device (str | torch.device): The device to load the model onto.
optimization_strategy (OptimizationStrategy): LORA or NONE.
cache_dir (Optional[str]): Directory to cache downloaded model files.
use_flash_attention (bool): Whether to use Flash Attention 2.

Returns:
(AutoProcessor, AutoModelForCausalLM):
A tuple containing the loaded processor and model.
"""
device = parse_device_spec(device)
processor = AutoProcessor.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
cache_dir=cache_dir,
use_fast=True,
)

processor.tokenizer.padding_side = "right"
attn_implementation = "flash_attention_2" if use_flash_attention else "eager"

if optimization_strategy in {OptimizationStrategy.LORA}:
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
target_modules=[
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
], # Todo: Check what target modules will be better
task_type="CAUSAL_LM",
)

model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
torch_dtype="auto",
cache_dir=cache_dir,
attn_implementation=attn_implementation,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
torch_dtype="auto",
cache_dir=cache_dir,
attn_implementation=attn_implementation,
)
model.to(device)
return processor, model


def save_model(
target_dir: str,
processor: AutoProcessor,
model: AutoModelForCausalLM,
) -> None:
"""
Save a Phi-4 model and its processor to disk with options for audio layer handling.

Args:
target_dir: Directory path where the model and processor will be saved.
Will be created if it doesn't exist.
processor: The Phi-4 processor to save.
model: The Phi-4 model to save.
"""
os.makedirs(target_dir, exist_ok=True)

processor.save_pretrained(target_dir)
model.save_pretrained(target_dir)

chat_template_path = os.path.join(target_dir, "chat_template.json")
if os.path.exists(chat_template_path):
os.remove(chat_template_path)
logger.info(f"Removed {chat_template_path}")

preprocessor_config_path = os.path.join(target_dir, "preprocessor_config.json")
if os.path.exists(preprocessor_config_path):
try:
with open(preprocessor_config_path) as f:
preprocessor_config = json.load(f)

for param in ["feature_size", "sampling_rate", "padding_value"]:
if param in preprocessor_config:
del preprocessor_config[param]
logger.info(f"Removed '{param}' from preprocessor_config.json")

audio_params = {"audio_compression_rate": 8, "audio_downsample_rate": 1, "audio_feat_stride": 1}

for param, value in audio_params.items():
preprocessor_config[param] = value
logger.info(f"Added '{param}': {value} to preprocessor_config.json")

with open(preprocessor_config_path, "w") as f:
json.dump(preprocessor_config, f, indent=2)
except Exception as e:
logger.warning(f"Error modifying preprocessor_config.json: {e}")


def _remove_audio_layers(model):
"""
Remove audio-related parameters from the model to optimize for vision-only tasks.

This function removes the audio embedding layers and audio-specific LoRA components
to reduce memory usage and focus the model on vision processing.

Args:
model: The Phi-4 model from which to remove audio-related layers.

Returns:
The modified model with audio layers removed.
"""
# Todo: Can use this to remove audio processing components from the Encoder can save some param
try:
logger.info("Removing audio layers to optimize for vision-only processing...")

if hasattr(model, "model") and hasattr(model.model, "embed_tokens_extend"):
if hasattr(model.model.embed_tokens_extend, "audio_embed"):
del model.model.embed_tokens_extend.audio_embed

if hasattr(model, "model") and hasattr(model.model, "layers"):
for layer_idx, layer in enumerate(model.model.layers):
removed_components = 0
lora_components = [
(layer.mlp.down_proj, "lora_A", "speech"),
(layer.mlp.down_proj, "lora_B", "speech"),
(layer.mlp.gate_up_proj, "lora_A", "speech"),
(layer.mlp.gate_up_proj, "lora_B", "speech"),
(layer.self_attn.o_proj, "lora_A", "speech"),
(layer.self_attn.o_proj, "lora_B", "speech"),
(layer.self_attn.qkv_proj, "lora_A", "speech"),
(layer.self_attn.qkv_proj, "lora_B", "speech"),
]

for component, lora_type, key in lora_components:
try:
if hasattr(component, lora_type) and hasattr(getattr(component, lora_type), key):
delattr(getattr(component, lora_type), key)
removed_components += 1
except AttributeError:
continue

if removed_components > 0:
logger.debug(f"Removed {removed_components} audio LoRA components from layer {layer_idx}")

logger.info("Audio layer removal complete")
except Exception as e:
logger.warning(
f"Could not remove some audio layers. This is expected if using a different model variant. Error: {e}"
)

return model


def filter_audio_components(inputs: dict[str, Any]) -> dict[str, Any]:
"""
Filter out audio-related components from the input dictionary.

Args:
inputs: Dictionary containing model inputs, potentially including audio components.

Returns:
A new dictionary with audio-related keys removed.
"""
audio_related_keys = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]

filtered_inputs = {k: v for k, v in inputs.items() if k not in audio_related_keys}

return filtered_inputs


def process_model_inputs(model: AutoModelForCausalLM, inputs: dict[str, Any], **kwargs) -> Any:
"""
Process inputs before passing to the model, removing audio components.

Args:
model: The model to use for processing.
inputs: Dictionary of input tensors and parameters.
**kwargs: Additional arguments to pass to the model.

Returns:
The model's output after processing the filtered inputs.
"""
filtered_inputs = filter_audio_components(inputs)

filtered_inputs.update(kwargs)

# Pass the filtered inputs to the model
return model(**filtered_inputs)
Loading