Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/models/lfm2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
## Summary
[LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) is a new generation of hybrid models developed by [Liquid AI](https://www.liquid.ai/) and available in 3 variants - 350M, 700M, 1.2B.

## Instructions

LFM2 uses the same example code as optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.
LFM2 is a hybrid model, where some attention layers are replaced with short convolutions.

### Example export
Here is a basic example for exporting LFM2, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage.

Export 350m to XNNPack, quantized with 8da4w:
```
python -m extension.llm.export.export_llm \
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
+base.model_class="lfm2_350m" \
+base.params="examples/models/lfm2/config/lfm2_350m_config.json" \
+export.output_name="lfm2_350m_8da4w.pte"
```

Export 700m to XNNPack, quantized with 8da4w:
```
python -m extension.llm.export.export_llm \
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
+base.model_class="lfm2_700m" \
+base.params="examples/models/lfm2/config/lfm2_700m_config.json" \
+export.output_name="lfm2_700m_8da4w.pte"
```

Export 1_2b to XNNPack, quantized with 8da4w:
```
python -m extension.llm.export.export_llm \
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \
+base.model_class="lfm2_1_2b" \
+base.params="examples/models/lfm2/config/lfm2_1_2b_config.json" \
+export.output_name="lfm2_1_2b_8da4w.pte"
```
### Example run
With ExecuTorch pybindings:
```
python -m examples.models.llama.runner.native \
--model lfm2_700m \
--pte lfm2_700m_8da4w.pte \
--tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \
--tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \
--prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \
--params examples/models/lfm2/config/lfm2_700m_config.json \
--max_len 128 \
-kv \
--temperature 0.3
```

With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner):
```
cmake-out/examples/models/llama/llama_main \
--model_path lfm2_700m_8da4w.pte \
--tokenizer_path ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \
--prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \
--temperature 0.3
```

To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section.
5 changes: 5 additions & 0 deletions examples/models/lfm2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from executorch.examples.models.lfm2.convert_weights import convert_weights

__all__ = [
"convert_weights",
]
34 changes: 34 additions & 0 deletions examples/models/lfm2/config/lfm2_1_2b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 8192,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 16,
"norm_eps": 1e-5,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 65536,
"use_hf_rope": true,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"layer_types": [
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"conv"
]
}
34 changes: 34 additions & 0 deletions examples/models/lfm2/config/lfm2_350m_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"dim": 1024,
"ffn_dim_multiplier": 1,
"hidden_dim": 4608,
"n_heads": 16,
"n_kv_heads": 8,
"n_layers": 16,
"norm_eps": 1e-5,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 65536,
"use_hf_rope": true,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"layer_types": [
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"conv"
]
}
34 changes: 34 additions & 0 deletions examples/models/lfm2/config/lfm2_700m_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"dim": 1536,
"ffn_dim_multiplier": 1,
"hidden_dim": 6912,
"n_heads": 24,
"n_kv_heads": 8,
"n_layers": 16,
"norm_eps": 1e-5,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 65536,
"use_hf_rope": true,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"layer_types": [
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"conv"
]
}
12 changes: 12 additions & 0 deletions examples/models/lfm2/config/lfm2_xnnpack_fp32.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
base:
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}'

model:
use_kv_cache: True
use_sdpa_with_kv_cache: True
dtype_override: fp32

backend:
xnnpack:
enabled: True
extended_ops: True
15 changes: 15 additions & 0 deletions examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
base:
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}'

model:
use_kv_cache: True
use_sdpa_with_kv_cache: True
dtype_override: fp32

quantization:
qmode: 8da4w

backend:
xnnpack:
enabled: True
extended_ops: True
74 changes: 74 additions & 0 deletions examples/models/lfm2/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key

_LFM_2_TO_META = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.embedding_norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
"model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
}


def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from LFM2 HF format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}

for key, value in state_dict.items():
try:
new_key = get_mapped_key(key, _LFM_2_TO_META)
except:
new_key = key.removeprefix("model.")

# split in_proj
if new_key.endswith(".conv.in_proj.weight"):
for name, split_value in zip(
["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0)
):
converted_state_dict[new_key.replace("in_proj", name)] = split_value
else:
converted_state_dict[new_key] = value

# If lm_head.weight is not present in state dict, assume tied embeddings
if "lm_head.weight" not in state_dict:
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def load_checkpoint(input_dir: str) -> Dict:
print("Loading checkpoint from safetensors directory")
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = lfm_2_to_meta(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")
110 changes: 110 additions & 0 deletions examples/models/lfm2/short_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
from executorch.examples.models.llama.attention import ForwardOptions
from executorch.examples.models.llama.feed_forward import FeedForward

from executorch.examples.models.llama.norm import RMSNorm
from torch import nn


class ShortConv(nn.Module):
def __init__(
self,
dim: int,
L_cache: int = 3,
bias: bool = False,
device: torch.device = None,
dtype: torch.dtype = None,
):
super().__init__()
self.dim = dim
self.L_cache = L_cache
self.device = device
self.dtype = dtype
self.bias = bias

self.conv = nn.Conv1d(
dim,
dim,
kernel_size=L_cache,
padding=0, ## we don't need padding since we handle it manually
groups=dim,
bias=bias,
)

conv_state = torch.zeros(
1, ## batch size is assumed to be 1 for now
dim,
L_cache - 1,
device="cpu",
)
self.register_buffer("conv_state", conv_state)

## better performance in Executorch with separate projections
self.B_proj = nn.Linear(dim, dim, bias=bias)
self.C_proj = nn.Linear(dim, dim, bias=bias)
self.x_proj = nn.Linear(dim, dim, bias=bias)

self.out_proj = nn.Linear(dim, dim, bias=bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seqlen, dim = x.size()
assert batch_size == 1, "batch_size must be 1"

B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)

Bx = B * x # (batch_size, dim, seq_len)

## This is where we handle padding
## By default, the conv_state is initialized to 0.
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
Bx = torch.cat(
[self.conv_state, Bx], dim=-1
) # (batch_size, dim, seq_len + L_cache - 1)

## Update the conv_state
new_conv_state = Bx[
..., -(self.L_cache - 1) :
] # (batch_size, dim, L_cache - 1)
with torch.no_grad():
self.conv_state.copy_(new_conv_state)
Comment on lines +64 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW it looks like you want ring buffer kind of update here which can be done maybe bit more efficiently. However, it complicates the conv computation .


conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len)
y = C * conv_out # (batch_size, dim, seq_len)

y = y.transpose(-1, -2) # (batch_size, seq_len, dim)
y = y.contiguous() # (batch_size, seq_len, dim)
y = self.out_proj(y) # (batch_size, seq_len, dim)
return y

def reset_cache(self):
self.conv_state.zero_()


class ShortConvBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int, norm_eps: float):
super().__init__()
self.L_cache = 3 # hardcode 3 for now
self.conv = ShortConv(dim, self.L_cache, bias=False)
self.feed_forward = FeedForward(dim, hidden_dim)
self.ffn_norm = RMSNorm(dim, norm_eps)
# use attention_norm norm instead of operator_norm to unify with TransformerBlock
self.attention_norm = RMSNorm(dim, norm_eps)

def forward(
self,
x,
freqs_cos=None,
freqs_sin=None,
_unused_attn_options: ForwardOptions = None,
): # x: 1xN
h = self.conv.forward(self.attention_norm(x))
h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
return out, None

def reset_cache(self):
self.conv.reset_cache()
Loading
Loading