From 8d431dc967a4118168af74aae9c41f2a68764851 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 28 Jul 2025 13:27:20 +0530 Subject: [PATCH] tighten compilation tests for quantization --- tests/quantization/bnb/test_4bit.py | 1 + tests/quantization/test_torch_compile_utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 8e2a8515c6..08c0fee43b 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase): components_to_quantize=["transformer", "text_encoder_2"], ) + @require_bitsandbytes_version_greater("0.46.1") def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True super().test_torch_compile() diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index c742927646..91ed173fc6 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -56,12 +56,18 @@ class QuantCompileTests: pipe.transformer.compile(fullgraph=True) # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) + with torch._dynamo.config.patch(error_on_recompile=True): + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe.enable_model_cpu_offload() - pipe.transformer.compile() + # regional compilation is better for offloading. + # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/ + if getattr(pipe.transformer, "_repeated_blocks"): + pipe.transformer.compile_repeated_blocks(fullgraph=True) + else: + pipe.transformer.compile() # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)