@@ -72,17 +72,20 @@ def __init__(
7272 hidden_act : str ,
7373 quant_config : Optional [QuantizationConfig ] = None ,
7474 reduce_results : bool = True ,
75+ prefix : str = "" ,
7576 ) -> None :
7677 super ().__init__ ()
7778 self .gate_up_proj = MergedColumnParallelLinear (
7879 hidden_size , [intermediate_size ] * 2 ,
7980 bias = False ,
80- quant_config = quant_config )
81+ quant_config = quant_config ,
82+ prefix = f"{ prefix } .gate_up_proj" )
8183 self .down_proj = RowParallelLinear (intermediate_size ,
8284 hidden_size ,
8385 bias = False ,
8486 quant_config = quant_config ,
85- reduce_results = reduce_results )
87+ reduce_results = reduce_results ,
88+ prefix = f"{ prefix } .down_proj" )
8689 if hidden_act != "silu" :
8790 raise ValueError (f"Unsupported activation: { hidden_act } . "
8891 "Only silu is supported for now." )
@@ -123,7 +126,8 @@ def __init__(
123126 self .gate = ReplicatedLinear (config .hidden_size ,
124127 config .num_experts ,
125128 bias = False ,
126- quant_config = None )
129+ quant_config = None ,
130+ prefix = f"{ prefix } .gate" )
127131 if config .shared_expert_intermediate_size > 0 :
128132 self .shared_expert = Qwen2MoeMLP (
129133 hidden_size = config .hidden_size ,
@@ -132,6 +136,7 @@ def __init__(
132136 quant_config = quant_config ,
133137 reduce_results = self .experts .must_reduce_shared_expert_outputs (
134138 ),
139+ prefix = f"{ prefix } .shared_expert" ,
135140 )
136141 else :
137142 self .shared_expert = None
@@ -203,21 +208,19 @@ def __init__(
203208 self .max_position_embeddings = max_position_embeddings
204209 self .dual_chunk_attention_config = dual_chunk_attention_config
205210
206- self .qkv_proj = QKVParallelLinear (
207- hidden_size ,
208- self .head_dim ,
209- self .total_num_heads ,
210- self .total_num_kv_heads ,
211- bias = True ,
212- quant_config = quant_config ,
213- )
211+ self .qkv_proj = QKVParallelLinear (hidden_size ,
212+ self .head_dim ,
213+ self .total_num_heads ,
214+ self .total_num_kv_heads ,
215+ bias = True ,
216+ quant_config = quant_config ,
217+ prefix = f"{ prefix } .qkv_proj" )
214218
215- self .o_proj = RowParallelLinear (
216- self .total_num_heads * self .head_dim ,
217- hidden_size ,
218- bias = False ,
219- quant_config = quant_config ,
220- )
219+ self .o_proj = RowParallelLinear (self .total_num_heads * self .head_dim ,
220+ hidden_size ,
221+ bias = False ,
222+ quant_config = quant_config ,
223+ prefix = f"{ prefix } .o_proj" )
221224
222225 self .rotary_emb = get_rope (
223226 self .head_dim ,
@@ -296,12 +299,11 @@ def __init__(
296299 quant_config = quant_config ,
297300 prefix = f"{ prefix } .mlp" )
298301 else :
299- self .mlp = Qwen2MoeMLP (
300- hidden_size = config .hidden_size ,
301- intermediate_size = config .intermediate_size ,
302- hidden_act = config .hidden_act ,
303- quant_config = quant_config ,
304- )
302+ self .mlp = Qwen2MoeMLP (hidden_size = config .hidden_size ,
303+ intermediate_size = config .intermediate_size ,
304+ hidden_act = config .hidden_act ,
305+ quant_config = quant_config ,
306+ prefix = f"{ prefix } .mlp" )
305307 self .input_layernorm = RMSNorm (config .hidden_size ,
306308 eps = config .rms_norm_eps )
307309 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
0 commit comments