Question about token alignment and logits handling in GRPO implementation(grpo_trainer.py) #3411
-
Hello TRL team, I'm studying the GRPO implementation(in grpo_trainer.py) to understand how it aligns with the original paper, and I have a question about token handling in the Specifically, I'm trying to understand the logic behind these operations: # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
logits = logits[:, -logits_to_keep:] I understand that autoregressive models predict the next token, but I'm trying to reconcile:
I'm implementing my own version for research purposes and want to make sure I understand the exact tensor alignment process. Thank you for any clarification you can provide! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Let's say >>> input_ids_batch = [[0, 1, 2, 3, 4, 5, 6]]
>>> logits_to_keep=3 what we want is the logits for the tokens 4, 5 and 6. then for the logits: >>> model(input_ids=input_ids_batch).logits
[[logits_for_1, logits_for_2, logits_for_3, logits_for_4,logits_for_5, logits_for_6, logits_for_token_after_6]]
>>> model(input_ids=input_ids_batch, logits_to_keep=1).logits
[[logits_for_token_after_6]]
>>> model(input_ids=input_ids_batch, logits_to_keep=logits_to_keep+1).logits
[[logits_for_4,logits_for_5, logits_for_6, logits_for_token_after_6]] Now we can exclude the last token (because we don't care about the token after 6): >>> logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
[[logits_for_4,logits_for_5, logits_for_6]] and for the >>> input_ids_batch[:, -logits_to_keep:]
[[4, 5, 6]] and you can see the input ids and logits are now aligned, and we indeed get the last 3 logits and ids. I hope Ive answered your question. |
Beta Was this translation helpful? Give feedback.
Let's say
what we want is the logits for the tokens 4, 5 and 6.
then for the logits:
Now we can exclude the last token (because we don't care about the token after 6):