77from huggingface_hub import snapshot_download
88from transformers import T5EncoderModel , T5TokenizerFast
99
10- from diffusers import AutoencoderKLCosmos , CosmosTextToWorldPipeline , CosmosTransformer3DModel , EDMEulerScheduler
10+ from diffusers import (
11+ AutoencoderKLCosmos ,
12+ AutoencoderKLWan ,
13+ Cosmos2TextToImagePipeline ,
14+ Cosmos2VideoToWorldPipeline ,
15+ CosmosTextToWorldPipeline ,
16+ CosmosTransformer3DModel ,
17+ CosmosVideoToWorldPipeline ,
18+ EDMEulerScheduler ,
19+ FlowMatchEulerDiscreteScheduler ,
20+ )
1121
1222
1323def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
2939 state_dict [new_key ] = state_dict .pop (key )
3040
3141
32- TRANSFORMER_KEYS_RENAME_DICT = {
42+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
3343 "t_embedder.1" : "time_embed.t_embedder" ,
3444 "affline_norm" : "time_embed.norm" ,
3545 ".blocks.0.block.attn" : ".attn1" ,
@@ -56,14 +66,53 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
5666 "final_layer.linear" : "proj_out" ,
5767}
5868
59- TRANSFORMER_SPECIAL_KEYS_REMAP = {
69+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
6070 "blocks.block" : rename_transformer_blocks_ ,
6171 "logvar.0.freqs" : remove_keys_ ,
6272 "logvar.0.phases" : remove_keys_ ,
6373 "logvar.1.weight" : remove_keys_ ,
6474 "pos_embedder.seq" : remove_keys_ ,
6575}
6676
77+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
78+ "t_embedder.1" : "time_embed.t_embedder" ,
79+ "t_embedding_norm" : "time_embed.norm" ,
80+ "blocks" : "transformer_blocks" ,
81+ "adaln_modulation_self_attn.1" : "norm1.linear_1" ,
82+ "adaln_modulation_self_attn.2" : "norm1.linear_2" ,
83+ "adaln_modulation_cross_attn.1" : "norm2.linear_1" ,
84+ "adaln_modulation_cross_attn.2" : "norm2.linear_2" ,
85+ "adaln_modulation_mlp.1" : "norm3.linear_1" ,
86+ "adaln_modulation_mlp.2" : "norm3.linear_2" ,
87+ "self_attn" : "attn1" ,
88+ "cross_attn" : "attn2" ,
89+ "q_proj" : "to_q" ,
90+ "k_proj" : "to_k" ,
91+ "v_proj" : "to_v" ,
92+ "output_proj" : "to_out.0" ,
93+ "q_norm" : "norm_q" ,
94+ "k_norm" : "norm_k" ,
95+ "mlp.layer1" : "ff.net.0.proj" ,
96+ "mlp.layer2" : "ff.net.2" ,
97+ "x_embedder.proj.1" : "patch_embed.proj" ,
98+ # "extra_pos_embedder": "learnable_pos_embed",
99+ "final_layer.adaln_modulation.1" : "norm_out.linear_1" ,
100+ "final_layer.adaln_modulation.2" : "norm_out.linear_2" ,
101+ "final_layer.linear" : "proj_out" ,
102+ }
103+
104+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
105+ "accum_video_sample_counter" : remove_keys_ ,
106+ "accum_image_sample_counter" : remove_keys_ ,
107+ "accum_iteration" : remove_keys_ ,
108+ "accum_train_in_hours" : remove_keys_ ,
109+ "pos_embedder.seq" : remove_keys_ ,
110+ "pos_embedder.dim_spatial_range" : remove_keys_ ,
111+ "pos_embedder.dim_temporal_range" : remove_keys_ ,
112+ "_extra_state" : remove_keys_ ,
113+ }
114+
115+
67116TRANSFORMER_CONFIGS = {
68117 "Cosmos-1.0-Diffusion-7B-Text2World" : {
69118 "in_channels" : 16 ,
@@ -125,6 +174,66 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
125174 "concat_padding_mask" : True ,
126175 "extra_pos_embed_type" : "learnable" ,
127176 },
177+ "Cosmos-2.0-Diffusion-2B-Text2Image" : {
178+ "in_channels" : 16 ,
179+ "out_channels" : 16 ,
180+ "num_attention_heads" : 16 ,
181+ "attention_head_dim" : 128 ,
182+ "num_layers" : 28 ,
183+ "mlp_ratio" : 4.0 ,
184+ "text_embed_dim" : 1024 ,
185+ "adaln_lora_dim" : 256 ,
186+ "max_size" : (128 , 240 , 240 ),
187+ "patch_size" : (1 , 2 , 2 ),
188+ "rope_scale" : (1.0 , 4.0 , 4.0 ),
189+ "concat_padding_mask" : True ,
190+ "extra_pos_embed_type" : None ,
191+ },
192+ "Cosmos-2.0-Diffusion-14B-Text2Image" : {
193+ "in_channels" : 16 ,
194+ "out_channels" : 16 ,
195+ "num_attention_heads" : 40 ,
196+ "attention_head_dim" : 128 ,
197+ "num_layers" : 36 ,
198+ "mlp_ratio" : 4.0 ,
199+ "text_embed_dim" : 1024 ,
200+ "adaln_lora_dim" : 256 ,
201+ "max_size" : (128 , 240 , 240 ),
202+ "patch_size" : (1 , 2 , 2 ),
203+ "rope_scale" : (1.0 , 4.0 , 4.0 ),
204+ "concat_padding_mask" : True ,
205+ "extra_pos_embed_type" : None ,
206+ },
207+ "Cosmos-2.0-Diffusion-2B-Video2World" : {
208+ "in_channels" : 16 + 1 ,
209+ "out_channels" : 16 ,
210+ "num_attention_heads" : 16 ,
211+ "attention_head_dim" : 128 ,
212+ "num_layers" : 28 ,
213+ "mlp_ratio" : 4.0 ,
214+ "text_embed_dim" : 1024 ,
215+ "adaln_lora_dim" : 256 ,
216+ "max_size" : (128 , 240 , 240 ),
217+ "patch_size" : (1 , 2 , 2 ),
218+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
219+ "concat_padding_mask" : True ,
220+ "extra_pos_embed_type" : None ,
221+ },
222+ "Cosmos-2.0-Diffusion-14B-Video2World" : {
223+ "in_channels" : 16 + 1 ,
224+ "out_channels" : 16 ,
225+ "num_attention_heads" : 40 ,
226+ "attention_head_dim" : 128 ,
227+ "num_layers" : 36 ,
228+ "mlp_ratio" : 4.0 ,
229+ "text_embed_dim" : 1024 ,
230+ "adaln_lora_dim" : 256 ,
231+ "max_size" : (128 , 240 , 240 ),
232+ "patch_size" : (1 , 2 , 2 ),
233+ "rope_scale" : (20 / 24 , 2.0 , 2.0 ),
234+ "concat_padding_mask" : True ,
235+ "extra_pos_embed_type" : None ,
236+ },
128237}
129238
130239VAE_KEYS_RENAME_DICT = {
@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
216325 return state_dict
217326
218327
219- def convert_transformer (transformer_type : str , ckpt_path : str ):
328+ def convert_transformer (transformer_type : str , ckpt_path : str , weights_only : bool = True ):
220329 PREFIX_KEY = "net."
221- original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = True ))
330+ original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , weights_only = weights_only ))
331+
332+ if "Cosmos-1.0" in transformer_type :
333+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
334+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
335+ elif "Cosmos-2.0" in transformer_type :
336+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
337+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
338+ else :
339+ assert False
222340
223341 with init_empty_weights ():
224342 config = TRANSFORMER_CONFIGS [transformer_type ]
@@ -281,13 +399,61 @@ def convert_vae(vae_type: str):
281399 return vae
282400
283401
402+ def save_pipeline_cosmos_1_0 (args , transformer , vae ):
403+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .bfloat16 )
404+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
405+ # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
406+ # So, the sigma_min values that is used is the default value of 0.002.
407+ scheduler = EDMEulerScheduler (
408+ sigma_min = 0.002 ,
409+ sigma_max = 80 ,
410+ sigma_data = 0.5 ,
411+ sigma_schedule = "karras" ,
412+ num_train_timesteps = 1000 ,
413+ prediction_type = "epsilon" ,
414+ rho = 7.0 ,
415+ final_sigmas_type = "sigma_min" ,
416+ )
417+
418+ pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args .transformer_type else CosmosVideoToWorldPipeline
419+ pipe = pipe_cls (
420+ text_encoder = text_encoder ,
421+ tokenizer = tokenizer ,
422+ transformer = transformer ,
423+ vae = vae ,
424+ scheduler = scheduler ,
425+ safety_checker = lambda * args , ** kwargs : None ,
426+ )
427+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
428+
429+
430+ def save_pipeline_cosmos_2_0 (args , transformer , vae ):
431+ text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = torch .bfloat16 )
432+ tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
433+
434+ scheduler = FlowMatchEulerDiscreteScheduler (use_karras_sigmas = True )
435+
436+ pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args .transformer_type else Cosmos2VideoToWorldPipeline
437+ pipe = pipe_cls (
438+ text_encoder = text_encoder ,
439+ tokenizer = tokenizer ,
440+ transformer = transformer ,
441+ vae = vae ,
442+ scheduler = scheduler ,
443+ safety_checker = lambda * args , ** kwargs : None ,
444+ )
445+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
446+
447+
284448def get_args ():
285449 parser = argparse .ArgumentParser ()
286450 parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
287451 parser .add_argument (
288452 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
289453 )
290- parser .add_argument ("--vae_type" , type = str , default = None , choices = list (VAE_CONFIGS .keys ()), help = "Type of VAE" )
454+ parser .add_argument (
455+ "--vae_type" , type = str , default = None , choices = ["none" , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
456+ )
291457 parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
292458 parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
293459 parser .add_argument ("--save_pipeline" , action = "store_true" )
@@ -316,37 +482,26 @@ def get_args():
316482 assert args .tokenizer_path is not None
317483
318484 if args .transformer_ckpt_path is not None :
319- transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path )
485+ weights_only = "Cosmos-1.0" in args .transformer_type
486+ transformer = convert_transformer (args .transformer_type , args .transformer_ckpt_path , weights_only )
320487 transformer = transformer .to (dtype = dtype )
321488 if not args .save_pipeline :
322489 transformer .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
323490
324491 if args .vae_type is not None :
325- vae = convert_vae (args .vae_type )
492+ if "Cosmos-1.0" in args .transformer_type :
493+ vae = convert_vae (args .vae_type )
494+ else :
495+ vae = AutoencoderKLWan .from_pretrained (
496+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
497+ )
326498 if not args .save_pipeline :
327499 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
328500
329501 if args .save_pipeline :
330- text_encoder = T5EncoderModel .from_pretrained (args .text_encoder_path , torch_dtype = dtype )
331- tokenizer = T5TokenizerFast .from_pretrained (args .tokenizer_path )
332- # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
333- # So, the sigma_min values that is used is the default value of 0.002.
334- scheduler = EDMEulerScheduler (
335- sigma_min = 0.002 ,
336- sigma_max = 80 ,
337- sigma_data = 0.5 ,
338- sigma_schedule = "karras" ,
339- num_train_timesteps = 1000 ,
340- prediction_type = "epsilon" ,
341- rho = 7.0 ,
342- final_sigmas_type = "sigma_min" ,
343- )
344-
345- pipe = CosmosTextToWorldPipeline (
346- text_encoder = text_encoder ,
347- tokenizer = tokenizer ,
348- transformer = transformer ,
349- vae = vae ,
350- scheduler = scheduler ,
351- )
352- pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
502+ if "Cosmos-1.0" in args .transformer_type :
503+ save_pipeline_cosmos_1_0 (args , transformer , vae )
504+ elif "Cosmos-2.0" in args .transformer_type :
505+ save_pipeline_cosmos_2_0 (args , transformer , vae )
506+ else :
507+ assert False
0 commit comments