-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Closed
Labels
Description
System Info
GPU: Nvidia a10g, 1 g5.12xlarge instance
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Use the docker container by following:
docker run --rm --runtime=nvidia --GPUs all --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com
pip uninstall -y mpmath
pip install mpmath==1.3.0
- Use the following build command to build engines from public flan-t5-XXL:
python TensorRT-LLM/examples/enc_dec/t5/convert.py -i google/flan-t5-xxl -o /public_t5_trt_covert_official/ --weight_data_type float32 --inference_tensor_para_size 4
python TensorRT-LLM/examples/enc_dec/build.py --model_type t5 --world_size 4 --tp_size 4 --gpus_per_node 4 --weight_dir /public_t5_trt_covert_official/tp4 -o /public_t5_trt_engine_official --engine_name t5 --use_bert_attention_plugin --use_gpt_attention_plugin --use_gemm_plugin --dtype bfloat16 --max_batch_size 32 --max_encoder_input_len 128 --max_output_len 128 --parallel_build
- Call
examples/enc_dec/run.py
(modified) inside the container as follows:
mpirun --allow-run-as-root -np 4 python TensorRT-LLM/examples/enc_dec/run_modified.py --engine_dir /public_t5_trt_engine_official --engine_name t5 --model_name /fluency_model/ --max_new_token=128 --num_beams=1 --compare_hf_fp32
I slightly modified the run.py to compare with HF bfloat16 results and use one example prompt. To replicate, just replace this part after if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
logger.set_level(args.log_level)
# FairSeq NMT test logic is different from HuggingFace models
if 'wmt' in args.model_name:
test_fairseq_models(args)
exit()
test_remove_padding = True
if not test_remove_padding:
if 't5' in args.model_name:
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
elif 'bart' in args.model_name:
input_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
else:
raise RuntimeError('Unsupported model type!')
else:
input_text = [
"Keeping the Secret of Genetic Testing",
# "translate English to German: The house is wonderful.",
# "summarize: I am a high-performance inference optimizer and runtime.",
# "During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world",
]
tokenizer = AutoTokenizer.from_pretrained(args.model_name) # TODO: use model path instead
tokenized_inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
max_new_tokens = args.max_new_tokens
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to(
'cuda') # [batch_size, padded_length]
# by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...]
if tensorrt_llm.mpi_rank() == 0:
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input text: ", input_text)
print("input ids: ", input_ids)
print("input lengths: ", tokenized_inputs.attention_mask.sum(dim=1))
print("--------------------------------------")
model_config = AutoConfig.from_pretrained(args.model_name)
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor([[model_config.decoder_start_token_id]
]).to('cuda')
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
if tensorrt_llm.mpi_rank() == 0:
print("Starting comparing with hf bfp16")
# simple comparison with HF on FP32
if args.compare_hf_fp32:
if tensorrt_llm.mpi_rank() == 0:
hf_model = AutoModelForSeq2SeqLM.from_pretrained(
args.model_name, # TODO: use model path instead
device_map="balanced_low_0",
torch_dtype=torch.bfloat16
# torch_dtype=torch.float16 if '16' in dtype else torch.float32, # TODO: use matched torch dtype
).eval() # TODO: create config model path instead
assert type(hf_model) in (
T5ForConditionalGeneration, BartForConditionalGeneration,
MBartForConditionalGeneration), 'Unsupported model!'
tik = time.time()
# breakpoint()
hf_gen_output = hf_model.generate(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
# control logits processors
no_repeat_ngram_size=0, # disable no repeat post-processor
forced_bos_token_id=None, # disable forced first/last token
forced_eos_token_id=None,
min_length=0,
# for debug
output_scores=True,
output_hidden_states=True,
return_dict_in_generate=True)
# get hf output scores
hf_output_ids = hf_gen_output.sequences
# convert to logits
torch.cuda.synchronize()
tok = time.time()
output_ids = hf_output_ids.squeeze(dim=1)
hf_output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("HF output_ids: ", output_ids)
print("HF output text: ", hf_output_text)
print("HF output generated lengths: ", output_gen_lengths)
print(f"HF E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
# Clean cache
del hf_model
gc.collect()
torch.cuda.empty_cache()
if tensorrt_llm.mpi_rank() == 0:
print("Done with HF inference")
# print(torch.cuda.memory_summary())
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
debug_mode=args.debug_mode)
tik = time.time()
tllm_output_ids = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=False, # when set return_dict=True, get outputs by key
attention_mask=tokenized_inputs.attention_mask)
tok = time.time()
inference_dtype = tllm_model.encoder_model_config.dtype
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print("TRT-LLM output text: ", output_text)
print("TRT-LLM output generated lengths: ", output_gen_lengths)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
# simple accuracy check
if args.compare_hf_fp32:
from difflib import SequenceMatcher
match_rate = SequenceMatcher(None, "\n".join(output_text),
"\n".join(hf_output_text)).ratio()
print(output_text)
print(hf_output_text)
if inference_dtype != "float32":
print("")
print(
f"[CAVEAT] Comparing TRT-LLM {inference_dtype} results with HF float32 results. Close match are not expected!"
)
assert match_rate > 0.8, f"Incorrect results! Match rate {match_rate}"
print(
f"TRT-LLM results match HF FP32 results with literal match rate {match_rate}"
)
Expected behavior
HF and TRT LLM results are roughly the same
actual behavior
HF output text: ['Keeping the Secret of Genetic Testing']
TRT-LLM output text: ['Keeping the Secret of Genetic Testing - The New York Times]
There is also a TensorRT error during the running:
[03/24/2024-05:29:38] [TRT] [E] 3: [engine.cpp::getProfileObliviousBindingIndex::1530] Error Code 3: Internal Error (setTensorAddress given invalid tensor name: attention_mask)
additional notes
In a larger dataset, the result difference is very obvious. Using the Fastertransformer can give much closer results with HF.