mirror of
https://github.com/huggingface/diffusers.git
synced 2026-06-05 00:53:09 +08:00
NucleusMoE-Image (#13317)
* adding NucleusMoE-Image model * update system prompt * Add text kv caching * Class/function name changes * add missing imports * add RoPE credits * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * update defaults * Update src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * review updates * fix the tests * clean up * update apply_text_kv_cache * SwiGLUExperts addition * fuse SwiGLUExperts up and gate proj * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/hooks/text_kv_cache.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * _SharedCacheKey -> TextKVCacheState * Apply style fixes * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite * Apply style fixes * run `make fix-copies` * fix import * refactor text KV cache to be managed by StateManager --------- Co-authored-by: Murali Nandan Nagarapu <nmn@withnucleus.ai> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -169,22 +169,23 @@ else:
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"TextKVCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
"apply_text_kv_cache",
|
||||
]
|
||||
)
|
||||
_import_structure["image_processor"] = [
|
||||
"IPAdapterMaskProcessor",
|
||||
"InpaintProcessor",
|
||||
"IPAdapterMaskProcessor",
|
||||
"PixArtImageProcessor",
|
||||
"VaeImageProcessor",
|
||||
"VaeImageProcessorLDM3D",
|
||||
]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -262,6 +263,7 @@ else:
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"NucleusMoEImageTransformer2DModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"OvisImageTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
@@ -396,6 +398,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
@@ -613,6 +616,7 @@ else:
|
||||
"MarigoldNormalsPipeline",
|
||||
"MochiPipeline",
|
||||
"MusicLDMPipeline",
|
||||
"NucleusMoEImagePipeline",
|
||||
"OmniGenPipeline",
|
||||
"OvisImagePipeline",
|
||||
"PaintByExamplePipeline",
|
||||
@@ -967,12 +971,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
apply_text_kv_cache,
|
||||
)
|
||||
from .image_processor import (
|
||||
InpaintProcessor,
|
||||
@@ -1057,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
OvisImageTransformer2DModel,
|
||||
ParallelConfig,
|
||||
@@ -1384,6 +1391,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MarigoldNormalsPipeline,
|
||||
MochiPipeline,
|
||||
MusicLDMPipeline,
|
||||
NucleusMoEImagePipeline,
|
||||
OmniGenPipeline,
|
||||
OvisImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
|
||||
@@ -27,3 +27,4 @@ if is_torch_available():
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
|
||||
|
||||
173
src/diffusers/hooks/text_kv_cache.py
Normal file
173
src/diffusers/hooks/text_kv_cache.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer"
|
||||
_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextKVCacheConfig:
|
||||
"""Enable exact (lossless) text K/V caching for transformer models.
|
||||
|
||||
Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all
|
||||
steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook
|
||||
before any intermediate tensor allocations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextKVCacheState(BaseState):
|
||||
"""Shared state between the transformer-level and block-level hooks.
|
||||
|
||||
The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so
|
||||
that block hooks can use it as a reliable cache key across denoising steps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.key: int | None = None
|
||||
|
||||
def reset(self):
|
||||
self.key = None
|
||||
|
||||
|
||||
class TextKVCacheBlockState(BaseState):
|
||||
"""Per-block state holding cached text key/value projections."""
|
||||
|
||||
def __init__(self):
|
||||
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
def reset(self):
|
||||
self.kv_cache.clear()
|
||||
|
||||
|
||||
class TextKVCacheTransformerHook(ModelHook):
|
||||
"""Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm``
|
||||
and writes it to shared state for the block hooks to read."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states")
|
||||
if encoder_hidden_states is not None:
|
||||
state: TextKVCacheState = self.state_manager.get_state()
|
||||
state.key = encoder_hidden_states.data_ptr()
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class TextKVCacheBlockHook(ModelHook):
|
||||
"""Caches ``(txt_key, txt_value)`` per block per unique prompt using
|
||||
the stable cache key from the shared state."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.block_state_manager = block_state_manager
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus
|
||||
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
if self.block_state_manager._current_context is None:
|
||||
self.block_state_manager.set_context("inference")
|
||||
|
||||
if "encoder_hidden_states" in kwargs:
|
||||
encoder_hidden_states = kwargs["encoder_hidden_states"]
|
||||
else:
|
||||
encoder_hidden_states = args[1]
|
||||
|
||||
if "image_rotary_emb" in kwargs:
|
||||
image_rotary_emb = kwargs["image_rotary_emb"]
|
||||
elif len(args) > 3:
|
||||
image_rotary_emb = args[3]
|
||||
else:
|
||||
image_rotary_emb = None
|
||||
|
||||
state: TextKVCacheState = self.state_manager.get_state()
|
||||
cache_key = state.key
|
||||
|
||||
block_state: TextKVCacheBlockState = self.block_state_manager.get_state()
|
||||
|
||||
if cache_key not in block_state.kv_cache:
|
||||
context = module.encoder_proj(encoder_hidden_states)
|
||||
|
||||
attn = module.attn
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
num_kv_heads = attn.inner_kv_dim // head_dim
|
||||
|
||||
txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
|
||||
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
_, txt_freqs = image_rotary_emb
|
||||
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
block_state.kv_cache[cache_key] = (txt_key, txt_value)
|
||||
|
||||
txt_key, txt_value = block_state.kv_cache[cache_key]
|
||||
|
||||
attn_kwargs = kwargs.get("attention_kwargs") or {}
|
||||
attn_kwargs["cached_txt_key"] = txt_key
|
||||
attn_kwargs["cached_txt_value"] = txt_value
|
||||
kwargs["attention_kwargs"] = attn_kwargs
|
||||
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module):
|
||||
self.block_state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
|
||||
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(TextKVCacheState)
|
||||
|
||||
transformer_hook = TextKVCacheTransformerHook(state_manager)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK)
|
||||
|
||||
for _, submodule in module.named_modules():
|
||||
if isinstance(submodule, NucleusMoEImageTransformerBlock):
|
||||
block_state_manager = StateManager(TextKVCacheBlockState)
|
||||
hook = TextKVCacheBlockHook(state_manager, block_state_manager)
|
||||
block_registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK)
|
||||
@@ -116,6 +116,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
@@ -236,6 +237,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
OvisImageTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
|
||||
@@ -41,11 +41,12 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
|
||||
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
- [`~hooks.TextKVCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -71,11 +72,13 @@ class CacheMixin:
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
apply_text_kv_cache,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
@@ -89,6 +92,8 @@ class CacheMixin:
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, TextKVCacheConfig):
|
||||
apply_text_kv_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -106,12 +111,14 @@ class CacheMixin:
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_BLOCK_HOOK, _TEXT_KV_CACHE_TRANSFORMER_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -129,6 +136,9 @@ class CacheMixin:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TextKVCacheConfig):
|
||||
registry.remove_hook(_TEXT_KV_CACHE_TRANSFORMER_HOOK, recurse=True)
|
||||
registry.remove_hook(_TEXT_KV_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
|
||||
else:
|
||||
|
||||
@@ -40,6 +40,7 @@ if is_torch_available():
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_ovis_image import OvisImageTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
|
||||
@@ -0,0 +1,925 @@
|
||||
# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus
|
||||
def _apply_rotary_emb_nucleus(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor | tuple[torch.Tensor],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def _compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None
|
||||
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
|
||||
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
|
||||
if encoder_hidden_states_mask is None:
|
||||
return text_seq_len, None, None
|
||||
|
||||
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
|
||||
raise ValueError(
|
||||
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
|
||||
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
|
||||
)
|
||||
|
||||
if encoder_hidden_states_mask.dtype != torch.bool:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
|
||||
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(
|
||||
has_active,
|
||||
active_positions.max(dim=1).values + 1,
|
||||
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
|
||||
)
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
class NucleusMoETimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=embedding_dim, time_embed_dim=4 * embedding_dim, out_dim=embedding_dim
|
||||
)
|
||||
self.norm = RMSNorm(embedding_dim, eps=1e-6)
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if use_additional_t_cond:
|
||||
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
conditioning = timesteps_emb
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
|
||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
|
||||
conditioning = conditioning + addition_t_emb
|
||||
|
||||
return self.norm(conditioning)
|
||||
|
||||
|
||||
class NucleusMoEEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self._rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self._rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self._rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self._rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self._rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self._rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
@staticmethod
|
||||
def _rope_params(index, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
device: torch.device = None,
|
||||
max_txt_seq_len: int | torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
|
||||
The maximum text sequence length for RoPE computation.
|
||||
"""
|
||||
if max_txt_seq_len is None:
|
||||
raise ValueError("Either `max_txt_seq_len` must be provided.")
|
||||
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
first_fhw = video_fhw[0]
|
||||
if not all(fhw == first_fhw for fhw in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in NucleusMoEEmbedRope. "
|
||||
"All images in the batch should have the same dimensions (frame, height, width). "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
|
||||
vid_freqs = []
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
if self.scale_rope:
|
||||
max_vid_index = torch.maximum(
|
||||
torch.tensor(height // 2, device=device, dtype=torch.long),
|
||||
torch.tensor(width // 2, device=device, dtype=torch.long),
|
||||
)
|
||||
else:
|
||||
max_vid_index = torch.maximum(
|
||||
torch.tensor(height, device=device, dtype=torch.long),
|
||||
torch.tensor(width, device=device, dtype=torch.long),
|
||||
)
|
||||
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index + torch.arange(max_txt_seq_len_int, device=device)]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class NucleusMoEAttnProcessor2_0:
|
||||
"""
|
||||
Attention processor for the NucleusMoE architecture. Image queries attend to concatenated image+text keys/values
|
||||
(cross-attention style, no text query). Supports grouped-query attention (GQA) when num_key_value_heads is set on
|
||||
the Attention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"NucleusMoEAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
cached_txt_key: torch.FloatTensor | None = None,
|
||||
cached_txt_value: torch.FloatTensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
num_kv_heads = attn.inner_kv_dim // head_dim
|
||||
num_kv_groups = attn.heads // num_kv_heads
|
||||
|
||||
img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1))
|
||||
img_key = attn.to_k(hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
img_value = attn.to_v(hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False)
|
||||
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False)
|
||||
|
||||
if cached_txt_key is not None and cached_txt_value is not None:
|
||||
txt_key, txt_value = cached_txt_key, cached_txt_value
|
||||
joint_key = torch.cat([img_key, txt_key], dim=1)
|
||||
joint_value = torch.cat([img_value, txt_value], dim=1)
|
||||
elif encoder_hidden_states is not None:
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
joint_key = torch.cat([img_key, txt_key], dim=1)
|
||||
joint_value = torch.cat([img_value, txt_value], dim=1)
|
||||
else:
|
||||
joint_key = img_key
|
||||
joint_value = img_value
|
||||
|
||||
if num_kv_groups > 1:
|
||||
joint_key = joint_key.repeat_interleave(num_kv_groups, dim=2)
|
||||
joint_value = joint_value.repeat_interleave(num_kv_groups, dim=2)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
img_query,
|
||||
joint_key,
|
||||
joint_value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(img_query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1:
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _is_moe_layer(strategy: str, layer_idx: int, num_layers: int) -> bool:
|
||||
if strategy == "leave_first_three_and_last_block_dense":
|
||||
return layer_idx >= 3 and layer_idx < num_layers - 1
|
||||
elif strategy == "leave_first_three_blocks_dense":
|
||||
return layer_idx >= 3
|
||||
elif strategy == "leave_first_block_dense":
|
||||
return layer_idx >= 1
|
||||
elif strategy == "all_moe":
|
||||
return True
|
||||
elif strategy == "all_dense":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SwiGLUExperts(nn.Module):
|
||||
"""
|
||||
Packed SwiGLU feed-forward experts for MoE: ``gate, up = (x @ gate_up_proj).chunk(2); out = (silu(gate) * up) @
|
||||
down_proj``.
|
||||
|
||||
Gate and up projections are fused into a single weight ``gate_up_proj`` so that only two grouped matmuls are needed
|
||||
at runtime (gate+up combined, then down).
|
||||
|
||||
Weights are stored pre-transposed relative to the standard linear-layer convention so that matmuls can be issued
|
||||
without a transpose at runtime.
|
||||
|
||||
Weight shapes:
|
||||
gate_up_proj: (num_experts, hidden_size, 2 * moe_intermediate_dim) -- fused gate + up projection down_proj:
|
||||
(num_experts, moe_intermediate_dim, hidden_size) -- down projection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
moe_intermediate_dim: int,
|
||||
num_experts: int,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.moe_intermediate_dim = moe_intermediate_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.use_grouped_mm = use_grouped_mm
|
||||
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(num_experts, hidden_size, 2 * moe_intermediate_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(num_experts, moe_intermediate_dim, hidden_size))
|
||||
|
||||
def _run_experts_for_loop(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
num_tokens_per_expert: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute SwiGLU MoE expert outputs using a sequential per-expert for loop.
|
||||
|
||||
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
|
||||
and so on — i.e. the layout produced by a standard token-permutation step (e.g. ``generate_permute_indices``).
|
||||
|
||||
``x`` may contain trailing padding rows appended by the permutation utility to reach a length that is a
|
||||
multiple of some alignment requirement. The padding rows are stripped before expert computation and re-appended
|
||||
as zeros so that the output shape matches ``x.shape``, keeping downstream scatter/gather indices valid.
|
||||
|
||||
.. note::
|
||||
``num_tokens_per_expert.tolist()`` synchronises the device with the host. This is acceptable for the loop
|
||||
path but means the method introduces a pipeline bubble. Use :meth:`forward` with ``use_grouped_mm=True``
|
||||
when a fully device-resident kernel is required (e.g. inside ``torch.compile``).
|
||||
|
||||
SwiGLU formula::
|
||||
|
||||
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
|
||||
|
||||
Args:
|
||||
x (Tensor): Pre-permuted input tokens of shape
|
||||
``(total_tokens_including_padding, hidden_dim)``.
|
||||
num_tokens_per_expert (Tensor): 1-D integer tensor of length
|
||||
``num_experts`` giving the number of real (non-padding) tokens assigned to each expert. Values may
|
||||
differ across experts to support load-imbalanced routing.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(total_tokens_including_padding, hidden_dim)``. Positions corresponding to padding rows
|
||||
contain zeros.
|
||||
"""
|
||||
# .tolist() triggers a host-device sync; see docstring note above.
|
||||
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
|
||||
|
||||
# x may be padded to a larger buffer size by the permutation utility.
|
||||
# Track the padding count so we can restore the original buffer shape.
|
||||
num_real_tokens = sum(num_tokens_per_expert_list)
|
||||
num_padding = x.shape[0] - num_real_tokens
|
||||
|
||||
# Split the real-token prefix of x into per-expert slices (variable length).
|
||||
x_per_expert = torch.split(
|
||||
x[:num_real_tokens],
|
||||
split_size_or_sections=num_tokens_per_expert_list,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
expert_outputs = []
|
||||
for expert_idx, x_expert in enumerate(x_per_expert):
|
||||
gate_up = torch.matmul(x_expert, self.gate_up_proj[expert_idx])
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
out_expert = torch.matmul(F.silu(gate) * up, self.down_proj[expert_idx])
|
||||
expert_outputs.append(out_expert)
|
||||
|
||||
# Concatenate real-token outputs, then re-append zero rows for the padding.
|
||||
out = torch.cat(expert_outputs, dim=0)
|
||||
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
|
||||
return out
|
||||
|
||||
def _run_experts_grouped_mm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
num_tokens_per_expert: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute SwiGLU MoE expert outputs using fused grouped GEMM kernels.
|
||||
|
||||
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
|
||||
and so on — the same layout required by :meth:`_run_experts_for_loop`.
|
||||
|
||||
This method is fully device-resident (no host-device sync) and is compatible with ``torch.compile``.
|
||||
|
||||
``F.grouped_mm`` is called with *exclusive end* offsets: ``offsets[k]`` is the exclusive end index of expert
|
||||
``k``'s token range in ``x`` (equivalently the inclusive start of expert ``k+1``'s range). This is the
|
||||
cumulative sum of ``num_tokens_per_expert``.
|
||||
|
||||
SwiGLU formula::
|
||||
|
||||
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
|
||||
|
||||
Args:
|
||||
x (Tensor): Pre-permuted input tokens of shape
|
||||
``(total_tokens, hidden_dim)``. No padding rows expected; ``total_tokens`` must equal
|
||||
``num_tokens_per_expert.sum()``.
|
||||
num_tokens_per_expert (Tensor): 1-D integer tensor of length
|
||||
``num_experts`` giving the number of tokens assigned to each expert.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(total_tokens, hidden_dim)`` with dtype matching ``x``.
|
||||
"""
|
||||
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
||||
|
||||
gate_up = F.grouped_mm(x, self.gate_up_proj, offs=offsets)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
out = F.grouped_mm(F.silu(gate) * up, self.down_proj, offs=offsets)
|
||||
|
||||
return out.type_as(x)
|
||||
|
||||
def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_grouped_mm:
|
||||
return self._run_experts_grouped_mm(x, num_tokens_per_expert)
|
||||
return self._run_experts_for_loop(x, num_tokens_per_expert)
|
||||
|
||||
|
||||
class NucleusMoELayer(nn.Module):
|
||||
"""
|
||||
Mixture-of-Experts layer with expert-choice routing and a shared expert.
|
||||
|
||||
Routed expert weights live in :class:`SwiGLUExperts`. The router concatenates a timestep embedding with the
|
||||
(unmodulated) hidden state to produce per-token affinity scores, then selects the top-C tokens per expert
|
||||
(expert-choice routing). A shared expert processes all tokens in parallel and its output is combined with the
|
||||
routed expert outputs via scatter-add.
|
||||
|
||||
SwiGLU expert computation is implemented by :class:`SwiGLUExperts`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
moe_intermediate_dim: int,
|
||||
num_experts: int,
|
||||
capacity_factor: float,
|
||||
use_sigmoid: bool,
|
||||
route_scale: float,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.moe_intermediate_dim = moe_intermediate_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.capacity_factor = capacity_factor
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.route_scale = route_scale
|
||||
|
||||
self.gate = nn.Linear(hidden_size * 2, num_experts, bias=False)
|
||||
|
||||
self.experts = SwiGLUExperts(
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
num_experts=num_experts,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
|
||||
self.shared_expert = FeedForward(
|
||||
dim=hidden_size,
|
||||
dim_out=hidden_size,
|
||||
inner_dim=moe_intermediate_dim,
|
||||
activation_fn="swiglu",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_unmodulated: torch.Tensor,
|
||||
timestep: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
bs, slen, dim = hidden_states.shape
|
||||
|
||||
if timestep is not None:
|
||||
timestep_expanded = timestep.unsqueeze(1).expand(-1, slen, -1)
|
||||
router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1)
|
||||
else:
|
||||
router_input = hidden_states_unmodulated
|
||||
|
||||
logits = self.gate(router_input)
|
||||
|
||||
if self.use_sigmoid:
|
||||
scores = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
else:
|
||||
scores = F.softmax(logits.float(), dim=-1).to(logits.dtype)
|
||||
|
||||
affinity = scores.transpose(1, 2) # (B, E, S)
|
||||
capacity = max(1, math.ceil(self.capacity_factor * slen / self.num_experts))
|
||||
|
||||
topk = torch.topk(affinity, k=capacity, dim=-1)
|
||||
top_indices = topk.indices # (B, E, C)
|
||||
gating = affinity.gather(dim=-1, index=top_indices) # (B, E, C)
|
||||
|
||||
batch_offsets = torch.arange(bs, device=hidden_states.device, dtype=torch.long).view(bs, 1, 1) * slen
|
||||
global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
||||
gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
||||
|
||||
token_score_sums = torch.zeros(bs * slen, device=hidden_states.device, dtype=gating_flat.dtype)
|
||||
token_score_sums.scatter_add_(0, global_token_indices, gating_flat)
|
||||
gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12)
|
||||
gating_flat = gating_flat * self.route_scale
|
||||
|
||||
x_flat = hidden_states.reshape(bs * slen, dim)
|
||||
routed_input = x_flat[global_token_indices]
|
||||
|
||||
tokens_per_expert = bs * capacity
|
||||
num_tokens_per_expert = torch.full(
|
||||
(self.num_experts,),
|
||||
tokens_per_expert,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
routed_output = self.experts(routed_input, num_tokens_per_expert)
|
||||
routed_output = (routed_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype)
|
||||
|
||||
out = self.shared_expert(hidden_states).reshape(bs * slen, dim)
|
||||
|
||||
scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, dim)
|
||||
out = out.scatter_add(dim=0, index=scatter_idx, src=routed_output)
|
||||
out = out.reshape(bs, slen, dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class NucleusMoEImageTransformerBlock(nn.Module):
|
||||
"""
|
||||
Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image stream receives adaptive modulation;
|
||||
the text context is projected per-block and used as cross-attention keys/values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_key_value_heads: int | None = None,
|
||||
joint_attention_dim: int = 3584,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
moe_enabled: bool = False,
|
||||
num_experts: int = 128,
|
||||
moe_intermediate_dim: int = 1344,
|
||||
capacity_factor: float = 8.0,
|
||||
use_sigmoid: bool = False,
|
||||
route_scale: float = 2.5,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.moe_enabled = moe_enabled
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
self.encoder_proj = nn.Linear(joint_attention_dim, dim)
|
||||
|
||||
self.pre_attn_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_key_value_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=dim,
|
||||
added_proj_bias=False,
|
||||
out_dim=dim,
|
||||
out_bias=False,
|
||||
bias=False,
|
||||
processor=NucleusMoEAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
context_pre_only=None,
|
||||
)
|
||||
|
||||
self.pre_mlp_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
||||
|
||||
if moe_enabled:
|
||||
self.img_mlp = NucleusMoELayer(
|
||||
hidden_size=dim,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
num_experts=num_experts,
|
||||
capacity_factor=capacity_factor,
|
||||
use_sigmoid=use_sigmoid,
|
||||
route_scale=route_scale,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
else:
|
||||
mlp_inner_dim = int(dim * mlp_ratio * 2 / 3) // 128 * 128
|
||||
self.img_mlp = FeedForward(
|
||||
dim=dim,
|
||||
dim_out=dim,
|
||||
inner_dim=mlp_inner_dim,
|
||||
activation_fn="swiglu",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
|
||||
|
||||
gate1 = gate1.clamp(min=-2.0, max=2.0)
|
||||
gate2 = gate2.clamp(min=-2.0, max=2.0)
|
||||
|
||||
attn_kwargs = attention_kwargs or {}
|
||||
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)
|
||||
|
||||
img_normed = self.pre_attn_norm(hidden_states)
|
||||
img_modulated = img_normed * (1 + scale1)
|
||||
|
||||
img_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=context,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate1.tanh() * img_attn_output
|
||||
|
||||
img_normed2 = self.pre_mlp_norm(hidden_states)
|
||||
img_modulated2 = img_normed2 * (1 + scale2)
|
||||
|
||||
if self.moe_enabled:
|
||||
img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb)
|
||||
else:
|
||||
img_mlp_output = self.img_mlp(img_modulated2)
|
||||
|
||||
hidden_states = hidden_states + gate2.tanh() * img_mlp_output
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
fp16_finfo = torch.finfo(torch.float16)
|
||||
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NucleusMoEImageTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
Nucleus MoE Transformer for image generation. Single-stream DiT with cross-attention to text and optional
|
||||
Mixture-of-Experts feed-forward layers.
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `24`):
|
||||
The number of transformer blocks.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `16`):
|
||||
The number of attention heads to use.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
The number of key/value heads for grouped-query attention. Defaults to `num_attention_heads`.
|
||||
joint_attention_dim (`int`, defaults to `3584`):
|
||||
The embedding dimension of the encoder hidden states (text).
|
||||
axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
Multiplier for the MLP hidden dimension in dense (non-MoE) blocks.
|
||||
moe_enabled (`bool`, defaults to `True`):
|
||||
Whether to use Mixture-of-Experts layers.
|
||||
dense_moe_strategy (`str`, defaults to ``"leave_first_three_and_last_block_dense"``):
|
||||
Strategy for choosing which layers are MoE vs dense.
|
||||
num_experts (`int`, defaults to `128`):
|
||||
Number of experts per MoE layer.
|
||||
moe_intermediate_dim (`int`, defaults to `1344`):
|
||||
Hidden dimension inside each expert.
|
||||
capacity_factors (`float | list[float]`, defaults to `8.0`):
|
||||
Expert-choice capacity factor per layer.
|
||||
use_sigmoid (`bool`, defaults to `False`):
|
||||
Use sigmoid instead of softmax for routing scores.
|
||||
route_scale (`float`, defaults to `2.5`):
|
||||
Scaling factor applied to routing weights.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NucleusMoEImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["NucleusMoEImageTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: int | None = None,
|
||||
num_layers: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int | None = None,
|
||||
joint_attention_dim: int = 3584,
|
||||
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
|
||||
mlp_ratio: float = 4.0,
|
||||
moe_enabled: bool = True,
|
||||
dense_moe_strategy: str = "leave_first_three_and_last_block_dense",
|
||||
num_experts: int = 128,
|
||||
moe_intermediate_dim: int = 1344,
|
||||
capacity_factors: float | list[float] = 8.0,
|
||||
use_sigmoid: bool = False,
|
||||
route_scale: float = 2.5,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
capacity_factors = capacity_factors if isinstance(capacity_factors, list) else [capacity_factors] * num_layers
|
||||
|
||||
self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
|
||||
self.time_text_embed = NucleusMoETimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
NucleusMoEImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
moe_enabled=moe_enabled and _is_moe_layer(dense_moe_strategy, idx, num_layers),
|
||||
num_experts=num_experts,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
capacity_factor=capacity_factors[idx],
|
||||
use_sigmoid=use_sigmoid,
|
||||
route_scale=route_scale,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
for idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
img_shapes: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`NucleusMoEImageTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
img_shapes (`list[tuple[int, int, int]]`, *optional*):
|
||||
Image shapes ``(frame, height, width)`` for RoPE computation.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Boolean mask for the encoder hidden states.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs forwarded to the attention processor.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`].
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
|
||||
text_seq_len, _, encoder_hidden_states_mask = _compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=block_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -420,6 +420,7 @@ else:
|
||||
"SkyReelsV2ImageToVideoPipeline",
|
||||
"SkyReelsV2Pipeline",
|
||||
]
|
||||
_import_structure["nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImagePipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
@@ -768,6 +769,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MarigoldNormalsPipeline,
|
||||
)
|
||||
from .mochi import MochiPipeline
|
||||
from .nucleusmoe_image import NucleusMoEImagePipeline
|
||||
from .omnigen import OmniGenPipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
|
||||
@@ -77,6 +77,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .nucleusmoe_image import NucleusMoEImagePipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -179,6 +180,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("helios", HeliosPipeline),
|
||||
("helios-pyramid", HeliosPyramidPipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("nucleusmoe-image", NucleusMoEImagePipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
("z-image", ZImagePipeline),
|
||||
|
||||
48
src/diffusers/pipelines/nucleusmoe_image/__init__.py
Normal file
48
src/diffusers/pipelines/nucleusmoe_image/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["NucleusMoEImagePipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_nucleusmoe_image import NucleusMoEImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,644 @@
|
||||
# Copyright 2025 Nucleus-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLQwenImage, NucleusMoEImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import NucleusMoEImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects."
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import NucleusMoEImagePipeline
|
||||
|
||||
>>> pipe = NucleusMoEImagePipeline.from_pretrained("NucleusAI/NucleusMoE-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("nucleus_moe.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class NucleusMoEImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using NucleusMoE.
|
||||
|
||||
This pipeline uses a single-stream DiT with Mixture-of-Experts feed-forward layers, cross-attention to a Qwen3-VL
|
||||
text encoder, and a flow-matching Euler discrete scheduler.
|
||||
|
||||
Args:
|
||||
transformer ([`NucleusMoEImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLQwenImage`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3VLForConditionalGeneration`]):
|
||||
Text encoder for computing prompt embeddings.
|
||||
processor ([`Qwen3VLProcessor`]):
|
||||
Processor for tokenizing text inputs.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: NucleusMoEImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen3VLForConditionalGeneration,
|
||||
processor: Qwen3VLProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
processor=processor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
self.default_sample_size = 128
|
||||
self.default_max_sequence_length = 1024
|
||||
self.default_return_index = -8
|
||||
|
||||
def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_SYSTEM_PROMPT
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_mask: torch.Tensor | None = None,
|
||||
max_sequence_length: int | None = None,
|
||||
return_index: int | None = None,
|
||||
):
|
||||
r"""
|
||||
Encode text prompt(s) into embeddings using the Qwen3-VL text encoder.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to encode.
|
||||
device (`torch.device`, *optional*):
|
||||
Torch device for the resulting tensors.
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
Number of images to generate per prompt.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Skips encoding when provided.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated embeddings.
|
||||
max_sequence_length (`int`, defaults to 1024):
|
||||
Maximum token length for the encoded prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
return_index = return_index or self.default_return_index
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
formatted = [self._format_prompt(p) for p in prompt]
|
||||
|
||||
inputs = self.processor(
|
||||
text=formatted,
|
||||
padding="longest",
|
||||
pad_to_multiple_of=8,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
).to(device=device)
|
||||
|
||||
prompt_embeds_mask = inputs.attention_mask
|
||||
|
||||
outputs = self.text_encoder(**inputs, use_cache=False, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds = outputs.hidden_states[return_index]
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.to(device=device)
|
||||
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
return_index=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} "
|
||||
f"but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, "
|
||||
f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and "
|
||||
f"`negative_prompt_embeds`: {negative_prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if return_index is not None and abs(return_index) >= self.text_encoder.config.text_config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"absolute value of `return_index` cannot be >= {self.text_encoder.config.text_config.num_hidden_layers} "
|
||||
f"but is {abs(return_index)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size):
|
||||
latents = latents.view(
|
||||
batch_size, num_channels_latents, height // patch_size, patch_size, width // patch_size, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(
|
||||
batch_size, (height // patch_size) * (width // patch_size), num_channels_latents * patch_size * patch_size
|
||||
)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, patch_size, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
patch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = patch_size * (int(height) // (self.vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (self.vae_scale_factor * patch_size))
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
negative_prompt: str | list[str] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int | None = None,
|
||||
return_index: int | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_mask: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds_mask: torch.Tensor | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
negative_prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, an empty string is used when
|
||||
`true_cfg_scale > 1`.
|
||||
true_cfg_scale (`float`, *optional*, defaults to 4.0):
|
||||
Classifier-free guidance scale. Values greater than 1 enable CFG.
|
||||
height (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule. If not defined, a linear schedule is used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
|
||||
One or a list of torch generators to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for image generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `"pil"`, `"np"`, or `"latent"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`NucleusMoEImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Kwargs passed to the attention processor.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`list`, *optional*):
|
||||
Tensor inputs for the `callback_on_step_end` function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the text prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`NucleusMoEImagePipelineOutput`] or `tuple`:
|
||||
[`NucleusMoEImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple` where the first element
|
||||
is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
max_sequence_length = max_sequence_length or self.default_max_sequence_length
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs or {}
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
do_cfg = guidance_scale > 1
|
||||
|
||||
if do_cfg and not has_neg_prompt:
|
||||
negative_prompt = [""] * batch_size
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
if do_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
patch_size = self.transformer.config.patch_size
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
patch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
img_shapes = [
|
||||
(1, height // self.vae_scale_factor // patch_size, width // self.vae_scale_factor // patch_size)
|
||||
] * (batch_size * num_images_per_prompt)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
if self.transformer.is_cache_enabled:
|
||||
self.transformer._reset_stateful_cache()
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / self.scheduler.config.num_train_timesteps,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_cfg:
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / self.scheduler.config.num_train_timesteps,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
noise_pred = -noise_pred
|
||||
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, patch_size, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return NucleusMoEImagePipelineOutput(images=image)
|
||||
20
src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py
Normal file
20
src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class NucleusMoEImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for NucleusMoE Image pipelines.
|
||||
|
||||
Args:
|
||||
images (`list[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: list[PIL.Image.Image] | np.ndarray
|
||||
@@ -287,6 +287,21 @@ class TaylorSeerCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TextKVCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
@@ -311,6 +326,10 @@ def apply_taylorseer_cache(*args, **kwargs):
|
||||
requires_backends(apply_taylorseer_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_text_kv_cache(*args, **kwargs):
|
||||
requires_backends(apply_text_kv_cache, ["torch"])
|
||||
|
||||
|
||||
class InpaintProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1511,6 +1530,21 @@ class MultiControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class NucleusMoEImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class OmniGenTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2567,6 +2567,21 @@ class MusicLDMPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class NucleusMoEImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class OmniGenPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import NucleusMoEImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return NucleusMoEImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
"moe_enabled": False,
|
||||
"capacity_factors": [8.0, 8.0],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_with_attention_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Mask out some text tokens
|
||||
mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask[:, 4:] = 0
|
||||
inputs["encoder_hidden_states_mask"] = mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_without_attention_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = None
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerLoRAHotSwap(
|
||||
NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin
|
||||
):
|
||||
"""LoRA hot-swapping tests for NucleusMoE Image Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for NucleusMoE Image Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for NucleusMoE Image Transformer."""
|
||||
0
tests/pipelines/nucleusmoe_image/__init__.py
Normal file
0
tests/pipelines/nucleusmoe_image/__init__.py
Normal file
337
tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py
Normal file
337
tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
NucleusMoEImagePipeline,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = NucleusMoEImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = NucleusMoEImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=16,
|
||||
axes_dims_rope=(8, 4, 4),
|
||||
moe_enabled=False,
|
||||
capacity_factors=[8.0, 8.0],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
z_dim = 4
|
||||
vae = AutoencoderKLQwenImage(
|
||||
base_dim=z_dim * 6,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True],
|
||||
# fmt: off
|
||||
latents_mean=[0.0] * z_dim,
|
||||
latents_std=[1.0] * z_dim,
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"vocab_size": 151936,
|
||||
"head_dim": 8,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_channels": 16,
|
||||
},
|
||||
)
|
||||
text_encoder = Qwen3VLForConditionalGeneration(config).eval()
|
||||
processor = Qwen3VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"processor": processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A cat sitting on a mat",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"return_index": -1,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_true_cfg(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["guidance_scale"] = 4.0
|
||||
inputs["negative_prompt"] = "low quality"
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image[0].shape, (3, 32, 32))
|
||||
|
||||
def test_prompt_embeds(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
|
||||
prompt=inputs["prompt"],
|
||||
device=device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
|
||||
inputs_with_embeds = self.get_dummy_inputs(device)
|
||||
inputs_with_embeds.pop("prompt")
|
||||
inputs_with_embeds["prompt_embeds"] = prompt_embeds
|
||||
inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask
|
||||
|
||||
image = pipe(**inputs_with_embeds).images
|
||||
self.assertEqual(image[0].shape, (3, 32, 32))
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
# PipelineTesterMixin compares outputs with assert_mean_pixel_difference, which assumes HWC numpy/PIL layout.
|
||||
# With output_type="pt", tensors are CHW; numpy_to_pil then fails. Match QwenImage: only assert max diff.
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
|
||||
# PipelineTesterMixin only keeps components whose keys contain "text" or "tokenizer"; this pipeline also
|
||||
# needs `processor` for encode_prompt (apply_chat_template). Mirror the mixin with that key included.
|
||||
if not hasattr(self.pipeline_class, "encode_prompt"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
for key in components:
|
||||
if "text_encoder" in key and hasattr(components[key], "eval"):
|
||||
components[key].eval()
|
||||
|
||||
def _is_text_stack_component(k):
|
||||
return "text" in k or "tokenizer" in k or k == "processor"
|
||||
|
||||
components_with_text_encoders = {}
|
||||
for k in components:
|
||||
if _is_text_stack_component(k):
|
||||
components_with_text_encoders[k] = components[k]
|
||||
else:
|
||||
components_with_text_encoders[k] = None
|
||||
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
|
||||
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt)
|
||||
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
|
||||
|
||||
required_params = []
|
||||
for param in encode_prompt_parameters:
|
||||
if param.name == "self" or param.name == "kwargs":
|
||||
continue
|
||||
if param.default is inspect.Parameter.empty:
|
||||
required_params.append(param.name)
|
||||
|
||||
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
|
||||
input_keys = list(inputs.keys())
|
||||
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
|
||||
|
||||
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
|
||||
pipe_call_parameters = pipe_call_signature.parameters
|
||||
|
||||
for required_param_name in required_params:
|
||||
if required_param_name not in encode_prompt_inputs:
|
||||
pipe_call_param = pipe_call_parameters.get(required_param_name, None)
|
||||
if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty:
|
||||
encode_prompt_inputs[required_param_name] = pipe_call_param.default
|
||||
elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict):
|
||||
encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Required parameter '{required_param_name}' in "
|
||||
f"encode_prompt has no default in either encode_prompt or __call__."
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
|
||||
|
||||
ast_visitor = ReturnNameVisitor()
|
||||
encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class)
|
||||
ast_visitor.visit(encode_prompt_tree)
|
||||
prompt_embed_kwargs = ast_visitor.return_names
|
||||
prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
|
||||
|
||||
adapted_prompt_embeds_kwargs = {
|
||||
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
|
||||
}
|
||||
|
||||
components_with_text_encoders = {}
|
||||
for k in components:
|
||||
if _is_text_stack_component(k):
|
||||
components_with_text_encoders[k] = None
|
||||
else:
|
||||
components_with_text_encoders[k] = components[k]
|
||||
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
|
||||
|
||||
pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs}
|
||||
if (
|
||||
pipe_call_parameters.get("negative_prompt", None) is not None
|
||||
and pipe_call_parameters.get("negative_prompt").default is not None
|
||||
):
|
||||
pipe_without_tes_inputs.update({"negative_prompt": None})
|
||||
|
||||
if (
|
||||
pipe_call_parameters.get("prompt", None) is not None
|
||||
and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty
|
||||
and pipe_call_parameters.get("prompt_embeds", None) is not None
|
||||
and pipe_call_parameters.get("prompt_embeds").default is None
|
||||
):
|
||||
pipe_without_tes_inputs.update({"prompt": None})
|
||||
|
||||
pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
|
||||
|
||||
full_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
pipe_out_2 = full_pipe(**inputs)[0]
|
||||
|
||||
if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray):
|
||||
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
|
||||
elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor):
|
||||
self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
|
||||
Reference in New Issue
Block a user