Skip to content

Conversation

PamelaBha
Copy link

What does this PR do?

Feature request
Improve the generate() API by supporting custom, declarative logit warping strategies. Make it easier for users to plug in standard and custom LogitProcessors via configuration or arguments without needing to subclass or dive into internals.

Motivation
The generation module already supports rich logit manipulation through LogitProcessorList, but:

It is undocumented and hard to use for casual users
Requires advanced subclassing to customize behaviors (e.g., word bans, domain constraints)
Doesn’t support JSON- or dict-style configuration like many other parts of Transformers
Making logit warping more accessible enables:

Prompt engineers and power users to fine-tune generation behavior
Safer generation via blacklists or probability shifting
Dynamic controls like repetition penalties or temperature annealing

Fixes # (issue): 40010

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

cc @gante

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR! I like where this feature is going 🔥

I've added a few comments, but I think they are simple to solve 🤗 Let me know if you have questions/counter-proposals regarding my comments.

Other global comments:

  • Missing: documentation of logits_processor in GenerationConfig. I would add a link to the example file inside the docs, your examples are quite clear!
  • Suggestion: in the documentation of logits_processor in generate(), I would add a mention about the new format (and a link to the examples)
  • Some classes should be directly imported from the top level, i.e. we should be able to do from transformers import LogitProcessorRegistry. This is done by adding the classes to the outermost __init__.py. I think the only class we need to add there is LogitProcessorRegistry (we can already to from transformers import LogitsProcessor)

@@ -58,7 +58,16 @@


if is_torch_available():
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
from ..cache_utils import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this diff seems related to a merge conflict, we can probably remove these added lines)

Comment on lines 568 to 570



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines 1164 to 1168
# Special handling for logit_processors - serialize to JSON string
if "logit_processors" in output and output["logit_processors"] is not None:
if not isinstance(output["logit_processors"], str):
# Convert to JSON string
output["logit_processors"] = json.dumps(output["logit_processors"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? if self.logit_processors is a list of dicts, then returning it directly should be fine :)

@@ -1039,6 +1054,21 @@ def from_pretrained(
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config

def get_logit_processors(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this function to utils.py?

Goal: let's keep classes as independent as possible -- the generation config doesn't need to know about the expected format that logit_processors takes in generate. It's mostly a data storage class.


# Enhanced LogitsProcessorList
class ConfigurableLogitsProcessorList(LogitsProcessorList):
"""Extended LogitsProcessorList that supports configuration-based construction."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing: an example in the docstring


# Add any directly passed processors last
if logits_processor is not None:
processors.extend(logits_processor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merging with the custom logits_processor already done in self._merge_criteria_processor_list call a few lines above

Comment on lines 1313 to 1325
if configured_processors is not None:
# Check for duplicates and warn
existing_types = {type(p).__name__ for p in processors}
config_types = {type(p).__name__ for p in configured_processors}
duplicates = existing_types & config_types

if duplicates:
warnings.warn(
f"Duplicate LogitProcessors detected: {duplicates}. "
f"Configured processors will be added after standard ones."
)

processors.extend(configured_processors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can reuse the logic in _merge_criteria_processor_list ?

@@ -1391,3 +1403,148 @@ def check_eos_logits(out, logits, batch, channel, eos):
check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos)
self.assertTrue(delay_pattern_processor.active_batches.all())
self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all())

class TestLogitProcessorRegistry(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a single class for everything related to the configurable logits processors 🤗

Comment on lines 69 to 70
from transformers.generation.logits_process import LogitsProcessor, LogitProcessorRegistry
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all the imports in the example should be:
1 - at the top of the file
2 - transformers imports should be top-level imports, i.e. from transformers import LogitsProcessor, LogitProcessorRegistry. See the global PR comment I added about this 🤗

@PamelaBha
Copy link
Author

Very cool PR! I like where this feature is going 🔥

I've added a few comments, but I think they are simple to solve 🤗 Let me know if you have questions/counter-proposals regarding my comments.

Other global comments:

  • Missing: documentation of logits_processor in GenerationConfig. I would add a link to the example file inside the docs, your examples are quite clear!
  • Suggestion: in the documentation of logits_processor in generate(), I would add a mention about the new format (and a link to the examples)
  • Some classes should be directly imported from the top level, i.e. we should be able to do from transformers import LogitProcessorRegistry. This is done by adding the classes to the outermost __init__.py. I think the only class we need to add there is LogitProcessorRegistry (we can already to from transformers import LogitsProcessor)

thanks a lot for reviewing. I will address these soon and send a new PR

@PamelaBha
Copy link
Author

@gante this test is failing in the latest build but nothing to do with my change I believe: FAILED tests/models/gpt_oss/test_modeling_gpt_oss.py::GptOssModelTest::test_assisted_decoding_matches_greedy_search_1_same - AssertionError: False is not true

any pointers for me?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants