-
Notifications
You must be signed in to change notification settings - Fork 653
model : support LiquidAI LFM2 hybrid family #13805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tdakhran
wants to merge
5
commits into
pytorch:main
Choose a base branch
from
tdakhran:tarek/feat/lfm2_upstream
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
## Summary | ||
[LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) is a new generation of hybrid models developed by [Liquid AI](https://www.liquid.ai/) and available in 3 variants - 350M, 700M, 1.2B. | ||
|
||
## Instructions | ||
|
||
LFM2 uses the same example code as optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details. | ||
LFM2 is a hybrid model, where some attention layers are replaced with short convolutions. | ||
|
||
### Example export | ||
Here is a basic example for exporting LFM2, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage. | ||
|
||
Export 350m to XNNPack, quantized with 8da4w: | ||
``` | ||
python -m extension.llm.export.export_llm \ | ||
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ | ||
+base.model_class="lfm2_350m" \ | ||
+base.params="examples/models/lfm2/config/lfm2_350m_config.json" \ | ||
+export.output_name="lfm2_350m_8da4w.pte" | ||
``` | ||
|
||
Export 700m to XNNPack, quantized with 8da4w: | ||
``` | ||
python -m extension.llm.export.export_llm \ | ||
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ | ||
+base.model_class="lfm2_700m" \ | ||
+base.params="examples/models/lfm2/config/lfm2_700m_config.json" \ | ||
+export.output_name="lfm2_700m_8da4w.pte" | ||
``` | ||
|
||
Export 1_2b to XNNPack, quantized with 8da4w: | ||
``` | ||
python -m extension.llm.export.export_llm \ | ||
--config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ | ||
+base.model_class="lfm2_1_2b" \ | ||
+base.params="examples/models/lfm2/config/lfm2_1_2b_config.json" \ | ||
+export.output_name="lfm2_1_2b_8da4w.pte" | ||
``` | ||
### Example run | ||
With ExecuTorch pybindings: | ||
``` | ||
python -m examples.models.llama.runner.native \ | ||
--model lfm2_700m \ | ||
--pte lfm2_700m_8da4w.pte \ | ||
--tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ | ||
--tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \ | ||
--prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ | ||
--params examples/models/lfm2/config/lfm2_700m_config.json \ | ||
--max_len 128 \ | ||
-kv \ | ||
--temperature 0.3 | ||
``` | ||
|
||
With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner): | ||
``` | ||
cmake-out/examples/models/llama/llama_main \ | ||
--model_path lfm2_700m_8da4w.pte \ | ||
--tokenizer_path ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ | ||
--prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ | ||
--temperature 0.3 | ||
``` | ||
|
||
To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from executorch.examples.models.lfm2.convert_weights import convert_weights | ||
|
||
__all__ = [ | ||
"convert_weights", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"dim": 2048, | ||
"ffn_dim_multiplier": 1, | ||
"hidden_dim": 8192, | ||
"n_heads": 32, | ||
"n_kv_heads": 8, | ||
"n_layers": 16, | ||
"norm_eps": 1e-5, | ||
"rope_theta": 1000000.0, | ||
"use_scaled_rope": false, | ||
"vocab_size": 65536, | ||
"use_hf_rope": true, | ||
"use_qk_norm": true, | ||
"qk_norm_before_rope": true, | ||
"layer_types": [ | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv" | ||
] | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"dim": 1024, | ||
"ffn_dim_multiplier": 1, | ||
"hidden_dim": 4608, | ||
"n_heads": 16, | ||
"n_kv_heads": 8, | ||
"n_layers": 16, | ||
"norm_eps": 1e-5, | ||
"rope_theta": 1000000.0, | ||
"use_scaled_rope": false, | ||
"vocab_size": 65536, | ||
"use_hf_rope": true, | ||
"use_qk_norm": true, | ||
"qk_norm_before_rope": true, | ||
"layer_types": [ | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv" | ||
] | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"dim": 1536, | ||
"ffn_dim_multiplier": 1, | ||
"hidden_dim": 6912, | ||
"n_heads": 24, | ||
"n_kv_heads": 8, | ||
"n_layers": 16, | ||
"norm_eps": 1e-5, | ||
"rope_theta": 1000000.0, | ||
"use_scaled_rope": false, | ||
"vocab_size": 65536, | ||
"use_hf_rope": true, | ||
"use_qk_norm": true, | ||
"qk_norm_before_rope": true, | ||
"layer_types": [ | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"full_attention", | ||
"conv", | ||
"conv" | ||
] | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
base: | ||
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}' | ||
|
||
model: | ||
use_kv_cache: True | ||
use_sdpa_with_kv_cache: True | ||
dtype_override: fp32 | ||
|
||
backend: | ||
xnnpack: | ||
enabled: True | ||
extended_ops: True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
base: | ||
metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}' | ||
|
||
model: | ||
use_kv_cache: True | ||
use_sdpa_with_kv_cache: True | ||
dtype_override: fp32 | ||
|
||
quantization: | ||
qmode: 8da4w | ||
|
||
backend: | ||
xnnpack: | ||
enabled: True | ||
extended_ops: True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
from typing import Dict | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
_LFM_2_TO_META = { | ||
"model.embed_tokens.weight": "tok_embeddings.weight", | ||
"model.embedding_norm.weight": "norm.weight", | ||
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", | ||
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", | ||
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", | ||
"model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight", | ||
"model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight", | ||
"model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight", | ||
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", | ||
"model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight", | ||
} | ||
|
||
|
||
def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convert a state dict from LFM2 HF format to Meta's format. This function | ||
doesn't handle any sharding or splitting of state dicts. It follows the | ||
state_dict IN -> state_dict OUT pattern. | ||
|
||
Args: | ||
state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format. | ||
|
||
Returns: | ||
Dict[str, torch.Tensor]: State dict in Meta's format. | ||
""" | ||
converted_state_dict = {} | ||
|
||
for key, value in state_dict.items(): | ||
try: | ||
new_key = get_mapped_key(key, _LFM_2_TO_META) | ||
except: | ||
new_key = key.removeprefix("model.") | ||
|
||
# split in_proj | ||
if new_key.endswith(".conv.in_proj.weight"): | ||
for name, split_value in zip( | ||
["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0) | ||
): | ||
converted_state_dict[new_key.replace("in_proj", name)] = split_value | ||
else: | ||
converted_state_dict[new_key] = value | ||
|
||
# If lm_head.weight is not present in state dict, assume tied embeddings | ||
if "lm_head.weight" not in state_dict: | ||
converted_state_dict["output.weight"] = converted_state_dict[ | ||
"tok_embeddings.weight" | ||
] | ||
|
||
return converted_state_dict | ||
|
||
|
||
def load_checkpoint(input_dir: str) -> Dict: | ||
print("Loading checkpoint from safetensors directory") | ||
state_dict = load_file(os.path.join(input_dir, "model.safetensors")) | ||
return state_dict | ||
|
||
|
||
def convert_weights(input_dir: str, output_file: str) -> None: | ||
print("Loading checkpoint...") | ||
sd = load_checkpoint(input_dir) | ||
print("Converting checkpoint...") | ||
sd = lfm_2_to_meta(sd) | ||
print("Saving checkpoint...") | ||
torch.save(sd, output_file) | ||
print("Done.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import torch | ||
from executorch.examples.models.llama.attention import ForwardOptions | ||
from executorch.examples.models.llama.feed_forward import FeedForward | ||
|
||
from executorch.examples.models.llama.norm import RMSNorm | ||
from torch import nn | ||
|
||
|
||
class ShortConv(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
L_cache: int = 3, | ||
bias: bool = False, | ||
device: torch.device = None, | ||
dtype: torch.dtype = None, | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
self.L_cache = L_cache | ||
self.device = device | ||
self.dtype = dtype | ||
self.bias = bias | ||
|
||
self.conv = nn.Conv1d( | ||
dim, | ||
dim, | ||
kernel_size=L_cache, | ||
padding=0, ## we don't need padding since we handle it manually | ||
groups=dim, | ||
bias=bias, | ||
) | ||
|
||
conv_state = torch.zeros( | ||
1, ## batch size is assumed to be 1 for now | ||
dim, | ||
L_cache - 1, | ||
device="cpu", | ||
) | ||
self.register_buffer("conv_state", conv_state) | ||
|
||
## better performance in Executorch with separate projections | ||
self.B_proj = nn.Linear(dim, dim, bias=bias) | ||
self.C_proj = nn.Linear(dim, dim, bias=bias) | ||
self.x_proj = nn.Linear(dim, dim, bias=bias) | ||
|
||
self.out_proj = nn.Linear(dim, dim, bias=bias) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
batch_size, seqlen, dim = x.size() | ||
assert batch_size == 1, "batch_size must be 1" | ||
|
||
B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) | ||
|
||
Bx = B * x # (batch_size, dim, seq_len) | ||
|
||
## This is where we handle padding | ||
## By default, the conv_state is initialized to 0. | ||
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to | ||
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding. | ||
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0. | ||
Bx = torch.cat( | ||
[self.conv_state, Bx], dim=-1 | ||
) # (batch_size, dim, seq_len + L_cache - 1) | ||
|
||
## Update the conv_state | ||
new_conv_state = Bx[ | ||
..., -(self.L_cache - 1) : | ||
] # (batch_size, dim, L_cache - 1) | ||
with torch.no_grad(): | ||
self.conv_state.copy_(new_conv_state) | ||
|
||
conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len) | ||
y = C * conv_out # (batch_size, dim, seq_len) | ||
|
||
y = y.transpose(-1, -2) # (batch_size, seq_len, dim) | ||
y = y.contiguous() # (batch_size, seq_len, dim) | ||
y = self.out_proj(y) # (batch_size, seq_len, dim) | ||
return y | ||
|
||
def reset_cache(self): | ||
self.conv_state.zero_() | ||
|
||
|
||
class ShortConvBlock(nn.Module): | ||
def __init__(self, dim: int, hidden_dim: int, norm_eps: float): | ||
super().__init__() | ||
self.L_cache = 3 # hardcode 3 for now | ||
self.conv = ShortConv(dim, self.L_cache, bias=False) | ||
self.feed_forward = FeedForward(dim, hidden_dim) | ||
self.ffn_norm = RMSNorm(dim, norm_eps) | ||
# use attention_norm norm instead of operator_norm to unify with TransformerBlock | ||
self.attention_norm = RMSNorm(dim, norm_eps) | ||
|
||
def forward( | ||
self, | ||
x, | ||
freqs_cos=None, | ||
freqs_sin=None, | ||
_unused_attn_options: ForwardOptions = None, | ||
): # x: 1xN | ||
h = self.conv.forward(self.attention_norm(x)) | ||
h = x + h | ||
out = h + self.feed_forward(self.ffn_norm(h)) | ||
return out, None | ||
|
||
def reset_cache(self): | ||
self.conv.reset_cache() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW it looks like you want ring buffer kind of update here which can be done maybe bit more efficiently. However, it complicates the conv computation .