-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Customizable Logit Warping Strategies for Generation #40010 #40403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
cc @gante |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool PR! I like where this feature is going 🔥
I've added a few comments, but I think they are simple to solve 🤗 Let me know if you have questions/counter-proposals regarding my comments.
Other global comments:
- Missing: documentation of
logits_processor
inGenerationConfig
. I would add a link to the example file inside the docs, your examples are quite clear! - Suggestion: in the documentation of
logits_processor
ingenerate()
, I would add a mention about the new format (and a link to the examples) - Some classes should be directly imported from the top level, i.e. we should be able to do
from transformers import LogitProcessorRegistry
. This is done by adding the classes to the outermost__init__.py
. I think the only class we need to add there isLogitProcessorRegistry
(we can already tofrom transformers import LogitsProcessor
)
@@ -58,7 +58,16 @@ | |||
|
|||
|
|||
if is_torch_available(): | |||
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor | |||
from ..cache_utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this diff seems related to a merge conflict, we can probably remove these added lines)
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
# Special handling for logit_processors - serialize to JSON string | ||
if "logit_processors" in output and output["logit_processors"] is not None: | ||
if not isinstance(output["logit_processors"], str): | ||
# Convert to JSON string | ||
output["logit_processors"] = json.dumps(output["logit_processors"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? if self.logit_processors
is a list of dicts, then returning it directly should be fine :)
@@ -1039,6 +1054,21 @@ def from_pretrained( | |||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified | |||
return config | |||
|
|||
def get_logit_processors(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this function to utils.py
?
Goal: let's keep classes as independent as possible -- the generation config doesn't need to know about the expected format that logit_processors
takes in generate
. It's mostly a data storage class.
|
||
# Enhanced LogitsProcessorList | ||
class ConfigurableLogitsProcessorList(LogitsProcessorList): | ||
"""Extended LogitsProcessorList that supports configuration-based construction.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing: an example in the docstring
src/transformers/generation/utils.py
Outdated
|
||
# Add any directly passed processors last | ||
if logits_processor is not None: | ||
processors.extend(logits_processor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merging with the custom logits_processor
already done in self._merge_criteria_processor_list
call a few lines above
src/transformers/generation/utils.py
Outdated
if configured_processors is not None: | ||
# Check for duplicates and warn | ||
existing_types = {type(p).__name__ for p in processors} | ||
config_types = {type(p).__name__ for p in configured_processors} | ||
duplicates = existing_types & config_types | ||
|
||
if duplicates: | ||
warnings.warn( | ||
f"Duplicate LogitProcessors detected: {duplicates}. " | ||
f"Configured processors will be added after standard ones." | ||
) | ||
|
||
processors.extend(configured_processors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can reuse the logic in _merge_criteria_processor_list
?
@@ -1391,3 +1403,148 @@ def check_eos_logits(out, logits, batch, channel, eos): | |||
check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos) | |||
self.assertTrue(delay_pattern_processor.active_batches.all()) | |||
self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all()) | |||
|
|||
class TestLogitProcessorRegistry(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a single class for everything related to the configurable logits processors 🤗
from transformers.generation.logits_process import LogitsProcessor, LogitProcessorRegistry | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all the imports in the example should be:
1 - at the top of the file
2 - transformers
imports should be top-level imports, i.e. from transformers import LogitsProcessor, LogitProcessorRegistry
. See the global PR comment I added about this 🤗
thanks a lot for reviewing. I will address these soon and send a new PR |
@gante this test is failing in the latest build but nothing to do with my change I believe: FAILED tests/models/gpt_oss/test_modeling_gpt_oss.py::GptOssModelTest::test_assisted_decoding_matches_greedy_search_1_same - AssertionError: False is not true any pointers for me? |
What does this PR do?
Feature request
Improve the generate() API by supporting custom, declarative logit warping strategies. Make it easier for users to plug in standard and custom LogitProcessors via configuration or arguments without needing to subclass or dive into internals.
Motivation
The generation module already supports rich logit manipulation through LogitProcessorList, but:
It is undocumented and hard to use for casual users
Requires advanced subclassing to customize behaviors (e.g., word bans, domain constraints)
Doesn’t support JSON- or dict-style configuration like many other parts of Transformers
Making logit warping more accessible enables:
Prompt engineers and power users to fine-tune generation behavior
Safer generation via blacklists or probability shifting
Dynamic controls like repetition penalties or temperature annealing
Fixes # (issue): 40010
Before submitting
Pull Request section?
to it if that's the case. Customizable Logit Warping Strategies for Generation #40010
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.