Skip to content

Improve TorchAO quantization test coverage and XPU support#13530

Open
jiqing-feng wants to merge 11 commits intohuggingface:mainfrom
jiqing-feng:torchao
Open

Improve TorchAO quantization test coverage and XPU support#13530
jiqing-feng wants to merge 11 commits intohuggingface:mainfrom
jiqing-feng:torchao

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Apr 21, 2026

What does this PR do?

This PR improves the TorchAO quantization testing infrastructure with several fixes: enabling int4wo tests on Intel XPU, implementing _dequantize for TorchAO, fixing input dtype mismatches, and fixing training gradient underflow.

Changes

  1. Enable int4wo tests on XPU: Removed the _int4wo_skip marker that restricted int4wo tests to CUDA only, allowing them to run on all accelerator backends.

  2. XPU-specific int4 packing format: Added XPU-specific handling in _get_quant_config() — Intel XPU requires int4_packing_format="plain_int32" for Int4WeightOnlyConfig.

  3. Fix input dtype casting: Introduced _get_dummy_inputs_for_model(model) helper in QuantizationTesterMixin to automatically cast floating-point input tensors to the model's parameter dtype, preventing dtype mismatches during quantized model inference.

  4. Implement _dequantize for TorchAO: Added _dequantize() method in TorchAoHfQuantizer that iterates all nn.Linear modules, calls weight.dequantize() on TorchAOBaseTensor weights, and replaces them with standard nn.Parameter. Also fixed _verify_if_layer_quantized to check isinstance(module.weight, TorchAOBaseTensor) so dequantized layers are correctly detected as non-quantized.

  5. Fix training gradient underflow: Changed autocast dtype from float16 to bfloat16 in _test_quantization_training. Float16's limited dynamic range (max ~65504, min subnormal ~5.96e-8) causes gradients to underflow to zero when passing through quantized tensor subclass operations; bfloat16 shares float32's exponent range and avoids this issue.

  6. Reduce WanAnimate TorchAO test input sizes: Shrunk dummy inputs in TestWanAnimateTransformer3DTorchAo to avoid OOM on devices without FlashAttention (e.g. XPU, which falls back to math SDPA and materializes the full O(S²) attention matrix). Reduced hidden_states from (1,36,21,64,64) to (1,36,5,16,16) and face_pixel_values from (1,3,77,512,512) to (1,3,13,512,512), bringing self-attention sequence length from 21,504 to 320 and peak attention memory from ~74 GiB to ~16 MB. Face frame count (13) is chosen so the face encoder's two stride-2 convolutions produce temporal output 4, plus 1 padding = 5, matching hidden_states temporal dim.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@github-actions github-actions Bot added tests size/M PR with diff < 200 LOC labels Apr 21, 2026
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul . Would you please review this PR? Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@github-actions github-actions Bot added quantization size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 21, 2026
@jiqing-feng jiqing-feng changed the title Enable TorchAO int4 weight-only quantization tests on Intel XPU Improve TorchAO quantization test coverage and XPU support Apr 21, 2026
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 21, 2026
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 21, 2026
@sayakpaul
Copy link
Copy Markdown
Member

There are a bunch of things going on in this PR. I would suggest breaking the PR into smaller PRs.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 22, 2026

Hi @sayakpaul . Thanks for the review! I've split this PR into 5 smaller independent PRs as suggested:

  1. Enable TorchAO int4wo quantization tests on XPU#13537

    • Remove _int4wo_skip marker on XPU
    • Add int4_packing_format="plain_int32" for XPU int4 quantization
  2. Implement _dequantize for TorchAO quantizer#13538

    • Add _dequantize() method to TorchAoHfQuantizer
    • Fix _verify_if_layer_quantized to check TorchAOBaseTensor weight type
  3. Add _get_dummy_inputs_for_model helper#13539

    • Cast floating-point input tensors to model's parameter dtype automatically
    • Defined in both QuantizationTesterMixin and QuantizationCompileTesterMixin
  4. Fix training gradient underflow#13540

    • Change autocast from float16 to bfloat16 to prevent gradient underflow in quantized training tests
  5. Reduce WanAnimate test input sizes#13541

    • Reduce spatial/temporal dimensions to prevent OOM on devices without FlashAttention (SDPA math backend materializes O(S²) attention matrix)

Each PR is independent and can be reviewed/merged separately. Will close this PR once the split PRs are up.

@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 22, 2026
def is_compileable(self) -> bool:
return True

def _dequantize(self, model):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have dequantize here in this PR right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please review the change here: #13538

return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for avoiding OOM, details see: #13541. Please let me know if you want comments in the code.

Comment on lines +1194 to +1200
def _get_dummy_inputs_for_model(self, model):
inputs = self.get_dummy_inputs()
model_dtype = next(model.parameters()).dtype
return {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to override?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuantizationCompileTesterMixin is an independent mixin that doesn't inherit from QuantizationTesterMixin. Test classes may use either one or both, so the method needs to be defined in both places.

Alternatively, I can extract it into a shared base class or a standalone utility function to avoid code duplication. Let me know which approach you prefer. Please review this change in #13539

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul . I have separated this PR into 5 small PRs, please review them 1 by 1 if it is easier for you. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization size/M PR with diff < 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants