Skip to content
Discussion options

You must be logged in to vote

Let's say

>>> input_ids_batch = [[0, 1, 2, 3, 4, 5, 6]]
>>> logits_to_keep=3

what we want is the logits for the tokens 4, 5 and 6.

then for the logits:

>>> model(input_ids=input_ids_batch).logits
[[logits_for_1, logits_for_2, logits_for_3, logits_for_4,logits_for_5, logits_for_6, logits_for_token_after_6]]
>>> model(input_ids=input_ids_batch, logits_to_keep=1).logits
[[logits_for_token_after_6]]
>>> model(input_ids=input_ids_batch, logits_to_keep=logits_to_keep+1).logits
[[logits_for_4,logits_for_5, logits_for_6, logits_for_token_after_6]]

Now we can exclude the last token (because we don't care about the token after 6):

>>> logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it cor…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@JenWei0312
Comment options

Answer selected by JenWei0312
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants