Skip to content

Conversation

@yao-matrix
Copy link
Contributor

@yao-matrix yao-matrix commented Sep 22, 2025

Issue

4 cases fail

pytest -rA tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit
pytest -rA tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora
pytest -rA tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5
pytest -rA tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit

w/ log

         raise ValueError(
                "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
                " model has already been set to the correct devices and casted to the correct `dtype`."

E ValueError: .to is not supported for 8-bit bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct dtype.

Root Cause

8-bit models don't support to and cuda according transformers logic

Fix

check whether model is loaded in 8 bit, and skip to if so.

Results:

Pass

@BenjaminBossan , pls help review, thx very much.

Signed-off-by: Yao, Matrix <[email protected]>
@yao-matrix
Copy link
Contributor Author

@BenjaminBossan , could you pls take a review? Thx very much.

@githubnemo
Copy link
Collaborator

Thanks for the suggestion! I wonder if we should skip the test instead since we're "requesting" to use self.device but that model doesn't support using that device. WDYT?

weight = weight.to(torch.device("cuda"))
elif is_xpu_available():
weight = weight.to(torch.device("xpu"))
state.SCB = state.SCB.to(torch.device("xpu"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight is moved to accelerator, but SCB not, will make later int8_vectorwise_dequant cry to say weight in xpu, but SCB in cpu and make CPU CI, fail, so move it to accelerator too.


if is_cpu:
dequantized = dequantized.to(device)
state.SCB = state.SCB.to(device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both should be move back

@yao-matrix
Copy link
Contributor Author

yao-matrix commented Sep 25, 2025

Thanks for the suggestion! I wonder if we should skip the test instead since we're "requesting" to use self.device but that model doesn't support using that device. WDYT?

@githubnemo, I took some time to dive into the issue and then cleaned the test cases, then the device cases passed, we have torch_device, device of class TestLoftQ, device through method arguments, kind of make the test case confused, so I cleaned it, then accelerator(xpu, cuda) case can pass.

for cpu failures, its because incomplete tensor transfer btw cpu and accelerator, i commented inline.

Thx very much for review and sorry for before's lazy workaround.

@yao-matrix
Copy link
Contributor Author

@githubnemo , could you pls take a review again, thx very much

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this with 8bit bnb and device placement. Unfortunately, when I run this locally, I still get errors:

$ CUDA_VISIBLE_DEVICES=0 pytest tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5 tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit

FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit[cpu] - assert tensor(nan, grad_fn=<MeanBackward0>) > 0.0
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora[cpu] - RuntimeError: invalid argument to getCurrentStream
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cuda] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cpu] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cuda] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cpu] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered

I cannot test on XPU, but for CPU and CUDA, I get these errors. Do these tests pass for you?

Furthermore, I was wondering why these tests are not running on CI and I think it's a simple oversight, there is no test marker for TestLoftQ. Could you please add the @pytest.mark.single_gpu_tests decorator?

# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
# conservative value to prevent flakiness, in practice most gains are > 1.5
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you find it necessary to reduce the factor to 1.005 for XPU and CPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, it's needed. We are continuing enhance bnb support, when it works, we can remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I checked again with CUDA/CPU and the improvement is much larger than 1.03 (which, as indicated, is a very conservative value to avoid flakiness). So if XPU is below that, I think it's an indicator that something is missing.

@yao-matrix
Copy link
Contributor Author

Thanks for working on this with 8bit bnb and device placement. Unfortunately, when I run this locally, I still get errors:

$ CUDA_VISIBLE_DEVICES=0 pytest tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5 tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit

FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit[cpu] - assert tensor(nan, grad_fn=<MeanBackward0>) > 0.0
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora[cpu] - RuntimeError: invalid argument to getCurrentStream
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cuda] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cpu] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cuda] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered
FAILED tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cpu] - torch.AcceleratorError: CUDA error: an illegal memory access was encountered

I cannot test on XPU, but for CPU and CUDA, I get these errors. Do these tests pass for you?

Furthermore, I was wondering why these tests are not running on CI and I think it's a simple oversight, there is no test marker for TestLoftQ. Could you please add the @pytest.mark.single_gpu_tests decorator?

@BenjaminBossan Could you try installing latest bnb with pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl, since there is a bug in bnb and the fixing PR just merged bitsandbytes-foundation/bitsandbytes#1769. In my env, with this bnb fix, the test cases pass on XPU, CPU and CUDA(A100). Thx

Signed-off-by: Yao, Matrix <[email protected]>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the 8bit bnb LoftQ tests and making LoftQ work with XPU, the PR LGTM.

The stuck CI is unrelated and can be ignored.

# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
# conservative value to prevent flakiness, in practice most gains are > 1.5
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I checked again with CUDA/CPU and the improvement is much larger than 1.03 (which, as indicated, is a very conservative value to avoid flakiness). So if XPU is below that, I think it's an indicator that something is missing.

@BenjaminBossan BenjaminBossan merged commit ffa971a into huggingface:main Oct 1, 2025
12 of 14 checks passed
@yao-matrix yao-matrix deleted the issue-540 branch October 1, 2025 15:34
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 15, 2025
This is to fix an oversight from huggingface#2797, where the LoftQ test was
sligthly refactored but one test was not updated accordingly.
BenjaminBossan added a commit that referenced this pull request Oct 15, 2025
This is to fix an oversight from #2797, where the LoftQ test was
sligthly refactored but one test was not updated accordingly.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 17, 2025
Fixes some failing GPU tests in CI.

A bug was introduced in huggingface#2797 where state.SCB was accessed while
dequantizing 4bit bnb weights even though state is None. This would
occur, for instance, when using DoRA, which needs to dequantize the
weight. The attribute access is now restricted to 8bit bnb weights.
BenjaminBossan added a commit that referenced this pull request Oct 27, 2025
Fixes some failing GPU tests in CI.

A bug was introduced in #2797 where state.SCB was accessed while
dequantizing 4bit bnb weights even though state is None. This would
occur, for instance, when using DoRA, which needs to dequantize the
weight. The attribute access is now restricted to 8bit bnb weights.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants