mirror of
https://github.com/huggingface/diffusers.git
synced 2026-06-05 00:53:09 +08:00
* add attention backend tests. * remove existing tests/others/test_attention_backends.py file * modify generate_model_tests.py * remove native. * account for _keep_in_fp32_modules * don't skip when exception is raised. * use is_kernels_available() * mark with compile. * move rtol and atol to methods as defaults. * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * up * up
26 lines
921 B
Python
26 lines
921 B
Python
import torch
|
|
|
|
from diffusers.models.attention_dispatch import AttentionBackendName
|
|
|
|
|
|
_BF16_REQUIRED_BACKENDS = {
|
|
AttentionBackendName._NATIVE_CUDNN,
|
|
AttentionBackendName.FLASH_HUB,
|
|
AttentionBackendName.FLASH_VARLEN_HUB,
|
|
AttentionBackendName._FLASH_3_HUB,
|
|
}
|
|
|
|
|
|
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
|
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
|
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
|
|
return model, inputs_dict
|
|
if getattr(model, "_keep_in_fp32_modules", None):
|
|
raise NotImplementedError("Do not know how to define casting for models with `_keep_in_fp32_modules`.")
|
|
model = model.to(dtype=torch.bfloat16)
|
|
inputs_dict = {
|
|
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
|
for k, v in inputs_dict.items()
|
|
}
|
|
return model, inputs_dict
|