From 3817aae2ad4d065f791ea419a35d9ef0609288e6 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:54:35 +0530 Subject: [PATCH 1/5] updated encode prompt and clip encod prompt --- examples/dreambooth/train_dreambooth_sd3.py | 27 ++++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 5d10345304ab..518913da27cd 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -902,20 +902,26 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] @@ -937,6 +943,7 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -945,13 +952,14 @@ def encode_prompt( clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): + for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt, device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) @@ -965,6 +973,7 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, device=device if device is not None else text_encoders[-1].device, ) From 0fba12dcb2dd4da800e1628585844bc9feb2a3b0 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:35:45 +0530 Subject: [PATCH 2/5] empty comit --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 518913da27cd..61ea755829af 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -936,7 +936,7 @@ def _encode_prompt_with_clip( return prompt_embeds, pooled_prompt_embeds -def encode_prompt( +def _( text_encoders, tokenizers, prompt: str, From fddb463af64cc057cd6679c35c9a19a4af1a4562 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:36:34 +0530 Subject: [PATCH 3/5] fix comit --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 61ea755829af..518913da27cd 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -936,7 +936,7 @@ def _encode_prompt_with_clip( return prompt_embeds, pooled_prompt_embeds -def _( +def encode_prompt( text_encoders, tokenizers, prompt: str, From 7a02c1790f1e8603392a062474d557d27f3f6411 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:03:20 +0530 Subject: [PATCH 4/5] fix text_input_ids --- examples/dreambooth/train_dreambooth_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 518913da27cd..67275405d20e 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -973,7 +973,6 @@ def encode_prompt( max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, device=device if device is not None else text_encoders[-1].device, ) From 04f03654fb6acad300bdec438b4fc6c69356e35d Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Fri, 1 Nov 2024 23:24:50 +0530 Subject: [PATCH 5/5] This is an empty commit