Files
diffusers/tests/models/testing_utils/utils.py
Sayak Paul cbdedbaf03 [tests] add attention backend tests. (#13174)
* add attention backend tests.

* remove existing tests/others/test_attention_backends.py file

* modify generate_model_tests.py

* remove native.

* account for _keep_in_fp32_modules

* don't skip when exception is raised.

* use is_kernels_available()

* mark with compile.

* move rtol and atol to methods as defaults.

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* up

* up
2026-05-12 17:00:00 +09:00

26 lines
921 B
Python

import torch
from diffusers.models.attention_dispatch import AttentionBackendName
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_VARLEN_HUB,
AttentionBackendName._FLASH_3_HUB,
}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
if getattr(model, "_keep_in_fp32_modules", None):
raise NotImplementedError("Do not know how to define casting for models with `_keep_in_fp32_modules`.")
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict