Skip to content

Commit 1f2291f

Browse files
committed
add test_speculate_get_token_penalty_multi_scores
1 parent 41aee08 commit 1f2291f

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import speculate_get_token_penalty_multi_scores
21+
22+
23+
def min_length_logits_process(
24+
logits,
25+
cur_len,
26+
min_len,
27+
eos_token_id,
28+
output_padding_offset,
29+
output_cum_offsets,
30+
token_num,
31+
bs,
32+
length,
33+
end_length,
34+
max_seq_len,
35+
):
36+
for token_idx in range(token_num):
37+
bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len
38+
bi = bi.astype(paddle.int32)
39+
if bi >= bs:
40+
continue
41+
query_start_token_idx = bi * max_seq_len - output_cum_offsets[bi]
42+
43+
if cur_len[bi] < 0:
44+
continue
45+
if cur_len[bi] + (token_idx - query_start_token_idx) < min_len[bi]:
46+
for i in range(end_length):
47+
logits[token_idx][eos_token_id[i]] = -1e10
48+
49+
50+
def update_repeat_times(
51+
pre_ids, cur_len, repeat_times, output_padding_offset, token_num, bs, length, length_id, max_seq_len
52+
):
53+
for token_idx in range(token_num):
54+
bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len
55+
bi = bi.astype(paddle.int32)
56+
if bi >= bs:
57+
continue
58+
if cur_len[bi] < 0:
59+
continue
60+
61+
pre_ids_now = pre_ids[bi]
62+
repeat_times_now = repeat_times[token_idx]
63+
64+
for i in range(length_id):
65+
id = pre_ids_now[i]
66+
if id < 0:
67+
break
68+
repeat_times_now[id] = repeat_times_now[id] + 1
69+
70+
71+
def update_value_by_repeat_times(
72+
repeat_times,
73+
penalty_scores,
74+
frequency_score,
75+
presence_score,
76+
temperatures,
77+
logits,
78+
output_padding_offset,
79+
token_num,
80+
bs,
81+
length,
82+
max_seq_len,
83+
):
84+
for token_idx in range(token_num):
85+
bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len
86+
bi = bi.astype(paddle.int32)
87+
if bi >= bs:
88+
continue
89+
logits_now = logits[token_idx]
90+
repeat_times_now = repeat_times[token_idx]
91+
alpha = penalty_scores[bi]
92+
beta = frequency_score[bi]
93+
gamma = presence_score[bi]
94+
for i in range(length):
95+
times = repeat_times_now[i]
96+
logit_now = logits_now[i]
97+
if times != 0:
98+
logit_now = logit_now * alpha if logit_now < 0 else logit_now / alpha
99+
logit_now = logit_now - times * beta - gamma
100+
101+
logits_now[i] = logit_now / temperatures[bi]
102+
103+
104+
def ban_bad_words(logits, bad_words_list, output_padding_offset, token_num, bs, length, bad_words_length, max_seq_len):
105+
for token_idx in range(token_num):
106+
bi = (token_idx + output_padding_offset[token_idx]) / max_seq_len
107+
bi = bi.astype(paddle.int32)
108+
if bi >= bs:
109+
continue
110+
logits_now = logits[token_idx]
111+
for i in range(bad_words_length):
112+
bad_words_token_id = bad_words_list[i]
113+
if bad_words_token_id >= length or bad_words_token_id < 0:
114+
continue
115+
logits_now[bad_words_token_id] = -1e10
116+
117+
118+
def speculate_get_token_penalty_multi_scores_ref(
119+
pre_ids,
120+
logits,
121+
penalty_scores,
122+
frequency_score,
123+
presence_score,
124+
temperatures,
125+
bad_tokens,
126+
cur_len,
127+
min_len,
128+
eos_token_id,
129+
seq_lens_this_time,
130+
output_padding_offset,
131+
output_cum_offsets,
132+
max_seq_len,
133+
):
134+
shape = logits.shape
135+
repeat_times = paddle.full(shape, 0, dtype=paddle.int32)
136+
bs = seq_lens_this_time.shape[0]
137+
token_num = shape[0]
138+
length = shape[1]
139+
length_id = pre_ids.shape[1]
140+
length_bad_words = bad_tokens.shape[1]
141+
142+
end_length = eos_token_id.shape[0]
143+
144+
min_length_logits_process(
145+
logits,
146+
cur_len,
147+
min_len,
148+
eos_token_id,
149+
output_padding_offset,
150+
output_cum_offsets,
151+
token_num,
152+
bs,
153+
length,
154+
end_length,
155+
max_seq_len,
156+
)
157+
158+
update_repeat_times(
159+
pre_ids, cur_len, repeat_times, output_padding_offset, token_num, bs, length, length_id, max_seq_len
160+
)
161+
162+
update_value_by_repeat_times(
163+
repeat_times,
164+
penalty_scores,
165+
frequency_score,
166+
presence_score,
167+
temperatures,
168+
logits,
169+
output_padding_offset,
170+
token_num,
171+
bs,
172+
length,
173+
max_seq_len,
174+
)
175+
176+
ban_bad_words(logits, bad_tokens, output_padding_offset, token_num, bs, length, length_bad_words, max_seq_len)
177+
178+
179+
class TestSpeculateGetTokenPenaltyMultiScores(unittest.TestCase):
180+
def test_speculate_get_token_penalty_multi_scores(self):
181+
paddle.seed(2023)
182+
np.random.seed(2023)
183+
184+
bs = 64
185+
max_seq_len = 1024 # 1024 #2048 #8192
186+
data_type = "float32"
187+
188+
# prepare output_padding_offset and output_cum_offsets
189+
tokens = [1] * bs
190+
token_num = np.sum(tokens)
191+
output_padding_offset = []
192+
output_cum_offsets = [0]
193+
opo_offset = 0
194+
for bid in range(bs):
195+
ts = tokens[bid]
196+
for i in range(ts):
197+
output_padding_offset.append(opo_offset)
198+
opo_offset += max_seq_len - ts
199+
output_cum_offsets.append(opo_offset)
200+
output_cum_offsets = output_cum_offsets[:-1]
201+
output_padding_offset = paddle.to_tensor(output_padding_offset, "int32")
202+
output_cum_offsets = paddle.to_tensor(output_cum_offsets, "int32")
203+
204+
# prepare pre_ids and logits
205+
pre_ids_len = 122
206+
logits_len = 110
207+
pre_ids = np.random.randint(1, logits_len, size=(bs, pre_ids_len))
208+
negative_start = np.random.randint(1, pre_ids_len + 1, size=(bs))
209+
for i in range(bs):
210+
pre_ids[:, negative_start[i] :] = -1
211+
pre_ids = paddle.to_tensor(pre_ids).astype("int64")
212+
logits = paddle.zeros([token_num, logits_len]).astype(data_type)
213+
# prepare other params
214+
penalty_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
215+
frequency_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
216+
presence_scores = paddle.to_tensor(np.random.random([bs])).astype(data_type)
217+
temperatures = paddle.to_tensor(np.random.random([bs])).astype("float32")
218+
bad_tokens = paddle.to_tensor(np.random.randint(1, 2, size=([bs, 1]))).astype("int64")
219+
cur_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64")
220+
min_len = paddle.to_tensor(np.random.randint(1, 50, size=(bs))).astype("int64")
221+
eos_token_id = paddle.to_tensor(np.random.randint(1, 64, size=(bs))).astype("int64")
222+
seq_len_this_time = paddle.to_tensor(
223+
np.random.randint(0, 1, size=(bs)), "int32"
224+
) # value of seq_len_this_time is useless
225+
226+
inputs = (
227+
pre_ids,
228+
logits,
229+
penalty_scores,
230+
frequency_scores,
231+
presence_scores,
232+
temperatures,
233+
bad_tokens,
234+
cur_len,
235+
min_len,
236+
eos_token_id,
237+
seq_len_this_time,
238+
output_padding_offset,
239+
output_cum_offsets,
240+
max_seq_len,
241+
)
242+
# inplace modify, not return data
243+
inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs]
244+
speculate_get_token_penalty_multi_scores(*inputs)
245+
speculate_get_token_penalty_multi_scores_ref(*inputs_clone)
246+
247+
np.testing.assert_allclose(inputs[1].numpy(), inputs_clone[1].numpy(), atol=1e-5, rtol=1e-5)
248+
# logits_ref = np.array([0.000000e00, -7.603661e00, -1.227168e01, -8.381664e00]).astype(data_type)
249+
# np.testing.assert_allclose(logits.numpy()[0][0:4], logits_ref, atol=1e-5, rtol=1e-5)
250+
251+
252+
if __name__ == "__main__":
253+
unittest.main()

0 commit comments

Comments
 (0)