-
Notifications
You must be signed in to change notification settings - Fork 217
Description
⚙️ Your current environment
The output of python collect_env.py
🐛 Describe the bug
Operating System: `Linux-5.14.0-427.50.1.el9_4.x86_64-x86_64-with-glibc2.34`
Python Version: `3.13.5 | packaged by conda-forge | (main, Jun 16 2025, 08:27:50) [GCC 13.3.0]`
llm-compressor Version: `0.7.2.dev5+g6af07785`
compressed-tensors Version: `0.11.1a20250821`
transformers Version: `4.55.2`
torch Version: `2.8.0+cu129`
CUDA Devices: `['NVIDIA A100-SXM4-80GB']`
AMD Devices: `None`
I adapted the multimodal gemma3 example and compressed the model with this config:
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(
actorder=True,
offload_hessians=offload_hessians,
ignore=ignore,
targets=targets,
config_groups= {
"group_0": {
"targets": targets,
"input_activations": {
"num_bits": 8,
"strategy": "group",
"type": "float",
"symmetric": True,
"dynamic": True,
"group_size": group_size
},
"output_activations": None,
"weights": {
"num_bits": 4,
"type": "int",
"symmetric": True,
"strategy": "group",
"group_size": group_size
}
}
},
),
Note that actorder=True
since this increases accuracy.
The problem is that weight_g_idx is internally initialized to -1s, but the loading in HF actually loads this in a context where the weights are on the meta device. This causes the -1s to get ignored, and then afterwards they are initialized to random numbers (mostly 0s) form "empty" in torch.
# Internally, weight_g_idx is inited with -1, but this is not used
# because the weight is in meta device.
model = Gemma3ForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=True,
)
This causes a failure because of this line in compressed_tensors/quantization/lifecycle/forward.py:_process_quantization:ll287-297
is_column_order = g_idx is None or -1 in g_idx
if is_column_order:
num_groups = int(ceil(columns / group_size))
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
else:
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
group_sizes = group_sizes[torch.argsort(group_indices)]
perm = torch.argsort(g_idx)
x = safe_permute(x, perm, dim=1)
because now g_idx
is not None and sometimes there might be a -1 and sometimes there might not be a -1 in it.
When I fix it by overwriting the g_idx to -1's after I load the model using this:
# Here, weight_g_idx was inited with "empty" so can contain arbitrary values
fixed_state_dict = {}
for key, value in model.state_dict().items():
if "weight_g_idx" in key and isinstance(value, torch.Tensor):
value.fill_(value=-1)
print(f"Filled {key} with -1's")
fixed_state_dict[key] = value
model.load_state_dict(fixed_state_dict)
it runs.
As a sidenote, this general quantization does not work (the model produces random outputs/collapses). I verified that this is because of the input quantization, even if I use a 128 group size and dynamic input quant. which surprises me. It would be nice if you can also comment on this as I need 8-bit inputs.
🛠️ Steps to reproduce
No response