Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 8c5f368

Browse files
authored
Add test for text-generation pipeline (#1139)
* Add test for text-generation pipeline * Fix test * Update test_text_generation.py * Remove model_stub fixture * local model stub * Add HF comparison * Move imports
1 parent 3cab6a3 commit 8c5f368

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
import pytest
18+
from deepsparse import Pipeline
19+
20+
21+
@pytest.mark.slow
22+
def test_codegen():
23+
model_stub = (
24+
"zoo:nlg/text_generation/codegen_mono-350m/pytorch/"
25+
"huggingface/bigpython_bigquery_thepile/base-none"
26+
)
27+
pipeline = Pipeline.create(
28+
task="text_generation",
29+
model_path=model_stub,
30+
max_generated_tokens=16,
31+
prompt_processing_sequence_length=1,
32+
use_deepsparse_cache=False,
33+
)
34+
prompt = "def fib():"
35+
out = pipeline(sequences=prompt)
36+
nm_output = prompt + out.sequences[0]
37+
38+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
39+
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
40+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
41+
generated_ids = model.generate(input_ids, max_new_tokens=16)
42+
hf_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
43+
44+
assert nm_output == hf_output

0 commit comments

Comments
 (0)