-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,7 +1464,46 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n | |
| f"`trainable_tokens={{'{target_layer}': x}}` but not both." | ||
| ) | ||
|
|
||
| # Check weight tying configuration first to determine which layers to wrap | ||
| weights_tied = ( | ||
| model_config.get("tie_word_embeddings", False) | ||
| # some models may be misconfigured to have weight tying enabled but don't define tied weights keys | ||
| and model._tied_weights_keys is not None | ||
| ) | ||
| ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False) | ||
|
|
||
| # Check if we're dealing with dict format that specifies both embed_tokens and lm_head | ||
| is_dict_format = isinstance(peft_config.trainable_token_indices, dict) | ||
| has_both_layers = False | ||
| indices_mismatch = False | ||
| embed_key = None | ||
| lm_head_key = None | ||
| layers_to_skip = set() | ||
|
|
||
| if is_dict_format and len(target_layers) > 1: | ||
| # Find embedding and lm_head keys | ||
| for key in target_layers: | ||
| key_lower = key.lower() | ||
| if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower): | ||
| embed_key = key | ||
| elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower): | ||
| lm_head_key = key | ||
|
||
|
|
||
| if embed_key and lm_head_key: | ||
| has_both_layers = True | ||
| # Check if indices are different | ||
| if target_layers[embed_key] != target_layers[lm_head_key]: | ||
| indices_mismatch = True | ||
| else: | ||
| # Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later) | ||
| if weights_tied and not (not ensure_weight_tying and False): # Will apply tying | ||
|
||
| layers_to_skip.add(lm_head_key) | ||
|
|
||
| for target_layer, token_indices in target_layers.items(): | ||
| # Skip layers that will be handled by weight tying | ||
| if target_layer in layers_to_skip: | ||
| continue | ||
|
|
||
| _set_trainable( | ||
| model, | ||
| adapter_name, | ||
|
|
@@ -1476,16 +1515,35 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n | |
| activate_adapter=activate_adapter, | ||
| ) | ||
|
|
||
| # There might be the possibility that we have output weights that are tied to the input weights. | ||
| # In that case we will tie any module that wants tied weights to the token adapter to make sure that | ||
| # any modification is reflected in the tied layers as well. | ||
| if ( | ||
| model_config.get("tie_word_embeddings", False) | ||
| # some models may be misconfigured to have weight tying enabled but don't define tied weights keys | ||
| and model._tied_weights_keys is not None | ||
| # Case 1: weights NOT tied + ensure_weight_tying=True -> WARNING | ||
| if not weights_tied and ensure_weight_tying: | ||
| warnings.warn( | ||
| "ensure_weight_tying=True but the model does not have tied weights " | ||
| "(tie_word_embeddings=False). Weight tying will not be applied for trainable_token_indices." | ||
| ) | ||
|
|
||
| # Case 2: weights tied + ensure_weight_tying=True + different indices -> ERROR | ||
| if weights_tied and ensure_weight_tying and has_both_layers and indices_mismatch: | ||
| raise ValueError( | ||
| f"Cannot ensure weight tying when different token indices are specified for " | ||
| f"embedding ({embed_key}: {target_layers[embed_key]}) and " | ||
| f"lm_head ({lm_head_key}: {target_layers[lm_head_key]}). " | ||
| f"Please use the same indices for both layers or set ensure_weight_tying=False." | ||
| ) | ||
|
|
||
| # Case 3: Apply weight tying when appropriate | ||
| # - When weights are tied AND we should apply tying | ||
| # - Apply tying unless: weights tied + ensure_weight_tying=False + different indices (BC: treat as separate) | ||
| should_apply_tying = ( | ||
| weights_tied | ||
| and isinstance(model.get_input_embeddings(), TrainableTokensWrapper) | ||
| ): | ||
| # the embedding layer is modified and we want weight tying. | ||
| and not (not ensure_weight_tying and has_both_layers and indices_mismatch) | ||
| ) | ||
|
|
||
| if should_apply_tying: | ||
| # There might be the possibility that we have output weights that are tied to the input weights. | ||
| # In that case we will tie any module that wants tied weights to the token adapter to make sure that | ||
| # any modification is reflected in the tied layers as well. | ||
| module_keys = [".".join(n.split(".")[:-1]) for n in model._tied_weights_keys] | ||
|
|
||
| token_adapter = model.get_input_embeddings().token_adapter | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -978,3 +978,119 @@ def test_scaled_embedding_with_lora(self): | |
| orig_embedding.embed_scale.fill_(0) | ||
| embedding_output = peft_embedding(x) | ||
| assert (embedding_output == 0.0).all() | ||
|
|
||
| # Tests for ensure_weight_tying parameter with trainable_token_indices | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's mention #2864 here, I think it helps understanding the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! I've included the comment |
||
| def test_ensure_weight_tying_warns_when_model_not_tied_list_format(self, model_weight_untied, recwarn): | ||
| """Should warn when ensure_weight_tying=True but model doesn't have tied weights (list format)""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices=[1, 2, 3], | ||
| ensure_weight_tying=True, | ||
| ) | ||
| peft_model = get_peft_model(model_weight_untied, peft_config) | ||
|
|
||
| warnings_list = [w.message.args[0] for w in recwarn] | ||
| warnings_found = [ | ||
| msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg | ||
| ] | ||
| assert warnings_found | ||
|
||
|
|
||
| def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn): | ||
|
||
| """Should warn when ensure_weight_tying=True with dict format but model doesn't have tied weights""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices={"embed_tokens": [1, 2, 3]}, | ||
| ensure_weight_tying=True, | ||
| ) | ||
| peft_model = get_peft_model(model_weight_untied, peft_config) | ||
|
|
||
| warnings_list = [w.message.args[0] for w in recwarn] | ||
| warnings_found = [ | ||
| msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg | ||
| ] | ||
| assert warnings_found | ||
|
|
||
| def test_weight_tying_bc_different_indices_treated_separately(self, model_weight_tied): | ||
| """Backwards compatibility: different indices should be treated separately when ensure_weight_tying=False""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]}, | ||
| ensure_weight_tying=False, # BC behavior | ||
| ) | ||
| peft_model = get_peft_model(model_weight_tied, peft_config) | ||
|
|
||
| # Check that both layers have token adapters but they're NOT tied | ||
| embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter | ||
| lm_head_adapter = peft_model.model.lm_head.token_adapter | ||
|
|
||
| assert embed_adapter is not None | ||
| assert lm_head_adapter is not None | ||
| # They should NOT share the same delta parameters (treated as separate) | ||
| assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta | ||
| # They should have different token indices | ||
| assert embed_adapter.token_indices["default"] == [3, 4] | ||
| assert lm_head_adapter.token_indices["default"] == [1, 2] | ||
|
|
||
| def test_ensure_weight_tying_errors_with_different_indices(self, model_weight_tied): | ||
| """Should raise error when ensure_weight_tying=True with different indices for embedding and lm_head""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]}, | ||
| ensure_weight_tying=True, | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError) as e: | ||
|
||
| peft_model = get_peft_model(model_weight_tied, peft_config) | ||
|
|
||
| assert "Cannot ensure weight tying when different token indices are specified" in str(e.value) | ||
|
|
||
| def test_ensure_weight_tying_applied_with_same_indices(self, model_weight_tied): | ||
| """Should apply weight tying when ensure_weight_tying=True with same indices""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]}, | ||
| ensure_weight_tying=True, | ||
| ) | ||
| peft_model = get_peft_model(model_weight_tied, peft_config) | ||
|
|
||
| # Check that weight tying is properly applied | ||
| embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter | ||
| lm_head_adapter = peft_model.model.lm_head.token_adapter | ||
|
|
||
| # They should share the same delta parameters (weight tying) | ||
| assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta | ||
| # They should have the same token indices | ||
| assert embed_adapter.token_indices["default"] == [1, 2] | ||
| assert lm_head_adapter.token_indices["default"] == [1, 2] | ||
|
|
||
| def test_weight_tying_bc_same_indices_applied(self, model_weight_tied): | ||
| """Backwards compatibility: same indices should have weight tying even when ensure_weight_tying=False""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]}, | ||
| ensure_weight_tying=False, # BC: still applies tying when indices are the same | ||
| ) | ||
| peft_model = get_peft_model(model_weight_tied, peft_config) | ||
|
|
||
| # Even with ensure_weight_tying=False, BC behavior should still tie when indices are same | ||
| embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter | ||
| lm_head_adapter = peft_model.model.lm_head.token_adapter | ||
|
|
||
| # They should share the same delta parameters (BC behavior) | ||
| assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta | ||
|
|
||
| def test_ensure_weight_tying_with_single_layer(self, model_weight_tied): | ||
| """ensure_weight_tying should work with single layer (list format)""" | ||
| peft_config = LoraConfig( | ||
| target_modules="all-linear", | ||
| trainable_token_indices=[1, 2, 3], | ||
| ensure_weight_tying=True, | ||
| ) | ||
| peft_model = get_peft_model(model_weight_tied, peft_config) | ||
|
|
||
| # Should apply weight tying to tied layers automatically | ||
| embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter | ||
| lm_head_adapter = peft_model.model.lm_head.token_adapter | ||
|
|
||
| # They should share the same delta parameters | ||
| assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need
is_dict_format. The check below,len(target_layers) > 1, is already enough, is it not?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes , re-reviewed this , and simplified the logic significantly . refrence 232c6e7 for implementation .