@@ -150,8 +150,6 @@ def load_hparams(dir_model):
150150
151151 @staticmethod
152152 def from_model_architecture (model_architecture ):
153- if model_architecture == "StableLMEpochForCausalLM" :
154- return StableLMModel
155153 if model_architecture == "GPTNeoXForCausalLM" :
156154 return GPTNeoXModel
157155 if model_architecture == "BloomForCausalLM" :
@@ -168,6 +166,8 @@ def from_model_architecture(model_architecture):
168166 return RefactModel
169167 if model_architecture == "PersimmonForCausalLM" :
170168 return PersimmonModel
169+ if model_architecture in ("StableLMEpochForCausalLM" , "LlavaStableLMEpochForCausalLM" ):
170+ return StableLMModel
171171 return Model
172172
173173 def _is_model_safetensors (self ) -> bool :
@@ -201,6 +201,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
201201 return gguf .MODEL_ARCH .REFACT
202202 if arch == "PersimmonForCausalLM" :
203203 return gguf .MODEL_ARCH .PERSIMMON
204+ if arch in ("StableLMEpochForCausalLM" , "LlavaStableLMEpochForCausalLM" ):
205+ return gguf .MODEL_ARCH .STABLELM
204206
205207 raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
206208
@@ -294,15 +296,6 @@ def _set_vocab_sentencepiece(self):
294296 special_vocab .add_to_gguf (self .gguf_writer )
295297
296298
297- class StableLMModel (Model ):
298- def set_gguf_parameters (self ):
299- super ().set_gguf_parameters ()
300- self .gguf_writer .add_rope_dimension_count (
301- int (self .hparams ["rope_pct" ] * (self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ])),
302- )
303- self .gguf_writer .add_layer_norm_eps (1e-5 )
304-
305-
306299class GPTNeoXModel (Model ):
307300 def set_gguf_parameters (self ):
308301 block_count = self .hparams ["num_hidden_layers" ]
@@ -824,6 +817,21 @@ def write_tensors(self):
824817 self .gguf_writer .add_tensor (new_name , data )
825818
826819
820+ class StableLMModel (Model ):
821+ def set_gguf_parameters (self ):
822+ hparams = self .hparams
823+ block_count = hparams ["num_hidden_layers" ]
824+
825+ self .gguf_writer .add_name (dir_model .name )
826+ self .gguf_writer .add_context_length (hparams ["max_position_embeddings" ])
827+ self .gguf_writer .add_embedding_length (hparams ["hidden_size" ])
828+ self .gguf_writer .add_block_count (block_count )
829+ self .gguf_writer .add_feed_forward_length (hparams ["intermediate_size" ])
830+ self .gguf_writer .add_rope_dimension_count (int (hparams ["rope_pct" ]* (hparams ["hidden_size" ] // hparams ["num_attention_heads" ])))
831+ self .gguf_writer .add_head_count (hparams ["num_attention_heads" ])
832+ self .gguf_writer .add_parallel_residual (hparams ["use_parallel_residual" ] if "use_parallel_residual" in hparams else True )
833+ self .gguf_writer .add_layer_norm_eps (1e-5 )
834+
827835###### CONVERSION LOGIC ######
828836
829837def parse_args () -> argparse .Namespace :
0 commit comments