NucleusMoE-Image (#13317)

* adding NucleusMoE-Image model

* update system prompt

* Add text kv caching

* Class/function name changes

* add missing imports

* add RoPE credits

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* update defaults

* Update src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* review updates

* fix the tests

* clean up

* update apply_text_kv_cache

* SwiGLUExperts addition

* fuse SwiGLUExperts up and gate proj

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/hooks/text_kv_cache.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* _SharedCacheKey -> TextKVCacheState

* Apply style fixes

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite

* Apply style fixes

* run `make fix-copies`

* fix import

* refactor text KV cache to be managed by StateManager

---------

Co-authored-by: Murali Nandan Nagarapu <nmn@withnucleus.ai>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
sippycoder
2026-04-03 02:01:13 -07:00
committed by GitHub
parent 5adc544b79
commit 447e571ada
17 changed files with 2445 additions and 3 deletions

View File

@@ -169,22 +169,23 @@ else:
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"TextKVCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
"apply_text_kv_cache",
]
)
_import_structure["image_processor"] = [
"IPAdapterMaskProcessor",
"InpaintProcessor",
"IPAdapterMaskProcessor",
"PixArtImageProcessor",
"VaeImageProcessor",
"VaeImageProcessorLDM3D",
]
_import_structure["video_processor"] = ["VideoProcessor"]
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
@@ -262,6 +263,7 @@ else:
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
"NucleusMoEImageTransformer2DModel",
"OmniGenTransformer2DModel",
"OvisImageTransformer2DModel",
"ParallelConfig",
@@ -396,6 +398,7 @@ else:
]
)
_import_structure["training_utils"] = ["EMAModel"]
_import_structure["video_processor"] = ["VideoProcessor"]
try:
if not (is_torch_available() and is_scipy_available()):
@@ -613,6 +616,7 @@ else:
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
"NucleusMoEImagePipeline",
"OmniGenPipeline",
"OvisImagePipeline",
"PaintByExamplePipeline",
@@ -967,12 +971,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
apply_text_kv_cache,
)
from .image_processor import (
InpaintProcessor,
@@ -1057,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
NucleusMoEImageTransformer2DModel,
OmniGenTransformer2DModel,
OvisImageTransformer2DModel,
ParallelConfig,
@@ -1384,6 +1391,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
NucleusMoEImagePipeline,
OmniGenPipeline,
OvisImagePipeline,
PaintByExamplePipeline,

View File

@@ -27,3 +27,4 @@ if is_torch_available():
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache

View File

@@ -0,0 +1,173 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer"
_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block"
@dataclass
class TextKVCacheConfig:
"""Enable exact (lossless) text K/V caching for transformer models.
Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all
steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook
before any intermediate tensor allocations.
"""
pass
class TextKVCacheState(BaseState):
"""Shared state between the transformer-level and block-level hooks.
The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so
that block hooks can use it as a reliable cache key across denoising steps.
"""
def __init__(self):
self.key: int | None = None
def reset(self):
self.key = None
class TextKVCacheBlockState(BaseState):
"""Per-block state holding cached text key/value projections."""
def __init__(self):
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
def reset(self):
self.kv_cache.clear()
class TextKVCacheTransformerHook(ModelHook):
"""Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm``
and writes it to shared state for the block hooks to read."""
_is_stateful = True
def __init__(self, state_manager: StateManager):
super().__init__()
self.state_manager = state_manager
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
encoder_hidden_states = kwargs.get("encoder_hidden_states")
if encoder_hidden_states is not None:
state: TextKVCacheState = self.state_manager.get_state()
state.key = encoder_hidden_states.data_ptr()
return self.fn_ref.original_forward(*args, **kwargs)
def reset_state(self, module: torch.nn.Module):
self.state_manager.reset()
return module
class TextKVCacheBlockHook(ModelHook):
"""Caches ``(txt_key, txt_value)`` per block per unique prompt using
the stable cache key from the shared state."""
_is_stateful = True
def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
super().__init__()
self.state_manager = state_manager
self.block_state_manager = block_state_manager
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
if self.block_state_manager._current_context is None:
self.block_state_manager.set_context("inference")
if "encoder_hidden_states" in kwargs:
encoder_hidden_states = kwargs["encoder_hidden_states"]
else:
encoder_hidden_states = args[1]
if "image_rotary_emb" in kwargs:
image_rotary_emb = kwargs["image_rotary_emb"]
elif len(args) > 3:
image_rotary_emb = args[3]
else:
image_rotary_emb = None
state: TextKVCacheState = self.state_manager.get_state()
cache_key = state.key
block_state: TextKVCacheBlockState = self.block_state_manager.get_state()
if cache_key not in block_state.kv_cache:
context = module.encoder_proj(encoder_hidden_states)
attn = module.attn
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim
txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))
if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)
if image_rotary_emb is not None:
_, txt_freqs = image_rotary_emb
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
block_state.kv_cache[cache_key] = (txt_key, txt_value)
txt_key, txt_value = block_state.kv_cache[cache_key]
attn_kwargs = kwargs.get("attention_kwargs") or {}
attn_kwargs["cached_txt_key"] = txt_key
attn_kwargs["cached_txt_value"] = txt_value
kwargs["attention_kwargs"] = attn_kwargs
return self.fn_ref.original_forward(*args, **kwargs)
def reset_state(self, module: torch.nn.Module):
self.block_state_manager.reset()
return module
def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock
HookRegistry.check_if_exists_or_initialize(module)
state_manager = StateManager(TextKVCacheState)
transformer_hook = TextKVCacheTransformerHook(state_manager)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK)
for _, submodule in module.named_modules():
if isinstance(submodule, NucleusMoEImageTransformerBlock):
block_state_manager = StateManager(TextKVCacheBlockState)
hook = TextKVCacheBlockHook(state_manager, block_state_manager)
block_registry = HookRegistry.check_if_exists_or_initialize(submodule)
block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK)

View File

@@ -116,6 +116,7 @@ if is_torch_available():
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
@@ -236,6 +237,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
NucleusMoEImageTransformer2DModel,
OmniGenTransformer2DModel,
OvisImageTransformer2DModel,
PixArtTransformer2DModel,

View File

@@ -41,11 +41,12 @@ class CacheMixin:
Enable caching techniques on the model.
Args:
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
- [`~hooks.TextKVCacheConfig`]
Example:
@@ -71,11 +72,13 @@ class CacheMixin:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
apply_text_kv_cache,
)
if self.is_cache_enabled:
@@ -89,6 +92,8 @@ class CacheMixin:
apply_first_block_cache(self, config)
elif isinstance(config, MagCacheConfig):
apply_mag_cache(self, config)
elif isinstance(config, TextKVCacheConfig):
apply_text_kv_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
@@ -106,12 +111,14 @@ class CacheMixin:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_BLOCK_HOOK, _TEXT_KV_CACHE_TRANSFORMER_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -129,6 +136,9 @@ class CacheMixin:
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TextKVCacheConfig):
registry.remove_hook(_TEXT_KV_CACHE_TRANSFORMER_HOOK, recurse=True)
registry.remove_hook(_TEXT_KV_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:

View File

@@ -40,6 +40,7 @@ if is_torch_available():
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_ovis_image import OvisImageTransformer2DModel
from .transformer_prx import PRXTransformer2DModel

View File

@@ -0,0 +1,925 @@
# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
logger = logging.get_logger(__name__)
# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus
def _apply_rotary_emb_nucleus(
x: torch.Tensor,
freqs_cis: torch.Tensor | tuple[torch.Tensor],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
def _compute_text_seq_len_from_mask(
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
if encoder_hidden_states_mask is None:
return text_seq_len, None, None
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
raise ValueError(
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
)
if encoder_hidden_states_mask.dtype != torch.bool:
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
)
return text_seq_len, per_sample_len, encoder_hidden_states_mask
class NucleusMoETimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()
self.time_proj = Timesteps(
num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000
)
self.timestep_embedder = TimestepEmbedding(
in_channels=embedding_dim, time_embed_dim=4 * embedding_dim, out_dim=embedding_dim
)
self.norm = RMSNorm(embedding_dim, eps=1e-6)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
conditioning = timesteps_emb
if self.use_additional_t_cond:
if addition_t_cond is None:
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
conditioning = conditioning + addition_t_emb
return self.norm(conditioning)
class NucleusMoEEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self._rope_params(pos_index, self.axes_dim[0], self.theta),
self._rope_params(pos_index, self.axes_dim[1], self.theta),
self._rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self._rope_params(neg_index, self.axes_dim[0], self.theta),
self._rope_params(neg_index, self.axes_dim[1], self.theta),
self._rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.scale_rope = scale_rope
@staticmethod
def _rope_params(index, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(
self,
video_fhw: tuple[int, int, int] | list[tuple[int, int, int]],
device: torch.device = None,
max_txt_seq_len: int | torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
device: (`torch.device`, *optional*):
The device on which to perform the RoPE computation.
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
The maximum text sequence length for RoPE computation.
"""
if max_txt_seq_len is None:
raise ValueError("Either `max_txt_seq_len` must be provided.")
if isinstance(video_fhw, list) and len(video_fhw) > 1:
first_fhw = video_fhw[0]
if not all(fhw == first_fhw for fhw in video_fhw):
logger.warning(
"Batch inference with variable-sized images is not currently supported in NucleusMoEEmbedRope. "
"All images in the batch should have the same dimensions (frame, height, width). "
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
"for RoPE computation, which may lead to incorrect results for other images in the batch."
)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)
max_txt_seq_len_int = int(max_txt_seq_len)
if self.scale_rope:
max_vid_index = torch.maximum(
torch.tensor(height // 2, device=device, dtype=torch.long),
torch.tensor(width // 2, device=device, dtype=torch.long),
)
else:
max_vid_index = torch.maximum(
torch.tensor(height, device=device, dtype=torch.long),
torch.tensor(width, device=device, dtype=torch.long),
)
txt_freqs = self.pos_freqs.to(device)[max_vid_index + torch.arange(max_txt_seq_len_int, device=device)]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class NucleusMoEAttnProcessor2_0:
"""
Attention processor for the NucleusMoE architecture. Image queries attend to concatenated image+text keys/values
(cross-attention style, no text query). Supports grouped-query attention (GQA) when num_key_value_heads is set on
the Attention module.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"NucleusMoEAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
cached_txt_key: torch.FloatTensor | None = None,
cached_txt_value: torch.FloatTensor | None = None,
) -> torch.FloatTensor:
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim
num_kv_groups = attn.heads // num_kv_heads
img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1))
img_key = attn.to_k(hidden_states).unflatten(-1, (num_kv_heads, -1))
img_value = attn.to_v(hidden_states).unflatten(-1, (num_kv_heads, -1))
if attn.norm_q is not None:
img_query = attn.norm_q(img_query)
if attn.norm_k is not None:
img_key = attn.norm_k(img_key)
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False)
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False)
if cached_txt_key is not None and cached_txt_value is not None:
txt_key, txt_value = cached_txt_key, cached_txt_value
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)
elif encoder_hidden_states is not None:
txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)
if image_rotary_emb is not None:
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)
else:
joint_key = img_key
joint_value = img_value
if num_kv_groups > 1:
joint_key = joint_key.repeat_interleave(num_kv_groups, dim=2)
joint_value = joint_value.repeat_interleave(num_kv_groups, dim=2)
hidden_states = dispatch_attention_fn(
img_query,
joint_key,
joint_value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(img_query.dtype)
hidden_states = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1:
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def _is_moe_layer(strategy: str, layer_idx: int, num_layers: int) -> bool:
if strategy == "leave_first_three_and_last_block_dense":
return layer_idx >= 3 and layer_idx < num_layers - 1
elif strategy == "leave_first_three_blocks_dense":
return layer_idx >= 3
elif strategy == "leave_first_block_dense":
return layer_idx >= 1
elif strategy == "all_moe":
return True
elif strategy == "all_dense":
return False
return True
class SwiGLUExperts(nn.Module):
"""
Packed SwiGLU feed-forward experts for MoE: ``gate, up = (x @ gate_up_proj).chunk(2); out = (silu(gate) * up) @
down_proj``.
Gate and up projections are fused into a single weight ``gate_up_proj`` so that only two grouped matmuls are needed
at runtime (gate+up combined, then down).
Weights are stored pre-transposed relative to the standard linear-layer convention so that matmuls can be issued
without a transpose at runtime.
Weight shapes:
gate_up_proj: (num_experts, hidden_size, 2 * moe_intermediate_dim) -- fused gate + up projection down_proj:
(num_experts, moe_intermediate_dim, hidden_size) -- down projection
"""
def __init__(
self,
hidden_size: int,
moe_intermediate_dim: int,
num_experts: int,
use_grouped_mm: bool = False,
):
super().__init__()
self.num_experts = num_experts
self.moe_intermediate_dim = moe_intermediate_dim
self.hidden_size = hidden_size
self.use_grouped_mm = use_grouped_mm
self.gate_up_proj = nn.Parameter(torch.empty(num_experts, hidden_size, 2 * moe_intermediate_dim))
self.down_proj = nn.Parameter(torch.empty(num_experts, moe_intermediate_dim, hidden_size))
def _run_experts_for_loop(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
"""
Compute SwiGLU MoE expert outputs using a sequential per-expert for loop.
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
and so on — i.e. the layout produced by a standard token-permutation step (e.g. ``generate_permute_indices``).
``x`` may contain trailing padding rows appended by the permutation utility to reach a length that is a
multiple of some alignment requirement. The padding rows are stripped before expert computation and re-appended
as zeros so that the output shape matches ``x.shape``, keeping downstream scatter/gather indices valid.
.. note::
``num_tokens_per_expert.tolist()`` synchronises the device with the host. This is acceptable for the loop
path but means the method introduces a pipeline bubble. Use :meth:`forward` with ``use_grouped_mm=True``
when a fully device-resident kernel is required (e.g. inside ``torch.compile``).
SwiGLU formula::
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
Args:
x (Tensor): Pre-permuted input tokens of shape
``(total_tokens_including_padding, hidden_dim)``.
num_tokens_per_expert (Tensor): 1-D integer tensor of length
``num_experts`` giving the number of real (non-padding) tokens assigned to each expert. Values may
differ across experts to support load-imbalanced routing.
Returns:
Tensor of shape ``(total_tokens_including_padding, hidden_dim)``. Positions corresponding to padding rows
contain zeros.
"""
# .tolist() triggers a host-device sync; see docstring note above.
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
# x may be padded to a larger buffer size by the permutation utility.
# Track the padding count so we can restore the original buffer shape.
num_real_tokens = sum(num_tokens_per_expert_list)
num_padding = x.shape[0] - num_real_tokens
# Split the real-token prefix of x into per-expert slices (variable length).
x_per_expert = torch.split(
x[:num_real_tokens],
split_size_or_sections=num_tokens_per_expert_list,
dim=0,
)
expert_outputs = []
for expert_idx, x_expert in enumerate(x_per_expert):
gate_up = torch.matmul(x_expert, self.gate_up_proj[expert_idx])
gate, up = gate_up.chunk(2, dim=-1)
out_expert = torch.matmul(F.silu(gate) * up, self.down_proj[expert_idx])
expert_outputs.append(out_expert)
# Concatenate real-token outputs, then re-append zero rows for the padding.
out = torch.cat(expert_outputs, dim=0)
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
return out
def _run_experts_grouped_mm(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
"""
Compute SwiGLU MoE expert outputs using fused grouped GEMM kernels.
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
and so on — the same layout required by :meth:`_run_experts_for_loop`.
This method is fully device-resident (no host-device sync) and is compatible with ``torch.compile``.
``F.grouped_mm`` is called with *exclusive end* offsets: ``offsets[k]`` is the exclusive end index of expert
``k``'s token range in ``x`` (equivalently the inclusive start of expert ``k+1``'s range). This is the
cumulative sum of ``num_tokens_per_expert``.
SwiGLU formula::
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
Args:
x (Tensor): Pre-permuted input tokens of shape
``(total_tokens, hidden_dim)``. No padding rows expected; ``total_tokens`` must equal
``num_tokens_per_expert.sum()``.
num_tokens_per_expert (Tensor): 1-D integer tensor of length
``num_experts`` giving the number of tokens assigned to each expert.
Returns:
Tensor of shape ``(total_tokens, hidden_dim)`` with dtype matching ``x``.
"""
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
gate_up = F.grouped_mm(x, self.gate_up_proj, offs=offsets)
gate, up = gate_up.chunk(2, dim=-1)
out = F.grouped_mm(F.silu(gate) * up, self.down_proj, offs=offsets)
return out.type_as(x)
def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
if self.use_grouped_mm:
return self._run_experts_grouped_mm(x, num_tokens_per_expert)
return self._run_experts_for_loop(x, num_tokens_per_expert)
class NucleusMoELayer(nn.Module):
"""
Mixture-of-Experts layer with expert-choice routing and a shared expert.
Routed expert weights live in :class:`SwiGLUExperts`. The router concatenates a timestep embedding with the
(unmodulated) hidden state to produce per-token affinity scores, then selects the top-C tokens per expert
(expert-choice routing). A shared expert processes all tokens in parallel and its output is combined with the
routed expert outputs via scatter-add.
SwiGLU expert computation is implemented by :class:`SwiGLUExperts`.
"""
def __init__(
self,
hidden_size: int,
moe_intermediate_dim: int,
num_experts: int,
capacity_factor: float,
use_sigmoid: bool,
route_scale: float,
use_grouped_mm: bool = False,
):
super().__init__()
self.num_experts = num_experts
self.moe_intermediate_dim = moe_intermediate_dim
self.hidden_size = hidden_size
self.capacity_factor = capacity_factor
self.use_sigmoid = use_sigmoid
self.route_scale = route_scale
self.gate = nn.Linear(hidden_size * 2, num_experts, bias=False)
self.experts = SwiGLUExperts(
hidden_size=hidden_size,
moe_intermediate_dim=moe_intermediate_dim,
num_experts=num_experts,
use_grouped_mm=use_grouped_mm,
)
self.shared_expert = FeedForward(
dim=hidden_size,
dim_out=hidden_size,
inner_dim=moe_intermediate_dim,
activation_fn="swiglu",
bias=False,
)
def forward(
self,
hidden_states: torch.Tensor,
hidden_states_unmodulated: torch.Tensor,
timestep: torch.Tensor | None = None,
) -> torch.Tensor:
bs, slen, dim = hidden_states.shape
if timestep is not None:
timestep_expanded = timestep.unsqueeze(1).expand(-1, slen, -1)
router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1)
else:
router_input = hidden_states_unmodulated
logits = self.gate(router_input)
if self.use_sigmoid:
scores = torch.sigmoid(logits.float()).to(logits.dtype)
else:
scores = F.softmax(logits.float(), dim=-1).to(logits.dtype)
affinity = scores.transpose(1, 2) # (B, E, S)
capacity = max(1, math.ceil(self.capacity_factor * slen / self.num_experts))
topk = torch.topk(affinity, k=capacity, dim=-1)
top_indices = topk.indices # (B, E, C)
gating = affinity.gather(dim=-1, index=top_indices) # (B, E, C)
batch_offsets = torch.arange(bs, device=hidden_states.device, dtype=torch.long).view(bs, 1, 1) * slen
global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
token_score_sums = torch.zeros(bs * slen, device=hidden_states.device, dtype=gating_flat.dtype)
token_score_sums.scatter_add_(0, global_token_indices, gating_flat)
gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12)
gating_flat = gating_flat * self.route_scale
x_flat = hidden_states.reshape(bs * slen, dim)
routed_input = x_flat[global_token_indices]
tokens_per_expert = bs * capacity
num_tokens_per_expert = torch.full(
(self.num_experts,),
tokens_per_expert,
device=hidden_states.device,
dtype=torch.long,
)
routed_output = self.experts(routed_input, num_tokens_per_expert)
routed_output = (routed_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype)
out = self.shared_expert(hidden_states).reshape(bs * slen, dim)
scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, dim)
out = out.scatter_add(dim=0, index=scatter_idx, src=routed_output)
out = out.reshape(bs, slen, dim)
return out
class NucleusMoEImageTransformerBlock(nn.Module):
"""
Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image stream receives adaptive modulation;
the text context is projected per-block and used as cross-attention keys/values.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_key_value_heads: int | None = None,
joint_attention_dim: int = 3584,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
mlp_ratio: float = 4.0,
moe_enabled: bool = False,
num_experts: int = 128,
moe_intermediate_dim: int = 1344,
capacity_factor: float = 8.0,
use_sigmoid: bool = False,
route_scale: float = 2.5,
use_grouped_mm: bool = False,
):
super().__init__()
self.dim = dim
self.moe_enabled = moe_enabled
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 4 * dim, bias=True),
)
self.encoder_proj = nn.Linear(joint_attention_dim, dim)
self.pre_attn_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
self.attn = Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_key_value_heads,
dim_head=attention_head_dim,
added_kv_proj_dim=dim,
added_proj_bias=False,
out_dim=dim,
out_bias=False,
bias=False,
processor=NucleusMoEAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
context_pre_only=None,
)
self.pre_mlp_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
if moe_enabled:
self.img_mlp = NucleusMoELayer(
hidden_size=dim,
moe_intermediate_dim=moe_intermediate_dim,
num_experts=num_experts,
capacity_factor=capacity_factor,
use_sigmoid=use_sigmoid,
route_scale=route_scale,
use_grouped_mm=use_grouped_mm,
)
else:
mlp_inner_dim = int(dim * mlp_ratio * 2 / 3) // 128 * 128
self.img_mlp = FeedForward(
dim=dim,
dim_out=dim,
inner_dim=mlp_inner_dim,
activation_fn="swiglu",
bias=False,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor:
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)
attn_kwargs = attention_kwargs or {}
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)
img_normed = self.pre_attn_norm(hidden_states)
img_modulated = img_normed * (1 + scale1)
img_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=context,
image_rotary_emb=image_rotary_emb,
**attn_kwargs,
)
hidden_states = hidden_states + gate1.tanh() * img_attn_output
img_normed2 = self.pre_mlp_norm(hidden_states)
img_modulated2 = img_normed2 * (1 + scale2)
if self.moe_enabled:
img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb)
else:
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + gate2.tanh() * img_mlp_output
if hidden_states.dtype == torch.float16:
fp16_finfo = torch.finfo(torch.float16)
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)
return hidden_states
class NucleusMoEImageTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
"""
Nucleus MoE Transformer for image generation. Single-stream DiT with cross-attention to text and optional
Mixture-of-Experts feed-forward layers.
Args:
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `24`):
The number of transformer blocks.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `16`):
The number of attention heads to use.
num_key_value_heads (`int`, *optional*):
The number of key/value heads for grouped-query attention. Defaults to `num_attention_heads`.
joint_attention_dim (`int`, defaults to `3584`):
The embedding dimension of the encoder hidden states (text).
axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
mlp_ratio (`float`, defaults to `4.0`):
Multiplier for the MLP hidden dimension in dense (non-MoE) blocks.
moe_enabled (`bool`, defaults to `True`):
Whether to use Mixture-of-Experts layers.
dense_moe_strategy (`str`, defaults to ``"leave_first_three_and_last_block_dense"``):
Strategy for choosing which layers are MoE vs dense.
num_experts (`int`, defaults to `128`):
Number of experts per MoE layer.
moe_intermediate_dim (`int`, defaults to `1344`):
Hidden dimension inside each expert.
capacity_factors (`float | list[float]`, defaults to `8.0`):
Expert-choice capacity factor per layer.
use_sigmoid (`bool`, defaults to `False`):
Use sigmoid instead of softmax for routing scores.
route_scale (`float`, defaults to `2.5`):
Scaling factor applied to routing weights.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["NucleusMoEImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["NucleusMoEImageTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: int | None = None,
num_layers: int = 24,
attention_head_dim: int = 128,
num_attention_heads: int = 16,
num_key_value_heads: int | None = None,
joint_attention_dim: int = 3584,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
mlp_ratio: float = 4.0,
moe_enabled: bool = True,
dense_moe_strategy: str = "leave_first_three_and_last_block_dense",
num_experts: int = 128,
moe_intermediate_dim: int = 1344,
capacity_factors: float | list[float] = 8.0,
use_sigmoid: bool = False,
route_scale: float = 2.5,
use_grouped_mm: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
capacity_factors = capacity_factors if isinstance(capacity_factors, list) else [capacity_factors] * num_layers
self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.time_text_embed = NucleusMoETimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
self.img_in = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
NucleusMoEImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_key_value_heads=num_key_value_heads,
joint_attention_dim=joint_attention_dim,
mlp_ratio=mlp_ratio,
moe_enabled=moe_enabled and _is_moe_layer(dense_moe_strategy, idx, num_layers),
num_experts=num_experts,
moe_intermediate_dim=moe_intermediate_dim,
capacity_factor=capacity_factors[idx],
use_sigmoid=use_sigmoid,
route_scale=route_scale,
use_grouped_mm=use_grouped_mm,
)
for idx in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
img_shapes: tuple[int, int, int] | list[tuple[int, int, int]],
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor | Transformer2DModelOutput:
"""
The [`NucleusMoEImageTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
img_shapes (`list[tuple[int, int, int]]`, *optional*):
Image shapes ``(frame, height, width)`` for RoPE computation.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Boolean mask for the encoder hidden states.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
attention_kwargs (`dict`, *optional*):
Extra kwargs forwarded to the attention processor.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`].
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
text_seq_len, _, encoder_hidden_states_mask = _compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)
temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
block_attention_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=block_attention_kwargs,
)
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -420,6 +420,7 @@ else:
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
]
_import_structure["nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
_import_structure["qwenimage"] = [
"QwenImagePipeline",
"QwenImageImg2ImgPipeline",
@@ -768,6 +769,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
from .nucleusmoe_image import NucleusMoEImagePipeline
from .omnigen import OmniGenPipeline
from .ovis_image import OvisImagePipeline
from .pag import (

View File

@@ -77,6 +77,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .lumina import LuminaPipeline
from .lumina2 import Lumina2Pipeline
from .nucleusmoe_image import NucleusMoEImagePipeline
from .ovis_image import OvisImagePipeline
from .pag import (
HunyuanDiTPAGPipeline,
@@ -179,6 +180,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("helios", HeliosPipeline),
("helios-pyramid", HeliosPyramidPipeline),
("cogview4-control", CogView4ControlPipeline),
("nucleusmoe-image", NucleusMoEImagePipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
("z-image", ZImagePipeline),

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["NucleusMoEImagePipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_nucleusmoe_image import NucleusMoEImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,644 @@
# Copyright 2025 Nucleus-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable
import numpy as np
import torch
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLQwenImage, NucleusMoEImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import NucleusMoEImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
DEFAULT_SYSTEM_PROMPT = "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects."
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import NucleusMoEImagePipeline
>>> pipe = NucleusMoEImagePipeline.from_pretrained("NucleusAI/NucleusMoE-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt, num_inference_steps=50).images[0]
>>> image.save("nucleus_moe.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`list[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`list[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class NucleusMoEImagePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using NucleusMoE.
This pipeline uses a single-stream DiT with Mixture-of-Experts feed-forward layers, cross-attention to a Qwen3-VL
text encoder, and a flow-matching Euler discrete scheduler.
Args:
transformer ([`NucleusMoEImageTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLQwenImage`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen3VLForConditionalGeneration`]):
Text encoder for computing prompt embeddings.
processor ([`Qwen3VLProcessor`]):
Processor for tokenizing text inputs.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
transformer: NucleusMoEImageTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLQwenImage,
text_encoder: Qwen3VLForConditionalGeneration,
processor: Qwen3VLProcessor,
):
super().__init__()
self.register_modules(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
processor=processor,
)
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.default_sample_size = 128
self.default_max_sequence_length = 1024
self.default_return_index = -8
def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def encode_prompt(
self,
prompt: str | list[str] = None,
device: torch.device | None = None,
num_images_per_prompt: int = 1,
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int | None = None,
return_index: int | None = None,
):
r"""
Encode text prompt(s) into embeddings using the Qwen3-VL text encoder.
Args:
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to encode.
device (`torch.device`, *optional*):
Torch device for the resulting tensors.
num_images_per_prompt (`int`, defaults to 1):
Number of images to generate per prompt.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Skips encoding when provided.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Attention mask for pre-generated embeddings.
max_sequence_length (`int`, defaults to 1024):
Maximum token length for the encoded prompt.
"""
device = device or self._execution_device
return_index = return_index or self.default_return_index
if prompt_embeds is None:
prompt = [prompt] if isinstance(prompt, str) else prompt
formatted = [self._format_prompt(p) for p in prompt]
inputs = self.processor(
text=formatted,
padding="longest",
pad_to_multiple_of=8,
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
).to(device=device)
prompt_embeds_mask = inputs.attention_mask
outputs = self.text_encoder(**inputs, use_cache=False, return_dict=True, output_hidden_states=True)
prompt_embeds = outputs.hidden_states[return_index]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(device=device)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.to(device=device)
if num_images_per_prompt > 1:
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_images_per_prompt, dim=0)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None
return prompt_embeds, prompt_embeds_mask
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
prompt_embeds_mask=None,
negative_prompt_embeds=None,
negative_prompt_embeds_mask=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
return_index=None,
):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} "
f"but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, "
f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.")
elif prompt is not None and not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and "
f"`negative_prompt_embeds`: {negative_prompt_embeds}. "
"Please make sure to only forward one of the two."
)
if return_index is not None and abs(return_index) >= self.text_encoder.config.text_config.num_hidden_layers:
raise ValueError(
f"absolute value of `return_index` cannot be >= {self.text_encoder.config.text_config.num_hidden_layers} "
f"but is {abs(return_index)}"
)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size):
latents = latents.view(
batch_size, num_channels_latents, height // patch_size, patch_size, width // patch_size, patch_size
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // patch_size) * (width // patch_size), num_channels_latents * patch_size * patch_size
)
return latents
@staticmethod
def _unpack_latents(latents, height, width, patch_size, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
latents = latents.view(
batch_size,
height // patch_size,
width // patch_size,
channels // (patch_size * patch_size),
patch_size,
patch_size,
)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
patch_size,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = patch_size * (int(height) // (self.vae_scale_factor * patch_size))
width = patch_size * (int(width) // (self.vae_scale_factor * patch_size))
shape = (batch_size, 1, num_channels_latents, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] = None,
negative_prompt: str | list[str] = None,
guidance_scale: float = 4.0,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
num_images_per_prompt: int = 1,
max_sequence_length: int | None = None,
return_index: int | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds_mask: torch.Tensor | None = None,
output_type: str | None = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, an empty string is used when
`true_cfg_scale > 1`.
true_cfg_scale (`float`, *optional*, defaults to 4.0):
Classifier-free guidance scale. Values greater than 1 enable CFG.
height (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list[float]`, *optional*):
Custom sigmas for the denoising schedule. If not defined, a linear schedule is used.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
One or a list of torch generators to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents to be used as inputs for image generation.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings.
prompt_embeds_mask (`torch.Tensor`, *optional*):
Attention mask for pre-generated text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings.
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
Attention mask for pre-generated negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `"pil"`, `"np"`, or `"latent"`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`NucleusMoEImagePipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
Kwargs passed to the attention processor.
callback_on_step_end (`Callable`, *optional*):
A function called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`list`, *optional*):
Tensor inputs for the `callback_on_step_end` function.
max_sequence_length (`int`, defaults to 512):
Maximum sequence length for the text prompt.
Examples:
Returns:
[`NucleusMoEImagePipelineOutput`] or `tuple`:
[`NucleusMoEImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple` where the first element
is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
max_sequence_length = max_sequence_length or self.default_max_sequence_length
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
return_index=return_index,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs or {}
self._current_timestep = None
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
do_cfg = guidance_scale > 1
if do_cfg and not has_neg_prompt:
negative_prompt = [""] * batch_size
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
return_index=return_index,
)
if do_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
return_index=return_index,
)
num_channels_latents = self.transformer.config.in_channels // 4
patch_size = self.transformer.config.patch_size
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
patch_size,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
img_shapes = [
(1, height // self.vae_scale_factor // patch_size, width // self.vae_scale_factor // patch_size)
] * (batch_size * num_images_per_prompt)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
self.scheduler.set_begin_index(0)
if self.transformer.is_cache_enabled:
self.transformer._reset_stateful_cache()
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / self.scheduler.config.num_train_timesteps,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
attention_kwargs=self._attention_kwargs,
return_dict=False,
)[0]
if do_cfg:
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / self.scheduler.config.num_train_timesteps,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
img_shapes=img_shapes,
attention_kwargs=self._attention_kwargs,
return_dict=False,
)[0]
comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
noise_pred = -noise_pred
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, patch_size, self.vae_scale_factor)
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return NucleusMoEImagePipelineOutput(images=image)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class NucleusMoEImagePipelineOutput(BaseOutput):
"""
Output class for NucleusMoE Image pipelines.
Args:
images (`list[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: list[PIL.Image.Image] | np.ndarray

View File

@@ -287,6 +287,21 @@ class TaylorSeerCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class TextKVCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
@@ -311,6 +326,10 @@ def apply_taylorseer_cache(*args, **kwargs):
requires_backends(apply_taylorseer_cache, ["torch"])
def apply_text_kv_cache(*args, **kwargs):
requires_backends(apply_text_kv_cache, ["torch"])
class InpaintProcessor(metaclass=DummyObject):
_backends = ["torch"]
@@ -1511,6 +1530,21 @@ class MultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class NucleusMoEImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class OmniGenTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -2567,6 +2567,21 @@ class MusicLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class NucleusMoEImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class OmniGenPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,220 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import NucleusMoEImageTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return NucleusMoEImageTransformer2DModel
@property
def output_shape(self) -> tuple[int, int]:
return (16, 16)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 16)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4),
"moe_enabled": False,
"capacity_factors": [8.0, 8.0],
}
def get_dummy_inputs(self) -> dict:
batch_size = 1
in_channels = 16
joint_attention_dim = 16
height = width = 4
sequence_length = 8
hidden_states = randn_tensor(
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
img_shapes = [(1, height, width)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin):
def test_with_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Mask out some text tokens
mask = inputs["encoder_hidden_states_mask"].clone()
mask[:, 4:] = 0
inputs["encoder_hidden_states_mask"] = mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_without_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
inputs["encoder_hidden_states_mask"] = None
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for NucleusMoE Image Transformer."""
class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for NucleusMoE Image Transformer."""
class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for NucleusMoE Image Transformer."""
class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for NucleusMoE Image Transformer."""
class TestNucleusMoEImageTransformerLoRAHotSwap(
NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin
):
"""LoRA hot-swapping tests for NucleusMoE Image Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
batch_size = 1
in_channels = 16
joint_attention_dim = 16
sequence_length = 8
hidden_states = randn_tensor(
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
img_shapes = [(1, height, width)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for NucleusMoE Image Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
batch_size = 1
in_channels = 16
joint_attention_dim = 16
sequence_length = 8
hidden_states = randn_tensor(
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
img_shapes = [(1, height, width)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for NucleusMoE Image Transformer."""
class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for NucleusMoE Image Transformer."""

View File

@@ -0,0 +1,337 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor
from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
NucleusMoEImagePipeline,
NucleusMoEImageTransformer2DModel,
)
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = NucleusMoEImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = NucleusMoEImageTransformer2DModel(
patch_size=2,
in_channels=16,
out_channels=4,
num_layers=2,
attention_head_dim=16,
num_attention_heads=4,
joint_attention_dim=16,
axes_dims_rope=(8, 4, 4),
moe_enabled=False,
capacity_factors=[8.0, 8.0],
)
torch.manual_seed(0)
z_dim = 4
vae = AutoencoderKLQwenImage(
base_dim=z_dim * 6,
z_dim=z_dim,
dim_mult=[1, 2, 4],
num_res_blocks=1,
temperal_downsample=[False, True],
# fmt: off
latents_mean=[0.0] * z_dim,
latents_std=[1.0] * z_dim,
# fmt: on
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
config = Qwen3VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 8,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
"vocab_size": 151936,
"head_dim": 8,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_channels": 16,
},
)
text_encoder = Qwen3VLForConditionalGeneration(config).eval()
processor = Qwen3VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"processor": processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A cat sitting on a mat",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"return_index": -1,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
def test_true_cfg(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["guidance_scale"] = 4.0
inputs["negative_prompt"] = "low quality"
image = pipe(**inputs).images
self.assertEqual(image[0].shape, (3, 32, 32))
def test_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
prompt=inputs["prompt"],
device=device,
max_sequence_length=inputs["max_sequence_length"],
)
inputs_with_embeds = self.get_dummy_inputs(device)
inputs_with_embeds.pop("prompt")
inputs_with_embeds["prompt_embeds"] = prompt_embeds
inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask
image = pipe(**inputs_with_embeds).images
self.assertEqual(image[0].shape, (3, 32, 32))
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
# PipelineTesterMixin compares outputs with assert_mean_pixel_difference, which assumes HWC numpy/PIL layout.
# With output_type="pt", tensors are CHW; numpy_to_pil then fails. Match QwenImage: only assert max diff.
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
# PipelineTesterMixin only keeps components whose keys contain "text" or "tokenizer"; this pipeline also
# needs `processor` for encode_prompt (apply_chat_template). Mirror the mixin with that key included.
if not hasattr(self.pipeline_class, "encode_prompt"):
return
components = self.get_dummy_components()
for key in components:
if "text_encoder" in key and hasattr(components[key], "eval"):
components[key].eval()
def _is_text_stack_component(k):
return "text" in k or "tokenizer" in k or k == "processor"
components_with_text_encoders = {}
for k in components:
if _is_text_stack_component(k):
components_with_text_encoders[k] = components[k]
else:
components_with_text_encoders[k] = None
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt)
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
required_params = []
for param in encode_prompt_parameters:
if param.name == "self" or param.name == "kwargs":
continue
if param.default is inspect.Parameter.empty:
required_params.append(param.name)
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
input_keys = list(inputs.keys())
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
pipe_call_parameters = pipe_call_signature.parameters
for required_param_name in required_params:
if required_param_name not in encode_prompt_inputs:
pipe_call_param = pipe_call_parameters.get(required_param_name, None)
if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty:
encode_prompt_inputs[required_param_name] = pipe_call_param.default
elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict):
encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name]
else:
raise ValueError(
f"Required parameter '{required_param_name}' in "
f"encode_prompt has no default in either encode_prompt or __call__."
)
with torch.no_grad():
encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
ast_visitor = ReturnNameVisitor()
encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class)
ast_visitor.visit(encode_prompt_tree)
prompt_embed_kwargs = ast_visitor.return_names
prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
adapted_prompt_embeds_kwargs = {
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
}
components_with_text_encoders = {}
for k in components:
if _is_text_stack_component(k):
components_with_text_encoders[k] = None
else:
components_with_text_encoders[k] = components[k]
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs}
if (
pipe_call_parameters.get("negative_prompt", None) is not None
and pipe_call_parameters.get("negative_prompt").default is not None
):
pipe_without_tes_inputs.update({"negative_prompt": None})
if (
pipe_call_parameters.get("prompt", None) is not None
and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty
and pipe_call_parameters.get("prompt_embeds", None) is not None
and pipe_call_parameters.get("prompt_embeds").default is None
):
pipe_without_tes_inputs.update({"prompt": None})
pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
full_pipe = self.pipeline_class(**components).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
pipe_out_2 = full_pipe(**inputs)[0]
if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray):
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor):
self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))