[Ernie-Image] Add lora support (#13575)

add lora support

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Álvaro Somoza
2026-04-30 12:32:09 -04:00
committed by GitHub
parent 716f246031
commit 303c1d8b04
5 changed files with 222 additions and 2 deletions

View File

@@ -34,6 +34,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`ErnieImageLoraLoaderMixin`] provides similar functions for [Ernie-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ernie_image).
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
@@ -64,6 +65,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## ErnieImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.ErnieImageLoraLoaderMixin
## LTX2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin

View File

@@ -85,6 +85,7 @@ if is_torch_available():
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
"ErnieImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -117,6 +118,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
ErnieImageLoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
HeliosLoraLoaderMixin,

View File

@@ -5829,6 +5829,217 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class ErnieImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# PEFT format -> normalize to diffusion_model.* prefix
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
# AI-Toolkit / diffusion_model.* prefix -> swap to transformer.*
# The Ernie LoRA naming under diffusion_model.* already matches diffusers module
# paths (layers.X.self_attention.to_q etc.), so only the prefix needs to change.
is_diffusion_model_prefix = any(k.startswith("diffusion_model.") for k in state_dict)
if is_diffusion_model_prefix:
state_dict = {k.replace("diffusion_model.", "transformer."): v for k, v in state_dict.items()}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
adapter_name: str | None = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ErnieImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: dict | None = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: list[str] | None = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

View File

@@ -25,6 +25,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import BaseOutput, logging
from ..attention import AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -288,7 +289,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
return x
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_repeated_blocks = ["ErnieImageSharedAdaLNBlock"]

View File

@@ -23,6 +23,7 @@ import torch
from PIL import Image
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from ...loaders import ErnieImageLoraLoaderMixin
from ...models import AutoencoderKLFlux2
from ...models.transformers import ErnieImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -31,7 +32,7 @@ from ...utils.torch_utils import randn_tensor
from .pipeline_output import ErnieImagePipelineOutput
class ErnieImagePipeline(DiffusionPipeline):
class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
"""
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.