-
Notifications
You must be signed in to change notification settings - Fork 189
[TextGeneration] max token refactor #1217
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The good: much less additional code and complexity that I thought
The bad and ugly: could you add appropriate tests in tests/deepsparse/transformers/pipelines/test_text_generation.py ?
d1d7b7a
to
735d33d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could max_tokens
default to the prompt_length - sequence_length
so we don't risk running out of kv cache context? I'm not sure what happens there actually, especially when using internal kv cache
@dbogunowicz would like to get your opinion on this |
@dsikka this is a very good idea. |
- Remove max_generated_tokens from the constructor and add it to the TextGenerationInput Schema - Add num_generated_predictions to the TextGenerationInput which if > 1, repeats the input sequence and turns off deterministic prediction. If a sequence is already provided multiple times, the sequence is not repeated.
55d1c08
to
b5de75f
Compare
Talking to the MLE team, I think for now we want to keep the defaults as is and update them once we've established best practices. |
For this ticket:
https://app.asana.com/0/1201735099598270/1205276886236972/f
Summary:
max_tokens
argument and makes it part of the pipeline inputnum_generated_predictions
to the input as well, which dictates the number of sequences that are generated for a given input. Similar to the hugging face implementation, we repeat the input based on the number provided, defaulting to 1. Whennum_generated_predictions
is > 1, the engine'sdeterministic
property is togged to Falsestr, List[str], and List[List[str]]
Testing
num_generated_predictions