Skip to content

Commit 027d37d

Browse files
toncaocpatonnjeejeelee
authored
[Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (#24960)
Signed-off-by: toncao <[email protected]> Co-authored-by: toncao <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent b982196 commit 027d37d

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

vllm/model_executor/models/qwen2_moe.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,20 @@ def __init__(
7272
hidden_act: str,
7373
quant_config: Optional[QuantizationConfig] = None,
7474
reduce_results: bool = True,
75+
prefix: str = "",
7576
) -> None:
7677
super().__init__()
7778
self.gate_up_proj = MergedColumnParallelLinear(
7879
hidden_size, [intermediate_size] * 2,
7980
bias=False,
80-
quant_config=quant_config)
81+
quant_config=quant_config,
82+
prefix=f"{prefix}.gate_up_proj")
8183
self.down_proj = RowParallelLinear(intermediate_size,
8284
hidden_size,
8385
bias=False,
8486
quant_config=quant_config,
85-
reduce_results=reduce_results)
87+
reduce_results=reduce_results,
88+
prefix=f"{prefix}.down_proj")
8689
if hidden_act != "silu":
8790
raise ValueError(f"Unsupported activation: {hidden_act}. "
8891
"Only silu is supported for now.")
@@ -123,7 +126,8 @@ def __init__(
123126
self.gate = ReplicatedLinear(config.hidden_size,
124127
config.num_experts,
125128
bias=False,
126-
quant_config=None)
129+
quant_config=None,
130+
prefix=f"{prefix}.gate")
127131
if config.shared_expert_intermediate_size > 0:
128132
self.shared_expert = Qwen2MoeMLP(
129133
hidden_size=config.hidden_size,
@@ -132,6 +136,7 @@ def __init__(
132136
quant_config=quant_config,
133137
reduce_results=self.experts.must_reduce_shared_expert_outputs(
134138
),
139+
prefix=f"{prefix}.shared_expert",
135140
)
136141
else:
137142
self.shared_expert = None
@@ -203,21 +208,19 @@ def __init__(
203208
self.max_position_embeddings = max_position_embeddings
204209
self.dual_chunk_attention_config = dual_chunk_attention_config
205210

206-
self.qkv_proj = QKVParallelLinear(
207-
hidden_size,
208-
self.head_dim,
209-
self.total_num_heads,
210-
self.total_num_kv_heads,
211-
bias=True,
212-
quant_config=quant_config,
213-
)
211+
self.qkv_proj = QKVParallelLinear(hidden_size,
212+
self.head_dim,
213+
self.total_num_heads,
214+
self.total_num_kv_heads,
215+
bias=True,
216+
quant_config=quant_config,
217+
prefix=f"{prefix}.qkv_proj")
214218

215-
self.o_proj = RowParallelLinear(
216-
self.total_num_heads * self.head_dim,
217-
hidden_size,
218-
bias=False,
219-
quant_config=quant_config,
220-
)
219+
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
220+
hidden_size,
221+
bias=False,
222+
quant_config=quant_config,
223+
prefix=f"{prefix}.o_proj")
221224

222225
self.rotary_emb = get_rope(
223226
self.head_dim,
@@ -296,12 +299,11 @@ def __init__(
296299
quant_config=quant_config,
297300
prefix=f"{prefix}.mlp")
298301
else:
299-
self.mlp = Qwen2MoeMLP(
300-
hidden_size=config.hidden_size,
301-
intermediate_size=config.intermediate_size,
302-
hidden_act=config.hidden_act,
303-
quant_config=quant_config,
304-
)
302+
self.mlp = Qwen2MoeMLP(hidden_size=config.hidden_size,
303+
intermediate_size=config.intermediate_size,
304+
hidden_act=config.hidden_act,
305+
quant_config=quant_config,
306+
prefix=f"{prefix}.mlp")
305307
self.input_layernorm = RMSNorm(config.hidden_size,
306308
eps=config.rms_norm_eps)
307309
self.post_attention_layernorm = RMSNorm(config.hidden_size,

vllm/model_executor/models/qwen3_next.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
quant_config=quant_config,
139139
reduce_results=self.experts.must_reduce_shared_expert_outputs(
140140
),
141+
prefix=f"{prefix}.shared_expert",
141142
)
142143
else:
143144
self.shared_expert = None

0 commit comments

Comments
 (0)