Skip to content

quantize.py fails to export important data to config.json (eg rotary scaling) #1676

@janpetrov

Description

@janpetrov

System Info

4x NVIDIA H100, TensorRT-LLM backend 0.9.0

Who can help?

@Tracin

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

(1) Have a HF transformers model with linear rope scaling.

(2) Edit /usr/local/lib/python3.10/dist-packages/ammo/torch/export/layer_utils.py, is_linear to (adding the and ("Rotary"... part)

def is_linear(module: nn.Module) -> bool:
    """Returns whether the module is a linear layer."""
    return any([k in type(module).__name__ for k in ["Linear", "Conv1D", "NormHead"]]) and ("Rotary" not in type(module).__name__)

so that the rope scaling model is exported (without crashing on an error that weights cannot be exported form the Rotary scaling layer, see this issue

(3) then run, as recommended here

python examples/quantization/quantize.py \
    --model_dir "$MODEL_DIR" \
    --dtype bfloat16 \
    --output_dir "$TMP_DIR" \
    --tp_size 2 \
    --qformat fp8 \
    --kv_cache_dtype fp8 \
    --calib_size 512

Expected behavior

quantize.py should generate a detailed config.json file in the output dir. The subsequent run of

trtllm-build \
    --checkpoint_dir "$TMP_DIR" \
    --gpt_attention_plugin bfloat16 \
    --gemm_plugin bfloat16 \
    --max_input_len 16384 \
    --max_output_len 16384 \
    --max_batch_size 8 \
    --strongly_typed \
    --workers 2 \
    --output_dir "$OUTPUT_DIR" \
    --multi_block_mode enable

should build a well-working engine.

actual behavior

The config.json generated by quantize.py contains just the following (please note eg the rope scaling missing). The engine built by trtllm-build generates nonsense.

{
    "producer": {
        "name": "ammo",
        "version": "0.7.4"
    },
    "architecture": "LlamaForCausalLM",
    "dtype": "bfloat16",
    "num_hidden_layers": 80,
    "num_attention_heads": 64,
    "num_key_value_heads": 8,
    "hidden_size": 8192,
    "norm_epsilon": 1e-05,
    "vocab_size": 32000,
    "max_position_embeddings": 4096,
    "hidden_act": "silu",
    "use_parallel_embedding": true,
    "embedding_sharding_dim": 0,
    "quantization": {
        "quant_algo": "FP8",
        "kv_cache_quant_algo": "FP8"
    },
    "mapping": {
        "world_size": 2,
        "tp_size": 2,
        "pp_size": 1
    },
    "head_size": 128,
    "intermediate_size": 28672,
    "position_embedding_type": "rope_gpt_neox",
    "rotary_base": 10000.0
}

additional notes

When I edit the config.json to have the following contents and then re-run trtllm-build, the resulting engine starts to generate fine text.

{
    "producer": {
        "name": "ammo",
        "version": "0.7.4"
    },
    "architecture": "LlamaForCausalLM",
    "dtype": "bfloat16",
    "logits_dtype": "float32",
    "vocab_size": 32000,
    "max_position_embeddings": 4096,
    "hidden_size": 8192,
    "num_hidden_layers": 80,
    "num_attention_heads": 64,
    "num_key_value_heads": 8,
    "head_size": 128,
    "hidden_act": "silu",
    "intermediate_size": 28672,
    "norm_epsilon": 1e-05,
    "position_embedding_type": "rope_gpt_neox",
    "use_parallel_embedding": true,
    "embedding_sharding_dim": 0,
    "mapping": {
        "world_size": 2,
        "tp_size": 2,
        "pp_size": 1
    },
    "quantization": {
        "quant_algo": "FP8",
        "kv_cache_quant_algo": "FP8"
    },
    "rotary_scaling": {
        "factor": 4.0,
        "type": "linear"
    },
    "moe_normalization_mode": null,
    "rotary_base": 10000.0,
    "moe_num_experts": 0,
    "moe_top_k": 0,
    "moe_tp_mode": 2,
    "attn_bias": false,
    "disable_weight_only_quant_plugin": false,
    "mlp_bias": false
}

Please note that when the input to trtllm-build is generated by examples/llama/convert_checkpoint.py (and not by examples/quantization/quanitize.py) then the config.json looks as follows. This is for the same model but without quantization. Please note much richer data, including rotary scaling.

 {
    "architecture": "LlamaForCausalLM",
    "dtype": "bfloat16",
    "logits_dtype": "float32",
    "vocab_size": 32000,
    "max_position_embeddings": 4096,
    "hidden_size": 8192,
    "num_hidden_layers": 80,
    "num_attention_heads": 64,
    "num_key_value_heads": 8,
    "head_size": 128,
    "hidden_act": "silu",
    "intermediate_size": 28672,
    "norm_epsilon": 1e-05,
    "position_embedding_type": "rope_gpt_neox",
    "use_parallel_embedding": false,
    "embedding_sharding_dim": 0,
    "share_embedding_table": false,
    "mapping": {
        "world_size": 4,
        "tp_size": 4,
        "pp_size": 1
    },
    "quantization": {
        "quant_algo": null,
        "kv_cache_quant_algo": null,
        "group_size": 128,
        "smoothquant_val": null,
        "has_zero_point": false,
        "pre_quant_scale": false,
        "exclude_modules": [
            "lm_head"
        ]
    },
    "kv_dtype": "bfloat16",
    "rotary_scaling": {
        "factor": 4.0,
        "type": "linear"
    },
    "moe_normalization_mode": null,
    "rotary_base": 10000.0,
    "moe_num_experts": 0,
    "moe_top_k": 0,
    "moe_tp_mode": 2,
    "attn_bias": false,
    "disable_weight_only_quant_plugin": false,
    "mlp_bias": false
}

Metadata

Metadata

Assignees

Labels

InvestigatingbugSomething isn't workingtriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions