@@ -201,6 +201,8 @@ def from_model_architecture(model_architecture):
201201 return PlamoModel
202202 if model_architecture == "CodeShellForCausalLM" :
203203 return CodeShellModel
204+ if model_architecture == "OrionForCausalLM" :
205+ return OrionModel
204206 return Model
205207
206208 def _is_model_safetensors (self ) -> bool :
@@ -250,6 +252,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
250252 return gguf .MODEL_ARCH .PLAMO
251253 if arch == "CodeShellForCausalLM" :
252254 return gguf .MODEL_ARCH .CODESHELL
255+ if arch == "OrionForCausalLM" :
256+ return gguf .MODEL_ARCH .ORION
253257
254258 raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
255259
@@ -572,6 +576,83 @@ def write_tensors(self):
572576 self .gguf_writer .add_tensor ("output.weight" , data )
573577
574578
579+ class OrionModel (Model ):
580+ def set_vocab (self ):
581+ self ._set_vocab_sentencepiece ()
582+
583+ def set_gguf_parameters (self ):
584+ block_count = self .hparams ["num_hidden_layers" ]
585+ head_count = self .hparams ["num_attention_heads" ]
586+ head_count_kv = self .hparams .get ("num_key_value_heads" , head_count )
587+ hf_repo = self .hparams .get ("_name_or_path" , "" )
588+
589+ ctx_length = 0
590+ if "max_sequence_length" in self .hparams :
591+ ctx_length = self .hparams ["max_sequence_length" ]
592+ elif "max_position_embeddings" in self .hparams :
593+ ctx_length = self .hparams ["max_position_embeddings" ]
594+ elif "model_max_length" in self .hparams :
595+ ctx_length = self .hparams ["model_max_length" ]
596+ else :
597+ print ("gguf: can not find ctx length parameter." )
598+ sys .exit ()
599+
600+ self .gguf_writer .add_file_type (self .ftype )
601+ self .gguf_writer .add_name (self .dir_model .name )
602+ self .gguf_writer .add_source_hf_repo (hf_repo )
603+ self .gguf_writer .add_tensor_data_layout ("Meta AI original pth" )
604+ self .gguf_writer .add_context_length (ctx_length )
605+ self .gguf_writer .add_embedding_length (self .hparams ["hidden_size" ])
606+ self .gguf_writer .add_block_count (block_count )
607+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
608+ self .gguf_writer .add_head_count (head_count )
609+ self .gguf_writer .add_head_count_kv (head_count_kv )
610+ self .gguf_writer .add_layer_norm_eps (self .hparams ["rms_norm_eps" ])
611+
612+ def write_tensors (self ):
613+ # Collect tensors from generator object
614+ model_kv = dict (self .get_tensors ())
615+ block_count = self .hparams ["num_hidden_layers" ]
616+ tensor_map = gguf .get_tensor_name_map (self .model_arch , block_count )
617+
618+ for name , data_torch in model_kv .items ():
619+ # we don't need these
620+ if name .endswith (".rotary_emb.inv_freq" ):
621+ continue
622+
623+ old_dtype = data_torch .dtype
624+
625+ # convert any unsupported data types to float32
626+ if data_torch .dtype not in (torch .float16 , torch .float32 ):
627+ data_torch = data_torch .to (torch .float32 )
628+
629+ data = data_torch .squeeze ().numpy ()
630+
631+ # map tensor names
632+ new_name = tensor_map .get_name (name , try_suffixes = (".weight" , ".bias" ))
633+ if new_name is None :
634+ print (f"Can not map tensor { name !r} " )
635+ sys .exit ()
636+
637+ n_dims = len (data .shape )
638+ data_dtype = data .dtype
639+
640+ # if f32 desired, convert any float16 to float32
641+ if self .ftype == 0 and data_dtype == np .float16 :
642+ data = data .astype (np .float32 )
643+
644+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
645+ if self .ftype == 1 and data_dtype == np .float16 and n_dims == 1 :
646+ data = data .astype (np .float32 )
647+
648+ # if f16 desired, convert any float32 2-dim weight tensors to float16
649+ if self .ftype == 1 and data_dtype == np .float32 and name .endswith (".weight" ) and n_dims == 2 :
650+ data = data .astype (np .float16 )
651+
652+ print (f"{ name } -> { new_name } , n_dims = { n_dims } , { old_dtype } --> { data .dtype } " )
653+ self .gguf_writer .add_tensor (new_name , data )
654+
655+
575656class BaichuanModel (Model ):
576657 def set_vocab (self ):
577658 self ._set_vocab_sentencepiece ()
0 commit comments