diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cc3df0975f45..ec4fd6028c006 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1134,6 +1134,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": # ref: https://huggingface.co/JetBrains/Mellum-4b-base res = "mellum" + if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152": + # ref: https://huggingface.co/answerdotai/ModernBERT-base + res = "modern-bert" if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df": # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer res = "afmoe" @@ -9757,6 +9760,47 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.MODERN_BERT + + def set_vocab(self): + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + self.gguf_writer.add_add_sep_token(True) + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_dense_every_n_layers(self.hparams["global_attn_every_n_layers"]) + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename custom "head" layers to standard bert "cls.predictions" names for compatibility + if name == "head.norm.weight": + name = "cls.predictions.transform.LayerNorm.weight" + elif name == "head.norm.bias": + name = "cls.predictions.transform.LayerNorm.bias" + elif name == "head.dense.weight": + name = "cls.predictions.transform.dense.weight" + elif name == "head.dense.bias": + name = "cls.predictions.transform.dense.bias" + + # These layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("ApertusForCausalLM") class ApertusModel(LlamaModel): model_arch = gguf.MODEL_ARCH.APERTUS diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index b8f694e86c062..649a79d78e1d6 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, + {"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", }, {"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", }, {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cd0efad4a8f1..d4caf09b24a58 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -161,11 +161,14 @@ class Attention: VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" + DENSE_EVERY_N_LAYERS = "{arch}.attention.dense_every_n_layers" + class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" @@ -339,6 +342,7 @@ class MODEL_ARCH(IntEnum): STARCODER = auto() REFACT = auto() BERT = auto() + MODERN_BERT = auto() NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() NEO_BERT = auto() @@ -708,6 +712,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.MODERN_BERT: "modern-bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.NEO_BERT: "neo-bert", @@ -1285,6 +1290,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, ], + MODEL_ARCH.MODERN_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb1341..f225667baccaa 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -839,6 +839,9 @@ def add_iclr_lora_rank(self, length: int) -> None: def add_value_residual_mix_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + def add_rope_freq_base_swa(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value) + def add_gate_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) @@ -847,6 +850,9 @@ def add_relative_attn_buckets_count(self, value: int) -> None: def add_sliding_window(self, value: int) -> None: self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + + def add_dense_every_n_layers(self, value: int) -> None: + self.add_uint32(Keys.Attention.DENSE_EVERY_N_LAYERS.format(arch=self.arch), value) def add_attention_scale(self, value: float) -> None: self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8c7ed10f2e3eb..b00ea37fde6b1 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -17,6 +17,7 @@ class TensorNameMap: "embed_tokens", # embeddinggemma "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert + "embeddings.tok_embeddings", # modern-bert "language_model.embedding.word_embeddings", # persimmon "wte", # gpt2 "transformer.embd.wte", # phi2 @@ -46,6 +47,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert + "embeddings.norm", # modern-bert "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv @@ -105,6 +107,7 @@ class TensorNameMap: "model.norm", # llama4 "model.transformer.ln_f", # llada "model.norm", # cogvlm + "final_norm", # modern-bert ), # Rope frequencies @@ -151,6 +154,7 @@ class TensorNameMap: "model.layers.{bid}.input_layernorm", # llama4 "layers.{bid}.input_layernorm", # embeddinggemma "transformer_encoder.{bid}.attention_norm", # neobert + "layers.{bid}.attn_norm", # modern-bert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding @@ -187,6 +191,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm + "layers.{bid}.attn.Wqkv", # modern-bert ), # Attention query @@ -260,6 +265,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "layers.{bid}.attn.Wo", # modern-bert "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon @@ -342,6 +348,7 @@ class TensorNameMap: "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding "model.layers.{bid}.feedforward_layernorm", # apertus + "layers.{bid}.mlp_norm" # modern-bert ), # Pre feed-forward norm @@ -402,6 +409,7 @@ class TensorNameMap: "layers.{bid}.mlp.up_proj", # embeddinggemma "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "layers.{bid}.mlp.Wi", # modern-bert "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact @@ -513,6 +521,7 @@ class TensorNameMap: "layers.{bid}.mlp.down_proj", # embeddinggemma "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "layers.{bid}.mlp.Wo", # modern-bert "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8ec95ee176240..add733c39344c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,6 +90,7 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b2eb2477f930d..acc967707c97d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -19,6 +19,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, @@ -187,6 +188,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_DENSE_EVERY_N_LAYERS, "%s.attention.dense_every_n_layers" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -196,6 +198,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, @@ -586,6 +589,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, + { + LLM_ARCH_MODERN_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, { LLM_ARCH_NOMIC_BERT, { diff --git a/src/llama-arch.h b/src/llama-arch.h index ae7fa222acaa6..b7cfb91cd3561 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -23,6 +23,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, @@ -191,6 +192,7 @@ enum llm_kv { LLM_KV_ATTENTION_GATE_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_DENSE_EVERY_N_LAYERS, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -200,6 +202,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9203af83b2e32..4656bd0002431 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -121,6 +121,7 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; + uint32_t n_swa_pattern = 1; // if swa_layers[il] == true, then layer il is SWA // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) // by default, all layers are dense diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 563823dc35d8e..6f70472c89966 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -182,6 +182,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + add_kv(LLM_KV_ATTENTION_DENSE_EVERY_N_LAYERS, hparams.n_swa_pattern); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e703181a19804..fa8328b370ff6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -871,6 +871,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_DENSE_EVERY_N_LAYERS, hparams.n_swa_pattern); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1071,6 +1092,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; } + // since vision model stacks deepstack features along feature dim + // we also create a fake "n_embd" for text model to be the main embd + deepstack embds + hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_QWEN3MOE: { @@ -1094,6 +1118,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } + // since vision model stacks deepstack features along feature dim + // we also create a fake "n_embd" for text model to be the main embd + deepstack embds + hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_PHI2: { @@ -3059,6 +3086,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + } break; case LLM_ARCH_NEO_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3367,6 +3425,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3VL: { + // for model loading, the weights only have the main embd + // so we need to divide by the number of deepstack layers + 1 + // n_embd is const int so we declare a new variable + int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3402,6 +3464,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: { + // for model loading, the weights only have the main embd + // so we need to divide by the number of deepstack layers + 1 + // n_embd is const int so we declare a new variable + int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -6877,6 +6943,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: @@ -7034,6 +7101,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + if (hparams.swa_type == LLAMA_SWA_TYPE_SYMMETRIC) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); @@ -7582,6 +7657,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index f730c49540cfe..533f6805cee99 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -24,12 +24,14 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, LLM_TYPE_140M, + LLM_TYPE_149M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -39,6 +41,7 @@ enum llm_type { LLM_TYPE_335M, LLM_TYPE_350M, LLM_TYPE_360M, + LLM_TYPE_395M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448ba53..611c511912d03 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1878,7 +1878,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert" ) { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -2527,6 +2528,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if (_contains_any(model_name, {"modern-bert"})) { + if (token_to_id.count("[MASK]") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } } diff --git a/src/models/models.h b/src/models/models.h index 4d7aeb4f42caa..7806dec029e50 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -321,6 +321,11 @@ struct llm_build_minimax_m2 : public llm_graph_context { llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); }; +template +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp new file mode 100644 index 0000000000000..fbbefd77060dd --- /dev/null +++ b/src/models/modern-bert.cpp @@ -0,0 +1,125 @@ +#include "models.h" + +template +llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const float rope_theta_global = hparams.rope_freq_base_train; + const float rope_theta_local = hparams.rope_freq_base_train_swa; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur = nullptr; + ggml_tensor * inpL = nullptr; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + auto * inp_attn = build_attn_inp_no_cache(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + const float rope_theta = (il % 3 == 0) ? rope_theta_global : rope_theta_local; + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + } + + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + const size_t type_size = ggml_type_size(cur->type); + + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + // attention layer norm + cur = build_norm(cur, model.layers[il].ffn_norm, nullptr, LLM_NORM, il); + + cb(ffn_inp, "ffn_inp", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + cb(cur, "final_norm_out", -1); + + if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + // extracting cls token + cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); + cb(cur, "cls_pooled_embd", -1); + } + + cb(cur, "res_embd", -1); + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); +} + +// Explicit template instantiations +template struct llm_build_modern_bert; +template struct llm_build_modern_bert;