@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
134134
135135 for validation_prompt , validation_image in zip (validation_prompts , validation_images ):
136136 validation_image = Image .open (validation_image ).convert ("RGB" )
137- validation_image = validation_image .resize ((args .resolution , args .resolution ))
137+
138+ try :
139+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper ())
140+ except (AttributeError , KeyError ):
141+ supported_interpolation_modes = [
142+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
143+ ]
144+ raise ValueError (
145+ f"Interpolation mode { args .image_interpolation_mode } is not supported. "
146+ f"Please select one of the following: { ', ' .join (supported_interpolation_modes )} "
147+ )
148+
149+ transform = transforms .Compose (
150+ [
151+ transforms .Resize (args .resolution , interpolation = interpolation ),
152+ transforms .CenterCrop (args .resolution ),
153+ ]
154+ )
155+ validation_image = transform (validation_image )
138156
139157 images = []
140158
@@ -587,6 +605,15 @@ def parse_args(input_args=None):
587605 " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
588606 ),
589607 )
608+ parser .add_argument (
609+ "--image_interpolation_mode" ,
610+ type = str ,
611+ default = "lanczos" ,
612+ choices = [
613+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
614+ ],
615+ help = "The image interpolation method to use for resizing images." ,
616+ )
590617
591618 if input_args is not None :
592619 args = parser .parse_args (input_args )
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
732759
733760
734761def prepare_train_dataset (dataset , accelerator ):
762+ try :
763+ interpolation_mode = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper ())
764+ except (AttributeError , KeyError ):
765+ supported_interpolation_modes = [
766+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
767+ ]
768+ raise ValueError (
769+ f"Interpolation mode { args .image_interpolation_mode } is not supported. "
770+ f"Please select one of the following: { ', ' .join (supported_interpolation_modes )} "
771+ )
772+
735773 image_transforms = transforms .Compose (
736774 [
737- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
775+ transforms .Resize (args .resolution , interpolation = interpolation_mode ),
738776 transforms .CenterCrop (args .resolution ),
739777 transforms .ToTensor (),
740778 transforms .Normalize ([0.5 ], [0.5 ]),
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
743781
744782 conditioning_image_transforms = transforms .Compose (
745783 [
746- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
784+ transforms .Resize (args .resolution , interpolation = interpolation_mode ),
747785 transforms .CenterCrop (args .resolution ),
748786 transforms .ToTensor (),
749787 ]
0 commit comments