@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
122122
123123 for _ in range (args .num_validation_images ):
124124 with autocast_ctx :
125- # need to fix in pipeline_flux_controlnet
126125 image = pipeline (
127126 prompt = validation_prompt ,
128127 control_image = validation_image ,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
159158 images = log ["images" ]
160159 validation_prompt = log ["validation_prompt" ]
161160 validation_image = log ["validation_image" ]
162- formatted_images .append (wandb .Image (validation_image , caption = "Controlnet conditioning " ))
161+ formatted_images .append (wandb .Image (validation_image , caption = "Conditioning " ))
163162 for image in images :
164163 image = wandb .Image (image , caption = validation_prompt )
165164 formatted_images .append (image )
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
188187 img_str += f"\n "
189188
190189 model_description = f"""
191- # control-lora -{ repo_id }
190+ # flux-control -{ repo_id }
192191
193192These are Control weights trained on { base_model } with new type of conditioning.
194193{ img_str }
@@ -434,14 +433,15 @@ def parse_args(input_args=None):
434433 "--conditioning_image_column" ,
435434 type = str ,
436435 default = "conditioning_image" ,
437- help = "The column of the dataset containing the controlnet conditioning image." ,
436+ help = "The column of the dataset containing the control conditioning image." ,
438437 )
439438 parser .add_argument (
440439 "--caption_column" ,
441440 type = str ,
442441 default = "text" ,
443442 help = "The column of the dataset containing a caption or a list of captions." ,
444443 )
444+ parser .add_argument ("--log_dataset_samples" , action = "store_true" , help = "Whether to log somple dataset samples." )
445445 parser .add_argument (
446446 "--max_train_samples" ,
447447 type = int ,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
468468 default = None ,
469469 nargs = "+" ,
470470 help = (
471- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
471+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
472472 " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
473473 " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
474474 " `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
505505 default = None ,
506506 help = "Path to the jsonl file containing the training data." ,
507507 )
508-
508+ parser .add_argument (
509+ "--only_target_transformer_blocks" ,
510+ action = "store_true" ,
511+ help = "If we should only target the transformer blocks to train along with the input layer (`x_embedder`)." ,
512+ )
509513 parser .add_argument (
510514 "--guidance_scale" ,
511515 type = float ,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
581585
582586 if args .resolution % 8 != 0 :
583587 raise ValueError (
584- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder ."
588+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer ."
585589 )
586590
587591 return args
@@ -665,7 +669,12 @@ def preprocess_train(examples):
665669 conditioning_images = [image_transforms (image ) for image in conditioning_images ]
666670 examples ["pixel_values" ] = images
667671 examples ["conditioning_pixel_values" ] = conditioning_images
668- examples ["captions" ] = list (examples [args .caption_column ])
672+
673+ is_caption_list = isinstance (examples [args .caption_column ][0 ], list )
674+ if is_caption_list :
675+ examples ["captions" ] = [max (example , key = len ) for example in examples [args .caption_column ]]
676+ else :
677+ examples ["captions" ] = list (examples [args .caption_column ])
669678
670679 return examples
671680
@@ -765,7 +774,8 @@ def main(args):
765774 subfolder = "scheduler" ,
766775 )
767776 noise_scheduler_copy = copy .deepcopy (noise_scheduler )
768- flux_transformer .requires_grad_ (True )
777+ if not args .only_target_transformer_blocks :
778+ flux_transformer .requires_grad_ (True )
769779 vae .requires_grad_ (False )
770780
771781 # cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
797807 assert torch .all (flux_transformer .x_embedder .weight [:, initial_input_channels :].data == 0 )
798808 flux_transformer .register_to_config (in_channels = initial_input_channels * 2 , out_channels = initial_input_channels )
799809
810+ if args .only_target_transformer_blocks :
811+ flux_transformer .x_embedder .requires_grad_ (True )
812+ for name , module in flux_transformer .named_modules ():
813+ if "transformer_blocks" in name :
814+ module .requires_grad_ (True )
815+
800816 def unwrap_model (model ):
801817 model = accelerator .unwrap_model (model )
802818 model = model ._orig_mod if is_compiled_module (model ) else model
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
974990 else :
975991 initial_global_step = 0
976992
993+ if accelerator .is_main_process and args .report_to == "wandb" and args .log_dataset_samples :
994+ logger .info ("Logging some dataset samples." )
995+ formatted_images = []
996+ formatted_control_images = []
997+ all_prompts = []
998+ for i , batch in enumerate (train_dataloader ):
999+ images = (batch ["pixel_values" ] + 1 ) / 2
1000+ control_images = (batch ["conditioning_pixel_values" ] + 1 ) / 2
1001+ prompts = batch ["captions" ]
1002+
1003+ if len (formatted_images ) > 10 :
1004+ break
1005+
1006+ for img , control_img , prompt in zip (images , control_images , prompts ):
1007+ formatted_images .append (img )
1008+ formatted_control_images .append (control_img )
1009+ all_prompts .append (prompt )
1010+
1011+ logged_artifacts = []
1012+ for img , control_img , prompt in zip (formatted_images , formatted_control_images , all_prompts ):
1013+ logged_artifacts .append (wandb .Image (control_img , caption = "Conditioning" ))
1014+ logged_artifacts .append (wandb .Image (img , caption = prompt ))
1015+
1016+ wandb_tracker = [tracker for tracker in accelerator .trackers if tracker .name == "wandb" ]
1017+ wandb_tracker [0 ].log ({"dataset_samples" : logged_artifacts })
1018+
9771019 progress_bar = tqdm (
9781020 range (0 , args .max_train_steps ),
9791021 initial = initial_global_step ,
0 commit comments