Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ class LoraConfig(PeftConfig):
"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
"are also tied. This is applicable for layers passed via "
"`modules_to_save` and `trainable_token_indices`."
)
},
)
Expand Down
76 changes: 67 additions & 9 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,46 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
f"`trainable_tokens={{'{target_layer}': x}}` but not both."
)

# Check weight tying configuration first to determine which layers to wrap
weights_tied = (
model_config.get("tie_word_embeddings", False)
# some models may be misconfigured to have weight tying enabled but don't define tied weights keys
and model._tied_weights_keys is not None
)
ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False)

# Check if we're dealing with dict format that specifies both embed_tokens and lm_head
is_dict_format = isinstance(peft_config.trainable_token_indices, dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes , re-reviewed this , and simplified the logic significantly . refrence 232c6e7 for implementation .

has_both_layers = False
indices_mismatch = False
embed_key = None
lm_head_key = None
layers_to_skip = set()

if is_dict_format and len(target_layers) > 1:
# Find embedding and lm_head keys
for key in target_layers:
key_lower = key.lower()
if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower):
embed_key = key
elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower):
lm_head_key = key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?

Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i relooked at this and removed the string matching logic (checking for "embed", "lm_head", etc.) and now directly compare target layers against model._tied_weights_keys and the actual embedding layer. The error message is now generic, showing all conflicting tied layers instead of assuming specific names.


if embed_key and lm_head_key:
has_both_layers = True
# Check if indices are different
if target_layers[embed_key] != target_layers[lm_head_key]:
indices_mismatch = True
else:
# Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later)
if weights_tied and not (not ensure_weight_tying and False): # Will apply tying
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check makes no sense to me, why and False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

layers_to_skip.add(lm_head_key)

for target_layer, token_indices in target_layers.items():
# Skip layers that will be handled by weight tying
if target_layer in layers_to_skip:
continue

_set_trainable(
model,
adapter_name,
Expand All @@ -1476,16 +1515,35 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
activate_adapter=activate_adapter,
)

# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
if (
model_config.get("tie_word_embeddings", False)
# some models may be misconfigured to have weight tying enabled but don't define tied weights keys
and model._tied_weights_keys is not None
# Case 1: weights NOT tied + ensure_weight_tying=True -> WARNING
if not weights_tied and ensure_weight_tying:
warnings.warn(
"ensure_weight_tying=True but the model does not have tied weights "
"(tie_word_embeddings=False). Weight tying will not be applied for trainable_token_indices."
)

# Case 2: weights tied + ensure_weight_tying=True + different indices -> ERROR
if weights_tied and ensure_weight_tying and has_both_layers and indices_mismatch:
raise ValueError(
f"Cannot ensure weight tying when different token indices are specified for "
f"embedding ({embed_key}: {target_layers[embed_key]}) and "
f"lm_head ({lm_head_key}: {target_layers[lm_head_key]}). "
f"Please use the same indices for both layers or set ensure_weight_tying=False."
)

# Case 3: Apply weight tying when appropriate
# - When weights are tied AND we should apply tying
# - Apply tying unless: weights tied + ensure_weight_tying=False + different indices (BC: treat as separate)
should_apply_tying = (
weights_tied
and isinstance(model.get_input_embeddings(), TrainableTokensWrapper)
):
# the embedding layer is modified and we want weight tying.
and not (not ensure_weight_tying and has_both_layers and indices_mismatch)
)

if should_apply_tying:
# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
module_keys = [".".join(n.split(".")[:-1]) for n in model._tied_weights_keys]

token_adapter = model.get_input_embeddings().token_adapter
Expand Down
116 changes: 116 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,3 +978,119 @@ def test_scaled_embedding_with_lora(self):
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(x)
assert (embedding_output == 0.0).all()

# Tests for ensure_weight_tying parameter with trainable_token_indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention #2864 here, I think it helps understanding the tests.

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! I've included the comment # See #2864 for details on the expected behavior at the beginning of the ensure_weight_tying test section. This helps readers understand the context and refer back to the original issue for the full specification.

def test_ensure_weight_tying_warns_when_model_not_tied_list_format(self, model_weight_untied, recwarn):
"""Should warn when ensure_weight_tying=True but model doesn't have tied weights (list format)"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=[1, 2, 3],
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_untied, peft_config)

warnings_list = [w.message.args[0] for w in recwarn]
warnings_found = [
msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg
]
assert warnings_found
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit more elegant to do:

expected = ...
assert any(expected in msg for msg in warings_list)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7


def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

"""Should warn when ensure_weight_tying=True with dict format but model doesn't have tied weights"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"embed_tokens": [1, 2, 3]},
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_untied, peft_config)

warnings_list = [w.message.args[0] for w in recwarn]
warnings_found = [
msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg
]
assert warnings_found

def test_weight_tying_bc_different_indices_treated_separately(self, model_weight_tied):
"""Backwards compatibility: different indices should be treated separately when ensure_weight_tying=False"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]},
ensure_weight_tying=False, # BC behavior
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Check that both layers have token adapters but they're NOT tied
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

assert embed_adapter is not None
assert lm_head_adapter is not None
# They should NOT share the same delta parameters (treated as separate)
assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta
# They should have different token indices
assert embed_adapter.token_indices["default"] == [3, 4]
assert lm_head_adapter.token_indices["default"] == [1, 2]

def test_ensure_weight_tying_errors_with_different_indices(self, model_weight_tied):
"""Should raise error when ensure_weight_tying=True with different indices for embedding and lm_head"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]},
ensure_weight_tying=True,
)

with pytest.raises(ValueError) as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use:

msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 232c6e7

peft_model = get_peft_model(model_weight_tied, peft_config)

assert "Cannot ensure weight tying when different token indices are specified" in str(e.value)

def test_ensure_weight_tying_applied_with_same_indices(self, model_weight_tied):
"""Should apply weight tying when ensure_weight_tying=True with same indices"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]},
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Check that weight tying is properly applied
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters (weight tying)
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta
# They should have the same token indices
assert embed_adapter.token_indices["default"] == [1, 2]
assert lm_head_adapter.token_indices["default"] == [1, 2]

def test_weight_tying_bc_same_indices_applied(self, model_weight_tied):
"""Backwards compatibility: same indices should have weight tying even when ensure_weight_tying=False"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]},
ensure_weight_tying=False, # BC: still applies tying when indices are the same
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Even with ensure_weight_tying=False, BC behavior should still tie when indices are same
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters (BC behavior)
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta

def test_ensure_weight_tying_with_single_layer(self, model_weight_tied):
"""ensure_weight_tying should work with single layer (list format)"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=[1, 2, 3],
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Should apply weight tying to tied layers automatically
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta