Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref


def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}

def convert_key(key: str) -> str:
prefix = "transformer_blocks"
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""

start = f"{prefix}_"
rest = base[len(start) :]

if "." in rest:
head, tail = rest.split(".", 1)
tail = "." + tail
else:
head, tail = rest, ""

# Protected n-grams that must keep their internal underscores
protected = {
# pairs
("to", "q"),
("to", "k"),
("to", "v"),
("to", "out"),
("add", "q"),
("add", "k"),
("add", "v"),
("txt", "mlp"),
("img", "mlp"),
("txt", "mod"),
("img", "mod"),
# triplets
("add", "q", "proj"),
("add", "k", "proj"),
("add", "v", "proj"),
("to", "add", "out"),
}

prot_by_len = {}
for ng in protected:
prot_by_len.setdefault(len(ng), set()).add(ng)

parts = head.split("_")
merged = []
i = 0
lengths_desc = sorted(prot_by_len.keys(), reverse=True)

while i < len(parts):
matched = False
for L in lengths_desc:
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
merged.append("_".join(parts[i : i + L]))
i += L
matched = True
break
if not matched:
merged.append(parts[i])
i += 1

head_converted = ".".join(merged)
converted_base = f"{prefix}.{head_converted}{tail}"
return converted_base + (("." + suffix) if suffix else "")

state_dict = {convert_key(k): v for k, v in state_dict.items()}

converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6643,7 +6643,8 @@ def lora_state_dict(
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_alphas_in_sd or has_lora_unet:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
Expand Down
Loading