mirror of
https://github.com/huggingface/diffusers.git
synced 2026-05-28 00:39:35 +08:00
Add ACE-Step pipeline for text-to-music generation (#13095)
* Add ACE-Step pipeline for text-to-music generation
Rebased on origin/main from the original pr-13095 branch (3 commits squashed).
- AceStepDiTModel: Diffusion Transformer with RoPE, GQA, sliding window,
AdaLN timestep conditioning, and cross-attention.
- AceStepConditionEncoder: fuses text / lyric / timbre into a single
cross-attention sequence.
- AceStepPipeline: text2music / cover / repaint / extract / lego / complete.
- Conversion script for the original checkpoint layout.
- Docs + tests.
* Fix ACE-Step pipeline audio quality and auto-detect turbo/base/sft variants
The PR's original inference produced low-quality audio on turbo because the
pipeline (a) mangled the SFT prompt format, (b) applied classifier-free guidance
with the wrong unconditional embedding (empty-string encoded vs. the learned
`null_condition_emb`), and (c) hardcoded turbo defaults even when loading a
base/SFT checkpoint.
Changes:
* Converter preserves `null_condition_emb` (stored under the condition encoder)
and propagates `is_turbo`/`model_version` into the transformer config so the
pipeline can route per-variant defaults.
* `AceStepConditionEncoder` registers `null_condition_emb` as a learned
parameter matching the original module.
* Pipeline auto-detects variant via `is_turbo`/`model_version` and picks
defaults that match `acestep/inference.py`:
* turbo: steps=8, shift=3.0, guidance_scale=1.0 (no CFG)
* base/SFT: steps=27, shift=1.0, guidance_scale=7.0
* Base/SFT timestep schedule uses the linear+shift transform from
`acestep/models/base/modeling_acestep_v15_base.py`; turbo still uses the
hardcoded 8-step `SHIFT_TIMESTEPS` table.
* CFG reuses the learned `null_condition_emb` and batches the
conditional+unconditional forwards into a single transformer call.
* `SFT_GEN_PROMPT` matches the newline layout in `acestep/constants.py` so the
text encoder sees the same prompt distribution it was trained on.
DiT parity vs. the original ACE-Step 1.5 turbo DiT is bit-identical
(max_abs=0.0 in fp32 eager/SDPA across 4 seed/shape cases) — see
scripts/dit_parity_test.py.
* Add ACE-Step parity test scripts
Two developer-facing parity harnesses live under scripts/:
* dit_parity_test.py — loads the same converted turbo weights into the
original AceStepDiTModel and the diffusers AceStepDiTModel, drives
identical (hidden_states, timestep, timestep_r, encoder_hidden_states,
context_latents) inputs, and asserts max-abs-diff ≤ 1e-5 in fp32
eager/SDPA. Currently passes bit-identical (max_abs=0) across four
shape/seed cases including batched + odd-length paths.
* audio_parity_jieyue.py — full end-to-end audio parity. Given the same
JSON example, runs both the original ACE-Step 1.5 pipeline and the
diffusers AceStepPipeline at matched seed/precision (bf16 + FA2 by
default) and saves side-by-side .wav files for listening verification.
Supports text2music / cover / repaint × turbo / base / sft via a
--matrix mode that writes 18 wavs named
{variant}_{task}_{official,diffusers}.wav.
* Route SFT parity to acestep-v15-sft checkpoint
On jieyue the release tree has a dedicated SFT checkpoint at
checkpoints/acestep-v15-sft with its own modeling_acestep_v15_base.py
shipped under acestep/models/sft/. Point the SFT row of the parity matrix
at that checkpoint / module so we're testing the actual SFT weights, not
the plain base ones.
* audio_parity_jieyue: fix doubled 'acestep-' in cache path; --converted-root flag
Previously the converted-pipeline cache dir was
`/tmp/acestep-<variant>-diffusers` but <variant> already starts with
"acestep-", giving `/tmp/acestep-acestep-v15-turbo-diffusers`. Drop the
prefix.
On jieyue the overlay rootfs (including /tmp) only has a few GB free; a
full turbo conversion needs ~5 GB per variant. Add --converted-root (env
ACESTEP_CONVERTED_ROOT) so the cache can live on vepfs.
* audio_parity_jieyue: two-phase matrix bootstraps cover/repaint from text2music
The ACE-Step release bundle on jieyue doesn't ship sample .wav/.mp3
files, so matrix mode had no default --src-audio and would skip
cover/repaint entirely. Run text2music first for every variant, then
reuse the TURBO official text2music output as the shared source for the
cover/repaint rows. Users can still override with --src-audio.
* audio_parity_jieyue: seed the diffusers generator on the pipeline device
The ORIGINAL ACE-Step pipeline seeds on the execution device
(`torch.Generator(device=device).manual_seed(seed)`), i.e. the CUDA RNG
stream when running on GPU. Previously the parity harness seeded the
diffusers side with a CPU generator, so even though the seed integer
matched, the two sides drew different noise from the outset and the
final outputs were essentially uncorrelated. Use the execution-device
generator on both sides for a fair comparison.
* Fix ACE-Step pipeline: switch to APG guidance + peak normalization
Two issues found after the first jieyue audio parity run:
1. The original base/SFT pipeline uses APG (Adaptive Projected Guidance,
acestep/models/common/apg_guidance.py) with a stateful momentum
buffer and norm/projection steps — NOT vanilla CFG. Using vanilla CFG
produced uncorrelated outputs vs. the reference (pearson ~0.0 on
20 s samples); this PR ports `_apg_forward` + `_APGMomentumBuffer`
and plugs them into the denoising loop when `guidance_scale > 1`.
Momentum is instantiated once per pipeline call (persists across
denoising steps) to match the reference semantics.
2. The post-VAE "anti-clipping normalization" in this pipeline was
`audio /= std * 5` with a `std<1 -> std=1` guard. The original
post-processing in
acestep/core/generation/handler/generate_music_decode.py is simple
peak normalization: `if audio.abs().max() > 1: audio /= peak`. The
std-based proxy both (a) let clips with peak < 1 leak through
unchanged (over-quiet) and (b) failed to bring clipping peaks to
exactly 1 in a bunch of base/SFT cases (observed max=1.000, std=0.200
repeatedly in the first parity run). Switch to peak normalization on
both sides.
Tested via scripts/audio_parity_jieyue.py on A800; re-run pending to
confirm the base/SFT correlation improvements.
* Fix ACE-Step chunk mask values to match the original pipeline
The DiT receives `context_latents = concat(src_latents, chunk_mask)` on the
channel dim, and was trained with chunk_mask values drawn from the three
sentinels documented in acestep/inference.py:
2.0 -> model-decided (default for text2music / cover / full-generation)
1.0 -> keep this latent frame from src_latents (repaint preserved region)
0.0 -> explicitly repaint this frame (only inside the repaint window)
Previously _build_chunk_mask returned all-1.0 for text2music (and cover /
lego), and an inverted 0/1 mask for repaint (1 inside the window, 0 outside).
Either case puts context_latents out of distribution. Switch text2music /
cover to the 2.0 sentinel and flip the repaint mask so it's 1.0 outside /
0.0 inside. Update the repaint src_latents zero-out to multiply by the new
mask (was `1 - chunk_mask`) so the zero region still lines up with the
repaint window.
* Add direct invoker for ACE-Step generate_music (ground truth)
Our earlier audio_parity_jieyue.py reconstructs the original pipeline by
calling AceStepConditionGenerationModel.generate_audio() directly, which
silently skips a lot of the real handler plumbing (conditioning masks,
silence-latent tiling, cover/repaint pre-processing, etc.). That made the
'official' wavs we saved sound wrong — flat, drone-like, not music.
This new script calls acestep.inference.generate_music end-to-end through
the real AceStepHandler, with LM + CoT explicitly disabled so we still have
a deterministic comparison. Use it to generate the ground-truth 'official'
wav for a given JSON example, then separately run the diffusers pipeline
with the same inputs and diff the two.
* run_official_generate_music: call initialize_service to bind a DiT variant
AceStepHandler() is a shell — you have to call handler.initialize_service(
project_root=..., config_path=..., device=..., use_flash_attention=..., ...)
before generate_music will work. Mirror what cli.py does at the equivalent
spot (around cli.py:1400).
* Fix silence-reference for ACE-Step timbre encoder
The root cause for the flat / drone-like outputs I was seeing (including
in my 'official' reconstruction): when no reference_audio is provided the
pipeline was feeding literal zeros to the timbre encoder. The real
handler feeds a slice of the learned `silence_latent` tensor.
The handler also transposes silence_latent on load (see
acestep/core/generation/handler/init_service_loader.py:214:
self.silence_latent = torch.load(...).transpose(1, 2)
) converting [1, 64, 15000] -> [1, 15000, 64] so that
`silence_latent[:, :750, :]` yields the expected [1, 750, 64] shape.
Changes:
* Converter: load silence_latent.pt, transpose to [1, T, C], bake into
the condition_encoder safetensors under key `silence_latent`.
(Also keeps the raw .pt file at the pipeline root for debugging.)
* AceStepConditionEncoder: register `silence_latent` as a persistent
buffer so from_pretrained loads it alongside the trained weights.
* Pipeline: when reference_audio is None, slice
`condition_encoder.silence_latent[:, :timbre_fix_frame, :]` and
broadcast across the batch instead of zeros. Emits a loud warning
(and falls back to zeros) if the buffer is all-zero — that means the
checkpoint was produced by an older converter and should be rebuilt.
* audio_parity_jieyue.py: the reference path now matches the handler's
silence-latent slicing.
Without this fix, every variant/task combo produced drone-like audio
even when my numeric DiT-forward parity claimed they were identical.
* Fix three more ACE-Step pipeline bugs I found by dumping real inputs
Instrumented the live generate_audio call in the real ACE-Step handler and
observed the exact tensors it sees — my diffusers pipeline was wrong in
three independent ways:
1. src_latents for text2music should be silence_latent tiled to
latent_length, NOT zeros. The handler fills no-target cases from
silence_latent_tiled (observed std=0.96). Zeros are OOD for the DiT
context_latents concat and produce drone-like outputs.
2. chunk_mask values cap at 1.0 (not 2.0). The handler starts with a
bool tensor (True inside the generate span, False outside); the
chunk_mask_modes=auto -> 2.0 override does NOT take effect because
the underlying tensor is bool, so setting entry = 2.0 casts to True.
After the later .to(dtype) float cast, the DiT sees 1.0/0.0 — exactly
what I observed in the captured tensor (unique values = [True]).
3. Default shift is 1.0 for ALL variants, including turbo. I was
defaulting turbo to shift=3.0 which picks a different SHIFT_TIMESTEPS
table (the 8-step schedule is keyed by shift, not variant).
Also:
* Added _silence_latent_tiled() helper that slices / tiles the learned
silence_latent (now loaded as a buffer on the condition encoder) to
the requested latent length.
* Repaint path now substitutes silence_latent (not raw zeros) inside
the repaint window — matches conditioning_masks.py.
* audio_parity_jieyue.py mirrors the same src/chunk/shift choices on
its 'original' leg for apples-to-apples parity once the buggy
reconstruction is removed from the picture.
* Add peak+loudness post-normalization to AceStepPipeline
The real pipeline normalizes audio in two stages (see
acestep/audio_utils.py:72 normalize_audio + generate_music_decode.py):
1. if peak > 1: audio /= peak (anti-clip)
2. audio *= target_amp / peak (target_amp = 10 ** (-1/20) ~ 0.891)
Step 2 is loudness normalization to -1 dBFS. Without it diffusers outputs
had peak=1.0 vs the real 0.891 — same music content (pearson was ~0.86
already), just 1.12x louder. Add step 2 after the existing anti-clip step.
* Match acestep/inference.py inference_steps=8 for ALL variants
GenerationParams.inference_steps default is 8 — turbo AND base/SFT. I had
base/SFT defaulting to 27 here, so every base/SFT parity run was comparing
a 27-step diffusers trajectory against an 8-step real trajectory. Different
number of denoising steps means different audio even at fixed seed.
This likely explains the lower base/SFT correlation in my earlier jieyue
runs (turbo was 0.86, base/SFT were 0.32-0.34). Aligning step counts
should bring base/SFT closer to turbo parity.
* Address PR #13095 review: rename classes + reuse diffusers primitives
Response to dg845's PR comments batch 1+2. DiT parity harness still bit-identical
(max_abs=0 on fp32 / SDPA across 4 shape cases).
Transformer file:
* Rename AceStepDiTModel -> AceStepTransformer1DModel (alias kept).
* Rename AceStepDiTLayer -> AceStepTransformerBlock (alias kept).
* Inherit AttentionMixin + CacheMixin on the DiT model.
* Swap in diffusers.models.normalization.RMSNorm for the hand-rolled
AceStepRMSNorm (weight-key-compatible).
* Swap the hand-rolled rotary embedding + apply_rotary for diffusers'
get_1d_rotary_pos_embed + apply_rotary_emb (use_real_unbind_dim=-2 to
match the cat-half convention ACE-Step inherits from Qwen3).
* Use get_timestep_embedding with flip_sin_to_cos=True — keeps the
(cos, sin) ordering of the original sinusoidal. State-dict-compatible.
* Drop max_position_embeddings arg from DiT config (RoPE computes freqs
per call based on seq_len); converter drops it.
* Gradient-checkpoint call now takes just the layer module (matches the
Flux2 idiom).
Pipeline modeling file (pipelines/ace_step/modeling_ace_step.py):
* Moved _pack_sequences + AceStepEncoderLayer here — they aren't used
by the DiT, so they shouldn't live in the transformer file.
* AceStepLyricEncoder + AceStepTimbreEncoder set
_supports_gradient_checkpointing = True and wrap encoder-layer calls
through the checkpointing func when enabled.
* Use diffusers RMSNorm + the RoPE helper from the transformer file
(shared single implementation).
Converter (scripts/convert_ace_step_to_diffusers.py):
* model_index.json now carries AceStepTransformer1DModel.
* Drop max_position_embeddings / use_sliding_window from the emitted
configs.
No numerical regressions: scripts/dit_parity_test.py PASSES with
max_abs=0.0 on fp32/SDPA across short, long, batched, and
padding-path shape variants.
* Address PR #13095 review: pipeline polish + converter HF-hub support
Response to dg845 review comments on the pipeline side. DiT parity still
bit-identical (max_abs=0 across 4 shape cases).
Pipeline (pipelines/ace_step/pipeline_ace_step.py):
* Add `sample_rate` + `latents_per_second` properties sourced from the
VAE config so the pipeline no longer hardcodes 48000 / 25 / 1920.
Propagates through prepare_latents, chunk_mask window math, and the
audio-duration round-trip.
* Add `do_classifier_free_guidance` property (matches LTX2 et al.).
* Add `check_inputs(...)` called from `__call__` before allocating noise.
Validates prompt type, lyrics type, task_type, step count, guidance
scale, shift, cfg interval bounds and repaint window ordering.
* Add `callback_on_step_end` + `callback_on_step_end_tensor_inputs` —
the modern callback form. The legacy `callback` / `callback_steps`
pair is kept for back-compat. Setting `pipe._interrupt = True` inside
the callback stops the loop early.
* Expose `encode_audio(audio)` as a public helper that wraps the tiled
VAE encode + (B, T, D) transpose the pipeline performs internally.
Converter (scripts/convert_ace_step_to_diffusers.py):
* Accept a Hugging Face Hub repo id for `--checkpoint_dir`; resolves it
via `huggingface_hub.snapshot_download` when the argument isn't a
local path.
Exports:
* Register `AceStepTransformer1DModel` in the top-level __init__,
models/__init__, models/transformers/__init__, and dummy_pt_objects so
`from diffusers import AceStepTransformer1DModel` works and the
pipeline loader resolves the new class name from model_index.json.
Deferred for a follow-up (commented inline in the PR): full
`Attention + AttnProcessor + dispatch_attention_fn` refactor and
`FlowMatchEulerDiscreteScheduler` migration — both would benefit from a
dedicated parity re-run and review.
* Fix stale ACE-Step 1.0-era docs / class names in the 1.5 integration
Docs and docstrings still carried a mix of 1.0 paper title, non-existent
`ACE-Step/ACE-Step-v1-5-turbo` hub id, `shift=3.0` turbo default, and
the old `AceStepDiTModel` class name. Cleaned up to match the actual
1.5 release:
* pipelines/ace_step.md: correct citation title ("ACE-Step 1.5: Pushing
the Boundaries of Open-Source Music Generation"), correct repo
(`ace-step/ACE-Step-1.5`), new variants table with real HF ids
(`Ace-Step1.5` / `acestep-v15-base` / `acestep-v15-sft`) and their
per-variant step/CFG defaults, drop the wrong `shift=3.0` tip.
* models/ace_step_transformer.md: page renamed to
`AceStepTransformer1DModel` with a short 1.5-specific description;
`AceStepDiTModel` noted as a backwards-compat alias.
* pipeline_ace_step.py: import, docstring, `Args`, and `__init__`
annotation reference `AceStepTransformer1DModel`; example model id
now `ACE-Step/Ace-Step1.5`; `_variant_defaults` docstring and the
`__call__` variant-fallback comment no longer claim `shift=3.0` /
`27 steps` — real defaults are 8 steps / shift=1.0 across all
variants, guidance=1.0 (turbo) vs 7.0 (base+sft).
* Address PR #13095 review: VAE tiling on AutoencoderOobleck + Timesteps class
Two more deferred review threads from dg845 addressed:
* Move tiled encode/decode onto AutoencoderOobleck
(https://github.com/huggingface/diffusers/pull/13095#discussion_r2785513647).
AutoencoderOobleck now carries `use_tiling` + `tile_sample_min_length` /
`tile_sample_overlap` / `tile_latent_min_length` / `tile_latent_overlap`
attributes and private `_tiled_encode` / `_tiled_decode` methods; the
existing `encode` / `_decode` dispatch to them when tiling is enabled and
the input exceeds the threshold. `AutoencoderMixin.enable_tiling()` is
already inherited.
AceStepPipeline's private `_tiled_encode` / `_tiled_decode` and the
`use_tiled_decode` `__call__` arg are gone; `__init__` now calls
`self.vae.enable_tiling()` so the long-audio memory behaviour is preserved
by default. Users can opt out with `pipe.vae.disable_tiling()`.
Note: the VAE-side tiling concatenates encoder features (h) and samples
the posterior once, instead of the old per-tile `.sample()` calls. This
is the standard diffusers pattern; numerically differs only in the
structure of the noise across tile boundaries.
* Use the Timesteps nn.Module for the sinusoid
(https://github.com/huggingface/diffusers/pull/13095#discussion_r2785420234).
`AceStepTimestepEmbedding` wraps `Timesteps(in_channels, flip_sin_to_cos=
True, downscale_freq_shift=0)` instead of calling `get_timestep_embedding`
directly — reviewer asked for the Module form.
* Address PR #13095 review: refactor AceStepAttention to Attention + AttnProcessor
Splits the monolithic AceStepAttention into the diffusers standard
Attention + AttnProcessor layout:
- AceStepAttention (torch.nn.Module, AttentionModuleMixin) holds the
to_q/to_k/to_v/to_out projections and norm_q/norm_k RMSNorms.
- AceStepAttnProcessor2_0 runs the attention dispatch through
dispatch_attention_fn so users can pick flash / sage / native backends
via model.set_attention_backend(...) or the attention_backend context
manager.
GQA (Q has 16 heads / K,V have 8) is preserved by passing enable_gqa=True
to dispatch_attention_fn instead of repeat_interleave; fusion is disabled
(_supports_qkv_fusion = False) because Q and K,V have different output
sizes.
The converter is updated to rename the six attention sub-keys
(q_proj -> to_q, k_proj -> to_k, v_proj -> to_v, o_proj -> to_out.0,
q_norm -> norm_q, k_norm -> norm_k) on both the DiT decoder path and the
condition encoder path, since AceStepLyricEncoder / AceStepTimbreEncoder
share the same AceStepAttention class.
Addresses review comments r2785433213 and r2785450463.
* Address PR #13095 review: migrate to FlowMatchEulerDiscreteScheduler
Replace the hand-rolled flow-matching Euler loop with
`FlowMatchEulerDiscreteScheduler`. ACE-Step still computes its own shifted /
turbo sigma schedule via `_get_timestep_schedule`, but now passes it to
`scheduler.set_timesteps(sigmas=...)` and delegates the ODE step to
`scheduler.step()`. The scheduler is configured with `num_train_timesteps=1`
and `shift=1.0` so `scheduler.timesteps` stays in `[0, 1]` (the convention the
DiT was trained on) and the scheduler doesn't re-shift already-shifted sigmas.
The scheduler's appended terminal `sigma=0` reproduces the old loop's
final-step "project to x0" case exactly: `prev = x + (0 - t_curr) * v`.
Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music, 8 steps):
waveform Pearson = 0.999999
spectral Pearson = 1.000000
max |diff| = 2.5e-3 (fp32 step-math vs previous bf16 step-math)
fp32 Euler-loop A/B against the hand-rolled path: max |diff| = 3.6e-7.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review: move DiT tests + drop stale test kwargs
- Move the DiT transformer tests out of the pipeline test file into a new
tests/models/transformers/test_models_transformer_ace_step.py that follows
the standard BaseModelTesterConfig + ModelTesterMixin scaffold (matches
test_models_transformer_longcat_audio_dit.py).
- Drop `max_position_embeddings` from the remaining AceStepDiTModel and
AceStepConditionEncoder test fixtures — neither constructor accepts that
argument anymore.
- Drop `use_sliding_window` from the same fixtures — also no longer a
constructor argument (the actual `sliding_window` int kwarg is kept).
- Wire `FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)`
into `get_dummy_components()` now that the pipeline requires it.
Resolves https://github.com/huggingface/diffusers/pull/13095#discussion_r3115653554,
r3115664850, r3115673059, r3115676580, r3115680700.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review from dg845 (2026-04-23)
Fixes 5 review threads + style:
1. Converter now builds `AceStepPipeline` in memory and calls
`save_pretrained`. Previously the hand-written `model_index.json` was
missing the `scheduler` entry — fresh converter output couldn't be loaded
by `AceStepPipeline.from_pretrained` (r3127767785). This also makes the
converter robust to future `__init__` signature changes.
2. `latent_length` uses `math.ceil(...)` instead of `int(...)` so non-integer
products (e.g. `latents_per_second=2.0, audio_duration=0.4 → 0.8`) round up
to `1` instead of truncating to `0` and crashing shape checks (r3127790939).
3. Add `_callback_tensor_inputs = ["latents"]` on `AceStepPipeline` so the
standard diffusers callback tests pick up the right tensor (r3127795954).
4. `AceStepConditionEncoder.silence_latent` no longer hard-codes the channel
dim to 64. The placeholder buffer now uses the `timbre_hidden_dim`
constructor argument, so smaller test configs with `timbre_hidden_dim != 64`
load without shape errors (r3127812932).
5. Revert `self.vae.enable_tiling()` from `AceStepPipeline.__init__`. Users can
call `pipe.vae.enable_tiling()` themselves for long-form generation; that
matches the opt-in convention used by the rest of diffusers (r3127777296).
6. `ruff check --fix` + `ruff format` over all ACE-Step sources (the style fix
dg845 asked for via `@bot /style`).
Also: converter now accepts sharded `model.safetensors.index.json` layouts
alongside the single-file `model.safetensors`, so the 5B XL turbo variant
converts without a pre-processing step.
Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music 160s, fresh
converter output loaded via `from_pretrained`):
waveform Pearson = 0.999954
spectral Pearson = 0.999977
max |a-b| bf16 = 4.3e-02 (dominated by the VAE tiling default flip)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review from yiyixuxu (2026-04-23)
Code-level (22 threads):
1. Delete 3 dev/parity scripts (`scripts/audio_parity_jieyue.py`,
`scripts/dit_parity_test.py`, `scripts/run_official_generate_music.py`)
that shouldn't have been committed.
2. Rename `AutoencoderOobleck._encode_one` → `_encode` to match the convention
used by other diffusers VAEs.
3. Delete the hard-coded `SHIFT_TIMESTEPS` / `VALID_SHIFTS` table in
`pipeline_ace_step.py`: the per-shift turbo schedules are recovered
exactly by `linspace(1, 0, N+1)[:-1]` plus the flow-match shift formula
that the non-turbo branch already uses, so a single code path covers both.
4. Drop the backwards-compat `AceStepDiTModel` / `AceStepDiTLayer` aliases
and every reference (top-level `__init__`, `models/__init__`,
`transformers/__init__`, dummy objects, tests, docs toctree, model card).
`AceStepTransformer1DModel` is the only exported name now.
5. Remove the unused `attention_mask` / `encoder_attention_mask` args from
`AceStepTransformer1DModel.forward`; the model rebuilds its masks from
the sequence shape and never consumed them.
6. In the DiT forward and both encoders, pass `None` instead of an all-zero
`full_attn_mask` / `encoder_4d_mask` to non-sliding attention layers — SDPA
dispatches to a faster kernel when the mask is None.
7. Inline the shared `_run_encoder_layers` helper directly into
`AceStepLyricEncoder.forward` / `AceStepTimbreEncoder.forward` so layer
calls are visible at the forward boundary (diffusers style).
8. Move `is_turbo` / `sample_rate` / `latents_per_second` from `@property`s
that re-read module configs each call to cached attributes populated in
`__init__` (Flux2-style), with a default-ACE-Step fallback when
`self.vae` is offloaded. Drop the now-unused `SAMPLE_RATE = 48000`
module-level constant and the three property definitions.
9. Warn + coerce `guidance_scale` to 1.0 on turbo (guidance-distilled)
checkpoints, following `pipeline_flux2_klein`. Prevents over-guided
audio when users forward their base/sft CFG settings to a turbo pipe.
10. Remove the `logger.warning(...)` paths that triggered on
`silence_latent` missing/zero — those only fired for author-side
unconverted checkpoints and tests; end users always load converted
weights where the buffer is baked in.
11. Drop the redundant `with torch.no_grad():` wrappers inside
`encode_prompt` — the pipeline's `__call__` runs under `torch.no_grad`
already.
12. Strip "reviewer comment on PR #13095" attribution comments from three
docstrings (here and everywhere).
Parity on jieyue (seed=42, bf16 + flash-attn, XL turbo 160s text2music):
waveform Pearson = 0.9747
spectral Pearson = 0.9895
The shift comes from full-attention layers switching `attn_mask=0_tensor` →
`attn_mask=None`, which dispatches to a different SDPA kernel on bf16. The
two outputs are algebraically equivalent for fp32 eager; on bf16+FA the
delta is dominated by kernel-level ULPs, well within the sampler-noise
band (ear-check on the 160s example confirms no audible regression).
Still open — AudioTokenizer/Detokenizer (deferred) + APG guider follow-up
(dims differ from `diffusers.guiders.adaptive_projected_guidance`, not a
drop-in; worth a separate PR).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address ACE-Step audio token and APG review
* Fix ACE-Step docs CI
* Address ACE-Step pipeline cleanup review
* Fix ACE-Step flash attention sliding windows
* Add ACE-Step callback properties
* Address ACE-Step final review comments
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
This commit is contained in:
@@ -324,6 +324,8 @@
|
||||
title: SparseControlNetModel
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/ace_step_transformer
|
||||
title: AceStepTransformer1DModel
|
||||
- local: api/models/allegro_transformer3d
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
@@ -488,6 +490,8 @@
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- sections:
|
||||
- local: api/pipelines/ace_step
|
||||
title: ACE-Step
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/longcat_audio_dit
|
||||
|
||||
19
docs/source/en/api/models/ace_step_transformer.md
Normal file
19
docs/source/en/api/models/ace_step_transformer.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AceStepTransformer1DModel
|
||||
|
||||
A 1D Diffusion Transformer for music generation from [ACE-Step 1.5](https://github.com/ace-step/ACE-Step-1.5). The model operates on the 25 Hz stereo latents produced by [`AutoencoderOobleck`] using flow matching, and is trained with a Qwen3-derived backbone (grouped-query attention, rotary position embedding, RMSNorm, AdaLN-Zero timestep conditioning) plus cross-attention to the text / lyric / timbre conditions built by `AceStepConditionEncoder`.
|
||||
|
||||
## AceStepTransformer1DModel
|
||||
|
||||
[[autodoc]] AceStepTransformer1DModel
|
||||
72
docs/source/en/api/pipelines/ace_step.md
Normal file
72
docs/source/en/api/pipelines/ace_step.md
Normal file
@@ -0,0 +1,72 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ACE-Step 1.5
|
||||
|
||||
ACE-Step 1.5 was introduced in [ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation](https://arxiv.org/abs/2602.00744) by the ACE-Step Team (ACE Studio and StepFun). It is an open-source music foundation model that generates commercial-grade stereo music with lyrics from text prompts.
|
||||
|
||||
ACE-Step 1.5 generates variable-length stereo audio at 48 kHz (10 seconds to 10 minutes) from text prompts and optional lyrics. The full system pairs a Language Model planner with a Diffusion Transformer (DiT) synthesizer; this pipeline wraps the DiT half of that stack, and consists of three components: an [`AutoencoderOobleck`] VAE that compresses waveforms into 25 Hz stereo latents, a Qwen3-based text encoder for prompt and lyric conditioning, and an [`AceStepTransformer1DModel`] DiT that operates in the VAE latent space using flow matching.
|
||||
|
||||
The model supports 50+ languages for lyrics — including English, Chinese, Japanese, Korean, French, German, Spanish, Italian, Portuguese, and Russian — and runs on consumer GPUs (under 4 GB of VRAM when offloaded).
|
||||
|
||||
This pipeline was contributed by the [ACE-Step Team](https://github.com/ace-step). The original codebase can be found at [ace-step/ACE-Step-1.5](https://github.com/ace-step/ACE-Step-1.5).
|
||||
|
||||
## Variants
|
||||
|
||||
ACE-Step 1.5 ships three DiT checkpoints that share the same transformer architecture but differ in guidance behavior; the pipeline auto-detects turbo checkpoints from the loaded transformer config and ignores CFG guidance for those guidance-distilled weights.
|
||||
|
||||
| Variant | CFG | Default steps | Default `guidance_scale` | Default `shift` | HF repo |
|
||||
|---------|:---:|:-------------:|:------------------------:|:---------------:|---------|
|
||||
| `turbo` (guidance-distilled) | off | 8 | ignored | 3.0 | [`ACE-Step/Ace-Step1.5`](https://huggingface.co/ACE-Step/Ace-Step1.5) |
|
||||
| `base` | on | 8 | 7.0 | 3.0 | [`ACE-Step/acestep-v15-base`](https://huggingface.co/ACE-Step/acestep-v15-base) |
|
||||
| `sft` | on | 8 | 7.0 | 3.0 | [`ACE-Step/acestep-v15-sft`](https://huggingface.co/ACE-Step/acestep-v15-sft) |
|
||||
|
||||
Base and SFT use the learned `null_condition_emb` for classifier-free guidance (APG, not vanilla CFG). Users commonly override `num_inference_steps` to 30–60 on base/sft for higher quality.
|
||||
|
||||
## Tips
|
||||
|
||||
When constructing a prompt, keep in mind:
|
||||
|
||||
* Descriptive prompt inputs work best; use adjectives to describe the music style, instruments, mood, and tempo.
|
||||
* The prompt should describe the overall musical characteristics (e.g., "upbeat pop song with electric guitar and drums").
|
||||
* Lyrics should be structured with tags like `[verse]`, `[chorus]`, `[bridge]`, etc.
|
||||
|
||||
During inference:
|
||||
|
||||
* `num_inference_steps`, `guidance_scale`, and `shift` default to the values shown above. For turbo checkpoints, `guidance_scale > 1.0` is ignored with a warning because guidance is distilled into the weights.
|
||||
* The `audio_duration` parameter controls the length of the generated music in seconds.
|
||||
* The `vocal_language` parameter should match the language of the lyrics.
|
||||
* `pipe.sample_rate` and `pipe.latents_per_second` are sourced from the VAE config (48000 Hz and 25 fps for the released checkpoints).
|
||||
* For audio-to-audio tasks, pass `src_audio` and `reference_audio` as preprocessed stereo tensors at `pipe.sample_rate`.
|
||||
* `flash` and `flash_hub` use FlashAttention's native sliding-window support for ACE-Step's self-attention and expect unpadded text batches. If a batched prompt contains padding, use `flash_varlen` or `flash_varlen_hub` instead. Single-prompt inference with `padding="longest"` is normally unpadded.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from diffusers import AceStepPipeline
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained("ACE-Step/Ace-Step1.5", torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
audio = pipe(
|
||||
prompt="A beautiful piano piece with soft melodies and gentle rhythm",
|
||||
lyrics="[verse]\nSoft notes in the morning light\nDancing through the air so bright\n[chorus]\nMusic fills the air tonight\nEvery note feels just right",
|
||||
audio_duration=30.0,
|
||||
).audios
|
||||
|
||||
sf.write("output.wav", audio[0].T.cpu().float().numpy(), pipe.sample_rate)
|
||||
```
|
||||
|
||||
## AceStepPipeline
|
||||
[[autodoc]] AceStepPipeline
|
||||
- all
|
||||
- __call__
|
||||
454
scripts/convert_ace_step_to_diffusers.py
Normal file
454
scripts/convert_ace_step_to_diffusers.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# Run this script to convert ACE-Step model weights to a diffusers pipeline.
|
||||
#
|
||||
# Usage:
|
||||
# python scripts/convert_ace_step_to_diffusers.py \
|
||||
# --checkpoint_dir /path/to/ACE-Step-1.5/checkpoints \
|
||||
# --dit_config acestep-v15-turbo \
|
||||
# --output_dir /path/to/output/ACE-Step-v1-5-turbo \
|
||||
# --dtype bf16
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="bf16"):
|
||||
"""
|
||||
Convert ACE-Step checkpoint weights into a Diffusers-compatible pipeline layout.
|
||||
|
||||
The original ACE-Step model stores all weights in a single `model.safetensors` file
|
||||
under `checkpoints/<dit_config>/`. This script splits the weights into separate
|
||||
sub-model directories that can be loaded by `AceStepPipeline.from_pretrained()`.
|
||||
|
||||
Expected input layout:
|
||||
checkpoint_dir/
|
||||
<dit_config>/ # e.g., acestep-v15-turbo
|
||||
config.json
|
||||
model.safetensors
|
||||
silence_latent.pt
|
||||
vae/
|
||||
config.json
|
||||
diffusion_pytorch_model.safetensors
|
||||
Qwen3-Embedding-0.6B/
|
||||
config.json
|
||||
model.safetensors
|
||||
tokenizer.json
|
||||
...
|
||||
|
||||
Output layout:
|
||||
output_dir/
|
||||
model_index.json
|
||||
transformer/
|
||||
config.json
|
||||
diffusion_pytorch_model.safetensors
|
||||
condition_encoder/
|
||||
config.json
|
||||
diffusion_pytorch_model.safetensors
|
||||
vae/
|
||||
config.json
|
||||
diffusion_pytorch_model.safetensors
|
||||
text_encoder/
|
||||
config.json
|
||||
model.safetensors
|
||||
...
|
||||
tokenizer/
|
||||
tokenizer.json
|
||||
...
|
||||
"""
|
||||
# Support `--checkpoint_dir <repo-id>` by snapshot-downloading it first. A
|
||||
# local path that happens not to exist still raises the clearer FileNotFoundError
|
||||
# below, so we only fall through to the Hub if the path is missing AND looks like
|
||||
# a repo id (namespace/name).
|
||||
if not os.path.exists(checkpoint_dir) and "/" in checkpoint_dir and not checkpoint_dir.startswith((".", "~", "/")):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
print(f"Downloading `{checkpoint_dir}` from the Hugging Face Hub ...")
|
||||
checkpoint_dir = snapshot_download(repo_id=checkpoint_dir)
|
||||
print(f" -> local snapshot at {checkpoint_dir}")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use a Hugging Face Hub repo id for --checkpoint_dir, install `huggingface_hub`."
|
||||
) from e
|
||||
|
||||
# Resolve paths
|
||||
dit_dir = os.path.join(checkpoint_dir, dit_config)
|
||||
vae_dir = os.path.join(checkpoint_dir, "vae")
|
||||
text_encoder_dir = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
|
||||
|
||||
# The DiT weights ship either as a single `model.safetensors` (the smaller turbo
|
||||
# variant) or as sharded safetensors keyed by `model.safetensors.index.json`
|
||||
# (the 5B XL variant). Resolve both layouts to `dit_weight_files` and load below.
|
||||
single_model_path = os.path.join(dit_dir, "model.safetensors")
|
||||
sharded_index_path = os.path.join(dit_dir, "model.safetensors.index.json")
|
||||
config_path = os.path.join(dit_dir, "config.json")
|
||||
if os.path.exists(single_model_path):
|
||||
dit_weight_files = [single_model_path]
|
||||
elif os.path.exists(sharded_index_path):
|
||||
with open(sharded_index_path) as f:
|
||||
shard_index = json.load(f)
|
||||
dit_weight_files = [os.path.join(dit_dir, s) for s in sorted(set(shard_index["weight_map"].values()))]
|
||||
for p in dit_weight_files:
|
||||
if not os.path.exists(p):
|
||||
raise FileNotFoundError(f"sharded DiT weight missing: {p}")
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"DiT weights not found at: {single_model_path} or {sharded_index_path}. "
|
||||
"Expected either a single `model.safetensors` or a sharded "
|
||||
"`model.safetensors.index.json` + per-shard files."
|
||||
)
|
||||
for path, name in [
|
||||
(config_path, "config"),
|
||||
(vae_dir, "VAE"),
|
||||
(text_encoder_dir, "text encoder"),
|
||||
]:
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"{name} not found at: {path}")
|
||||
|
||||
# Select dtype
|
||||
dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
if dtype_str not in dtype_map:
|
||||
raise ValueError(f"Unsupported dtype: {dtype_str}. Choose from {list(dtype_map.keys())}")
|
||||
target_dtype = dtype_map[dtype_str]
|
||||
|
||||
# Load original config
|
||||
with open(config_path) as f:
|
||||
original_config = json.load(f)
|
||||
|
||||
print(f"Loading DiT weights from {len(dit_weight_files)} file(s) ...")
|
||||
state_dict = {}
|
||||
for p in dit_weight_files:
|
||||
print(f" loading {os.path.basename(p)}")
|
||||
state_dict.update(load_file(p))
|
||||
print(f" Total keys: {len(state_dict)}")
|
||||
|
||||
# =========================================================================
|
||||
# 1. Split weights by prefix
|
||||
# =========================================================================
|
||||
transformer_sd = {}
|
||||
condition_encoder_sd = {}
|
||||
audio_tokenizer_sd = {}
|
||||
audio_token_detokenizer_sd = {}
|
||||
other_sd = {}
|
||||
|
||||
# Rename original ACE-Step attention keys to the diffusers `Attention` +
|
||||
# `AttnProcessor` convention (`to_q`/`to_k`/`to_v`/`to_out.0`/`norm_q`/`norm_k`).
|
||||
# Applies uniformly to both the DiT (self-attn and cross-attn) and the
|
||||
# condition-encoder self-attention, since both use `AceStepAttention`.
|
||||
_ATTN_KEY_RENAMES = [
|
||||
(".q_proj.", ".to_q."),
|
||||
(".k_proj.", ".to_k."),
|
||||
(".v_proj.", ".to_v."),
|
||||
(".o_proj.", ".to_out.0."),
|
||||
(".q_norm.", ".norm_q."),
|
||||
(".k_norm.", ".norm_k."),
|
||||
]
|
||||
|
||||
def _rename_attn_keys(key: str) -> str:
|
||||
for old, new in _ATTN_KEY_RENAMES:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith("decoder."):
|
||||
# Strip "decoder." prefix for the transformer
|
||||
new_key = key[len("decoder.") :]
|
||||
# The original model uses nn.Sequential for proj_in/proj_out:
|
||||
# proj_in = Sequential(Lambda, Conv1d, Lambda)
|
||||
# proj_out = Sequential(Lambda, ConvTranspose1d, Lambda)
|
||||
# Only the Conv1d/ConvTranspose1d (index 1) has parameters.
|
||||
# In diffusers, we use standalone Conv1d/ConvTranspose1d named proj_in_conv/proj_out_conv.
|
||||
new_key = new_key.replace("proj_in.1.", "proj_in_conv.")
|
||||
new_key = new_key.replace("proj_out.1.", "proj_out_conv.")
|
||||
new_key = _rename_attn_keys(new_key)
|
||||
transformer_sd[new_key] = value.to(target_dtype)
|
||||
elif key.startswith("encoder."):
|
||||
# Strip "encoder." prefix for the condition encoder
|
||||
new_key = key[len("encoder.") :]
|
||||
new_key = _rename_attn_keys(new_key)
|
||||
condition_encoder_sd[new_key] = value.to(target_dtype)
|
||||
elif key == "null_condition_emb":
|
||||
# Learned unconditional embedding (used by the base/SFT CFG path).
|
||||
# Keep it co-located with the condition encoder since that is where the
|
||||
# pipeline pulls unconditional sequences from.
|
||||
condition_encoder_sd["null_condition_emb"] = value.to(target_dtype)
|
||||
elif key.startswith("tokenizer."):
|
||||
new_key = key[len("tokenizer.") :]
|
||||
new_key = _rename_attn_keys(new_key)
|
||||
audio_tokenizer_sd[new_key] = value.to(target_dtype)
|
||||
elif key.startswith("detokenizer."):
|
||||
new_key = key[len("detokenizer.") :]
|
||||
new_key = _rename_attn_keys(new_key)
|
||||
audio_token_detokenizer_sd[new_key] = value.to(target_dtype)
|
||||
else:
|
||||
other_sd[key] = value.to(target_dtype)
|
||||
|
||||
print(f" Transformer keys: {len(transformer_sd)}")
|
||||
print(f" Condition encoder keys: {len(condition_encoder_sd)}")
|
||||
print(f" Audio tokenizer keys: {len(audio_tokenizer_sd)}")
|
||||
print(f" Audio token detokenizer keys: {len(audio_token_detokenizer_sd)}")
|
||||
print(f" Other keys: {len(other_sd)} ({list(other_sd.keys())[:5]}...)")
|
||||
|
||||
# =========================================================================
|
||||
# 2. Build configs for each sub-model
|
||||
# =========================================================================
|
||||
|
||||
# On the 5B XL turbo the condition encoder is narrower than the DiT
|
||||
# (`encoder_hidden_size=2048` feeding a `hidden_size=2560` DiT). Non-XL
|
||||
# turbo / base checkpoints don't set this field, so fall back to
|
||||
# `hidden_size` — that makes the DiT's `condition_embedder` an identity-width
|
||||
# Linear as before. Similarly `encoder_intermediate_size` /
|
||||
# `encoder_num_attention_heads` / `encoder_num_key_value_heads` describe the
|
||||
# condition encoder on XL only.
|
||||
encoder_hidden_size = original_config.get("encoder_hidden_size", original_config["hidden_size"])
|
||||
encoder_intermediate_size = original_config.get("encoder_intermediate_size", original_config["intermediate_size"])
|
||||
encoder_num_attention_heads = original_config.get(
|
||||
"encoder_num_attention_heads", original_config["num_attention_heads"]
|
||||
)
|
||||
encoder_num_key_value_heads = original_config.get(
|
||||
"encoder_num_key_value_heads", original_config["num_key_value_heads"]
|
||||
)
|
||||
|
||||
# Transformer (DiT) config. `is_turbo` / `model_version` propagate the variant so
|
||||
# the pipeline can pick the right CFG / shift / step-count defaults at inference.
|
||||
# Note: `max_position_embeddings` is dropped (RoPE computes freqs on-the-fly per call),
|
||||
# and `use_sliding_window` is implied by the mix of `layer_types`.
|
||||
transformer_config = {
|
||||
"_class_name": "AceStepTransformer1DModel",
|
||||
"_diffusers_version": "0.33.0.dev0",
|
||||
"hidden_size": original_config["hidden_size"],
|
||||
"intermediate_size": original_config["intermediate_size"],
|
||||
"num_hidden_layers": original_config["num_hidden_layers"],
|
||||
"num_attention_heads": original_config["num_attention_heads"],
|
||||
"num_key_value_heads": original_config["num_key_value_heads"],
|
||||
"head_dim": original_config["head_dim"],
|
||||
"in_channels": original_config["in_channels"],
|
||||
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
|
||||
"patch_size": original_config["patch_size"],
|
||||
"rope_theta": original_config["rope_theta"],
|
||||
"attention_bias": original_config["attention_bias"],
|
||||
"attention_dropout": original_config["attention_dropout"],
|
||||
"rms_norm_eps": original_config["rms_norm_eps"],
|
||||
"sliding_window": original_config["sliding_window"],
|
||||
"layer_types": original_config["layer_types"],
|
||||
"encoder_hidden_size": encoder_hidden_size,
|
||||
"is_turbo": bool(original_config.get("is_turbo", False)),
|
||||
"model_version": original_config.get("model_version"),
|
||||
}
|
||||
|
||||
# Condition encoder config
|
||||
condition_encoder_config = {
|
||||
"_class_name": "AceStepConditionEncoder",
|
||||
"_diffusers_version": "0.33.0.dev0",
|
||||
"hidden_size": encoder_hidden_size,
|
||||
"intermediate_size": encoder_intermediate_size,
|
||||
"text_hidden_dim": original_config["text_hidden_dim"],
|
||||
"timbre_hidden_dim": original_config["timbre_hidden_dim"],
|
||||
"num_lyric_encoder_hidden_layers": original_config["num_lyric_encoder_hidden_layers"],
|
||||
"num_timbre_encoder_hidden_layers": original_config["num_timbre_encoder_hidden_layers"],
|
||||
"num_attention_heads": encoder_num_attention_heads,
|
||||
"num_key_value_heads": encoder_num_key_value_heads,
|
||||
"head_dim": original_config["head_dim"],
|
||||
"rope_theta": original_config["rope_theta"],
|
||||
"attention_bias": original_config["attention_bias"],
|
||||
"attention_dropout": original_config["attention_dropout"],
|
||||
"rms_norm_eps": original_config["rms_norm_eps"],
|
||||
"sliding_window": original_config["sliding_window"],
|
||||
}
|
||||
|
||||
audio_tokenizer_config = {
|
||||
"_class_name": "AceStepAudioTokenizer",
|
||||
"_diffusers_version": "0.33.0.dev0",
|
||||
"hidden_size": encoder_hidden_size,
|
||||
"intermediate_size": encoder_intermediate_size,
|
||||
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
|
||||
"pool_window_size": original_config.get("pool_window_size", 5),
|
||||
"fsq_dim": original_config.get("fsq_dim", encoder_hidden_size),
|
||||
"fsq_input_levels": original_config.get("fsq_input_levels", [8, 8, 8, 5, 5, 5]),
|
||||
"fsq_input_num_quantizers": original_config.get("fsq_input_num_quantizers", 1),
|
||||
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
|
||||
"num_attention_heads": encoder_num_attention_heads,
|
||||
"num_key_value_heads": encoder_num_key_value_heads,
|
||||
"head_dim": original_config["head_dim"],
|
||||
"rope_theta": original_config["rope_theta"],
|
||||
"attention_bias": original_config["attention_bias"],
|
||||
"attention_dropout": original_config["attention_dropout"],
|
||||
"rms_norm_eps": original_config["rms_norm_eps"],
|
||||
"sliding_window": original_config["sliding_window"],
|
||||
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
|
||||
}
|
||||
|
||||
audio_token_detokenizer_config = {
|
||||
"_class_name": "AceStepAudioTokenDetokenizer",
|
||||
"_diffusers_version": "0.33.0.dev0",
|
||||
"hidden_size": encoder_hidden_size,
|
||||
"intermediate_size": encoder_intermediate_size,
|
||||
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
|
||||
"pool_window_size": original_config.get("pool_window_size", 5),
|
||||
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
|
||||
"num_attention_heads": encoder_num_attention_heads,
|
||||
"num_key_value_heads": encoder_num_key_value_heads,
|
||||
"head_dim": original_config["head_dim"],
|
||||
"rope_theta": original_config["rope_theta"],
|
||||
"attention_bias": original_config["attention_bias"],
|
||||
"attention_dropout": original_config["attention_dropout"],
|
||||
"rms_norm_eps": original_config["rms_norm_eps"],
|
||||
"sliding_window": original_config["sliding_window"],
|
||||
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# 3. Bake silence_latent into the condition_encoder state dict.
|
||||
#
|
||||
# The original loader in
|
||||
# acestep/core/generation/handler/init_service_loader.py:214 does
|
||||
# self.silence_latent = torch.load(...).transpose(1, 2)
|
||||
# converting the stored [B, C=64, T=15000] tensor to [B, T, C=64] before any
|
||||
# downstream slicing. Do the same transpose here and register it as the
|
||||
# `silence_latent` buffer on AceStepConditionEncoder — the pipeline slices
|
||||
# `silence_latent[:, :timbre_fix_frame, :]` to build the "silence" input to the
|
||||
# timbre encoder when no reference audio is supplied. Passing literal zeros
|
||||
# produces drone-like audio.
|
||||
silence_latent_src = os.path.join(dit_dir, "silence_latent.pt")
|
||||
if os.path.exists(silence_latent_src):
|
||||
silence_raw = torch.load(silence_latent_src, weights_only=True, map_location="cpu")
|
||||
silence_latent = silence_raw.transpose(1, 2).to(target_dtype).contiguous()
|
||||
print(f" silence_latent raw shape: {tuple(silence_raw.shape)} -> baked shape: {tuple(silence_latent.shape)}")
|
||||
condition_encoder_sd["silence_latent"] = silence_latent
|
||||
|
||||
# =========================================================================
|
||||
# 4. Build the AceStepPipeline in memory and save via `save_pretrained`.
|
||||
# Assembling the pipeline directly (rather than hand-writing model_index.json)
|
||||
# ensures the saved repo stays in sync with the `AceStepPipeline.__init__`
|
||||
# signature — e.g. a future sub-module added to the pipeline can't silently
|
||||
# drift out of `model_index.json`.
|
||||
# =========================================================================
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AceStepPipeline,
|
||||
AceStepTransformer1DModel,
|
||||
AutoencoderOobleck,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.pipelines.ace_step import (
|
||||
AceStepAudioTokenDetokenizer,
|
||||
AceStepAudioTokenizer,
|
||||
AceStepConditionEncoder,
|
||||
)
|
||||
|
||||
# Drop metadata keys — they're re-populated by `save_pretrained` at save time.
|
||||
transformer_init_kwargs = {k: v for k, v in transformer_config.items() if not k.startswith("_")}
|
||||
condition_encoder_init_kwargs = {k: v for k, v in condition_encoder_config.items() if not k.startswith("_")}
|
||||
audio_tokenizer_init_kwargs = {k: v for k, v in audio_tokenizer_config.items() if not k.startswith("_")}
|
||||
audio_token_detokenizer_init_kwargs = {
|
||||
k: v for k, v in audio_token_detokenizer_config.items() if not k.startswith("_")
|
||||
}
|
||||
|
||||
print("\nConstructing transformer ...")
|
||||
transformer = AceStepTransformer1DModel(**transformer_init_kwargs).to(target_dtype)
|
||||
transformer.load_state_dict(transformer_sd, strict=True)
|
||||
|
||||
print("Constructing condition_encoder ...")
|
||||
condition_encoder = AceStepConditionEncoder(**condition_encoder_init_kwargs).to(target_dtype)
|
||||
condition_encoder.load_state_dict(condition_encoder_sd, strict=True)
|
||||
|
||||
print("Constructing audio_tokenizer ...")
|
||||
audio_tokenizer = AceStepAudioTokenizer(**audio_tokenizer_init_kwargs).to(target_dtype)
|
||||
audio_tokenizer.load_state_dict(audio_tokenizer_sd, strict=True)
|
||||
|
||||
print("Constructing audio_token_detokenizer ...")
|
||||
audio_token_detokenizer = AceStepAudioTokenDetokenizer(**audio_token_detokenizer_init_kwargs).to(target_dtype)
|
||||
audio_token_detokenizer.load_state_dict(audio_token_detokenizer_sd, strict=True)
|
||||
|
||||
print("Loading VAE ...")
|
||||
vae = AutoencoderOobleck.from_pretrained(vae_dir).to(target_dtype)
|
||||
|
||||
print("Loading text encoder ...")
|
||||
text_encoder = AutoModel.from_pretrained(text_encoder_dir, torch_dtype=target_dtype)
|
||||
|
||||
print("Loading tokenizer ...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
|
||||
|
||||
# ACE-Step drives the DiT with t ∈ [0, 1] and computes its own shifted / turbo
|
||||
# sigma schedule, which it passes to `scheduler.set_timesteps(sigmas=...)` at
|
||||
# sampling time. So the scheduler needs `num_train_timesteps=1` (so
|
||||
# `scheduler.timesteps == sigmas`) and `shift=1.0` (so it doesn't re-shift
|
||||
# already-shifted sigmas). All other defaults are fine.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)
|
||||
|
||||
pipe = AceStepPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
condition_encoder=condition_encoder,
|
||||
scheduler=scheduler,
|
||||
audio_tokenizer=audio_tokenizer,
|
||||
audio_token_detokenizer=audio_token_detokenizer,
|
||||
)
|
||||
|
||||
print(f"\nSaving pipeline -> {output_dir}")
|
||||
pipe.save_pretrained(output_dir, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
# Keep the raw silence_latent.pt at the pipeline root for debugging — not
|
||||
# required by `from_pretrained`, but makes it easy to re-derive the buffer
|
||||
# without re-running the full conversion.
|
||||
if os.path.exists(silence_latent_src):
|
||||
shutil.copy2(silence_latent_src, os.path.join(output_dir, "silence_latent.pt"))
|
||||
print(f" kept raw silence_latent copy at {output_dir}/silence_latent.pt")
|
||||
|
||||
# Report any keys that were not saved to registered pipeline modules.
|
||||
if other_sd:
|
||||
print(f"\nNote: {len(other_sd)} keys were dropped:")
|
||||
for key in sorted(other_sd.keys())[:10]:
|
||||
print(f" {key}")
|
||||
if len(other_sd) > 10:
|
||||
print(f" ... ({len(other_sd) - 10} more)")
|
||||
|
||||
print(f"\nConversion complete! Output saved to: {output_dir}")
|
||||
print("\nTo load the pipeline:")
|
||||
print(" from diffusers import AceStepPipeline")
|
||||
print(f" pipe = AceStepPipeline.from_pretrained('{output_dir}', torch_dtype=torch.bfloat16)")
|
||||
print(" pipe = pipe.to('cuda')")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert ACE-Step model weights to Diffusers pipeline format")
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the ACE-Step checkpoints directory (containing vae/, Qwen3-Embedding-0.6B/, and dit config dirs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dit_config",
|
||||
type=str,
|
||||
default="acestep-v15-turbo",
|
||||
help="Name of the DiT config directory (default: acestep-v15-turbo)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to save the converted Diffusers pipeline",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
help="Data type for saved weights (default: bf16)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_ace_step_weights(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
dit_config=args.dit_config,
|
||||
output_dir=args.output_dir,
|
||||
dtype_str=args.dtype,
|
||||
)
|
||||
@@ -188,6 +188,7 @@ else:
|
||||
]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AceStepTransformer1DModel",
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AttentionBackendName",
|
||||
@@ -488,6 +489,10 @@ else:
|
||||
)
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AceStepAudioTokenDetokenizer",
|
||||
"AceStepAudioTokenizer",
|
||||
"AceStepConditionEncoder",
|
||||
"AceStepPipeline",
|
||||
"AllegroPipeline",
|
||||
"AltDiffusionImg2ImgPipeline",
|
||||
"AltDiffusionPipeline",
|
||||
@@ -1000,6 +1005,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VaeImageProcessorLDM3D,
|
||||
)
|
||||
from .models import (
|
||||
AceStepTransformer1DModel,
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AttentionBackendName,
|
||||
@@ -1277,6 +1283,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageModularPipeline,
|
||||
)
|
||||
from .pipelines import (
|
||||
AceStepAudioTokenDetokenizer,
|
||||
AceStepAudioTokenizer,
|
||||
AceStepConditionEncoder,
|
||||
AceStepPipeline,
|
||||
AllegroPipeline,
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
|
||||
@@ -40,6 +40,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
adaptive_projected_guidance_norm_dim (`int` or `tuple[int]`, *optional*):
|
||||
Dimension(s) over which to compute the APG norm and projection. If omitted, all non-batch dimensions are
|
||||
used, preserving the original behavior.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
@@ -62,6 +65,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: float | None = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
adaptive_projected_guidance_norm_dim: int | tuple[int, ...] | None = None,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
@@ -74,6 +78,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.adaptive_projected_guidance_norm_dim = adaptive_projected_guidance_norm_dim
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
@@ -117,6 +122,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
self.adaptive_projected_guidance_norm_dim,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
@@ -210,9 +216,15 @@ def normalized_guidance(
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
norm_dim: int | tuple[int, ...] | None = None,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
if norm_dim is None:
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
elif isinstance(norm_dim, int):
|
||||
dim = [norm_dim]
|
||||
else:
|
||||
dim = list(norm_dim)
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
@@ -224,11 +236,15 @@ def normalized_guidance(
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
if diff.device.type in {"mps", "npu"}:
|
||||
v0, v1 = diff.cpu().double(), pred_cond.cpu().double()
|
||||
else:
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
diff_parallel = v0_parallel.to(device=diff.device, dtype=diff.dtype)
|
||||
diff_orthogonal = v0_orthogonal.to(device=diff.device, dtype=diff.dtype)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
|
||||
@@ -79,6 +79,7 @@ if is_torch_available():
|
||||
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.ace_step_transformer"] = ["AceStepTransformer1DModel"]
|
||||
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
||||
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
|
||||
_import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"]
|
||||
@@ -209,6 +210,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
AceStepTransformer1DModel,
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
|
||||
@@ -1091,14 +1091,14 @@ def _flash_attention_forward_op(
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
*,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
|
||||
|
||||
# Hardcoded for now
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
@@ -1191,6 +1191,8 @@ def _flash_attention_hub_forward_op(
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
*,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||
@@ -1209,7 +1211,6 @@ def _flash_attention_hub_forward_op(
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
@@ -2453,6 +2454,7 @@ def _flash_attention(
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -2468,11 +2470,13 @@ def _flash_attention(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
forward_op = functools.partial(_flash_attention_forward_op, window_size=window_size)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
@@ -2483,7 +2487,7 @@ def _flash_attention(
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_forward_op,
|
||||
forward_op=forward_op,
|
||||
backward_op=_flash_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
@@ -2506,6 +2510,7 @@ def _flash_attention_hub(
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -2522,11 +2527,13 @@ def _flash_attention_hub(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
forward_op = functools.partial(_flash_attention_hub_forward_op, window_size=window_size)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
@@ -2537,7 +2544,7 @@ def _flash_attention_hub(
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_hub_forward_op,
|
||||
forward_op=forward_op,
|
||||
backward_op=_flash_attention_hub_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
@@ -2560,6 +2567,7 @@ def _flash_varlen_attention_hub(
|
||||
dropout_p: float = 0.0,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -2597,6 +2605,7 @@ def _flash_varlen_attention_hub(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
@@ -2616,6 +2625,7 @@ def _flash_varlen_attention(
|
||||
dropout_p: float = 0.0,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -2652,6 +2662,7 @@ def _flash_varlen_attention(
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
@@ -355,6 +355,24 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# 1D time-axis tiling defaults. `tile_sample_min_length` is the raw-audio
|
||||
# threshold (in samples) above which `encode` splits the input; chunks are
|
||||
# `tile_sample_min_length` wide with `tile_sample_overlap` samples of overlap
|
||||
# on each side, trimmed back out after decoding. `tile_latent_min_length`
|
||||
# is the equivalent threshold on the decode side, expressed in latent frames.
|
||||
self.tile_sample_min_length = sampling_rate * 30 # 30 seconds
|
||||
self.tile_sample_overlap = sampling_rate * 2 # 2 seconds per side
|
||||
# Decode chunk is smaller than encode chunk because the decoder upsamples
|
||||
# back to raw audio and is more VRAM-heavy per frame.
|
||||
self.tile_latent_min_length = 512
|
||||
self.tile_latent_overlap = 64
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_tiling and x.shape[-1] > self.tile_sample_min_length:
|
||||
return self._tiled_encode(x)
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
@@ -373,10 +391,10 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = OobleckDiagonalGaussianDistribution(h)
|
||||
|
||||
@@ -385,14 +403,88 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return AutoencoderOobleckOutput(latent_dist=posterior)
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a long audio waveform by splitting it into overlapping tiles along
|
||||
the time axis and concatenating the resulting encoder features. Used to keep memory bounded regardless of clip
|
||||
length. Not bit-identical to a single unsplit encode — each tile has its own receptive-field boundary — but the
|
||||
overlap/trim scheme keeps the joined feature map smooth.
|
||||
"""
|
||||
_B, _C, S = x.shape
|
||||
chunk = self.tile_sample_min_length
|
||||
overlap = self.tile_sample_overlap
|
||||
stride = chunk - 2 * overlap
|
||||
if stride <= 0:
|
||||
raise ValueError(
|
||||
f"tile_sample_min_length ({chunk}) must be greater than 2 * tile_sample_overlap ({overlap})"
|
||||
)
|
||||
|
||||
num_steps = math.ceil(S / stride)
|
||||
tiles = []
|
||||
hop = None
|
||||
|
||||
for i in range(num_steps):
|
||||
core_start = i * stride
|
||||
core_end = min(core_start + stride, S)
|
||||
win_start = max(0, core_start - overlap)
|
||||
win_end = min(S, core_end + overlap)
|
||||
|
||||
tile = self.encoder(x[:, :, win_start:win_end])
|
||||
|
||||
if hop is None:
|
||||
hop = (win_end - win_start) / tile.shape[-1]
|
||||
|
||||
trim_l = int(round((core_start - win_start) / hop))
|
||||
trim_r = int(round((win_end - core_end) / hop))
|
||||
end_idx = tile.shape[-1] - trim_r if trim_r > 0 else tile.shape[-1]
|
||||
tiles.append(tile[:, :, trim_l:end_idx])
|
||||
|
||||
return torch.cat(tiles, dim=-1)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> OobleckDecoderOutput | torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
if self.use_tiling and z.shape[-1] > self.tile_latent_min_length:
|
||||
dec = self._tiled_decode(z)
|
||||
else:
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return OobleckDecoderOutput(sample=dec)
|
||||
|
||||
def _tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
r"""Decode a long latent by splitting it into overlapping tiles along the
|
||||
time axis, decoding each, and concatenating the audio tiles back together."""
|
||||
_B, _C, T = z.shape
|
||||
chunk = self.tile_latent_min_length
|
||||
overlap = self.tile_latent_overlap
|
||||
stride = chunk - 2 * overlap
|
||||
if stride <= 0:
|
||||
raise ValueError(
|
||||
f"tile_latent_min_length ({chunk}) must be greater than 2 * tile_latent_overlap ({overlap})"
|
||||
)
|
||||
|
||||
num_steps = math.ceil(T / stride)
|
||||
tiles = []
|
||||
upsample = None
|
||||
|
||||
for i in range(num_steps):
|
||||
core_start = i * stride
|
||||
core_end = min(core_start + stride, T)
|
||||
win_start = max(0, core_start - overlap)
|
||||
win_end = min(T, core_end + overlap)
|
||||
|
||||
tile = self.decoder(z[:, :, win_start:win_end])
|
||||
|
||||
if upsample is None:
|
||||
upsample = tile.shape[-1] / (win_end - win_start)
|
||||
|
||||
trim_l = int(round((core_start - win_start) * upsample))
|
||||
trim_r = int(round((win_end - core_end) * upsample))
|
||||
end_idx = tile.shape[-1] - trim_r if trim_r > 0 else tile.shape[-1]
|
||||
tiles.append(tile[:, :, trim_l:end_idx])
|
||||
|
||||
return torch.cat(tiles, dim=-1)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
|
||||
@@ -2,6 +2,7 @@ from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .ace_step_transformer import AceStepTransformer1DModel
|
||||
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
|
||||
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
||||
from .consisid_transformer_3d import ConsisIDTransformer3DModel
|
||||
|
||||
626
src/diffusers/models/transformers/ace_step_transformer.py
Normal file
626
src/diffusers/models/transformers/ace_step_transformer.py
Normal file
@@ -0,0 +1,626 @@
|
||||
# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Diffusion Transformer (DiT) for ACE-Step 1.5 music generation."""
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_AttentionBackendRegistry,
|
||||
dispatch_attention_fn,
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_FLASH_ATTENTION_BACKENDS = {
|
||||
AttentionBackendName.FLASH,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
AttentionBackendName.FLASH_VARLEN_HUB,
|
||||
}
|
||||
|
||||
_FLASH_ATTENTION_VARLEN_BACKENDS = {
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
AttentionBackendName.FLASH_VARLEN_HUB,
|
||||
}
|
||||
|
||||
|
||||
def _get_current_attention_backend(processor: Optional["AceStepAttnProcessor2_0"] = None) -> AttentionBackendName:
|
||||
backend = getattr(processor, "_attention_backend", None)
|
||||
if backend is None:
|
||||
backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
return AttentionBackendName(backend)
|
||||
|
||||
|
||||
def _is_flash_attention_backend(processor: Optional["AceStepAttnProcessor2_0"] = None) -> bool:
|
||||
return _get_current_attention_backend(processor) in _FLASH_ATTENTION_BACKENDS
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# attention-mask #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _create_4d_mask(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
is_sliding_window: bool = False,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Build a `[B, 1, seq_len, seq_len]` additive mask (0.0 kept, -inf masked).
|
||||
|
||||
Mirrors the mask construction in ``acestep/models/turbo/modeling_acestep_v15_turbo.py::create_4d_mask`` so the DiT
|
||||
sees identical attention coverage regardless of whether SDPA, eager or flash attention is selected downstream.
|
||||
"""
|
||||
indices = torch.arange(seq_len, device=device)
|
||||
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff >= 0)
|
||||
|
||||
if is_sliding_window and sliding_window is not None:
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff <= sliding_window)
|
||||
else:
|
||||
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
if attention_mask is not None:
|
||||
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||
valid_mask = valid_mask & padding_mask_4d
|
||||
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||
return mask_tensor
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# RoPE helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _ace_step_rotary_freqs(
|
||||
seq_len: int, head_dim: int, theta: float, device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Build (cos, sin) freqs for ACE-Step RoPE using ``get_1d_rotary_pos_embed``.
|
||||
|
||||
The original ACE-Step DiT reuses Qwen3's rotary layout: ``freqs = cat([freq_half, freq_half], dim=-1)`` (not
|
||||
interleaved), and the rotate-half convention splits the last dim in two halves rather than unbinding pairs. That
|
||||
matches ``get_1d_rotary_pos_embed(..., use_real=True, repeat_interleave_real=False)`` + ``apply_rotary_emb(...,
|
||||
use_real_unbind_dim=-2)``.
|
||||
"""
|
||||
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
cos, sin = get_1d_rotary_pos_embed(head_dim, positions, theta=theta, use_real=True, repeat_interleave_real=False)
|
||||
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# building blocks #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class AceStepMLP(nn.Module):
|
||||
"""SwiGLU MLP used in ACE-Step transformer blocks."""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class AceStepTimestepEmbedding(nn.Module):
|
||||
"""Sinusoidal timestep embedding + 2-layer MLP + 6-way AdaLN scale/shift projection.
|
||||
|
||||
Matches the original ACE-Step checkpoint layout exactly (``linear_1``, ``linear_2``, ``time_proj``) so the
|
||||
converter maps keys 1:1. The sinusoid itself is the shared ``Timesteps`` module (``flip_sin_to_cos=True`` for
|
||||
ACE-Step's ``cat([cos, sin])`` convention).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int = 256, time_embed_dim: int = 2048, scale: float = 1000.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.scale = scale
|
||||
self.time_sinusoid = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act1 = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
||||
self.act2 = nn.SiLU()
|
||||
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
|
||||
|
||||
def forward(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
t_freq = self.time_sinusoid(t * self.scale)
|
||||
temb = self.linear_1(t_freq.to(t.dtype))
|
||||
temb = self.act1(temb)
|
||||
temb = self.linear_2(temb)
|
||||
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
|
||||
return temb, timestep_proj
|
||||
|
||||
|
||||
class AceStepAttnProcessor2_0:
|
||||
"""Attention processor for ACE-Step GQA attention.
|
||||
|
||||
Dispatches the actual attention call through ``dispatch_attention_fn`` so users can pick flash / sage / native
|
||||
backends via ``model.set_attention_backend(...)`` or the ``attention_backend`` context manager. Uses the ``(B, L,
|
||||
H, D)`` tensor layout that the diffusers attention backends consume directly.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AceStepAttnProcessor2_0 requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "AceStepAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
is_cross = attn.is_cross_attention and encoder_hidden_states is not None
|
||||
kv_input = encoder_hidden_states if is_cross else hidden_states
|
||||
|
||||
# Project to (B, L, H, D). Q uses ``heads``; K/V use ``kv_heads`` (GQA).
|
||||
query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, attn.head_dim))
|
||||
key = attn.to_k(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
|
||||
value = attn.to_v(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# RoPE on self-attention only. Matches Qwen3 layout:
|
||||
# freqs = cat([freq_half, freq_half], dim=-1); rotate-half splits last dim.
|
||||
if not is_cross and image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
|
||||
|
||||
attention_kwargs = None
|
||||
backend = _get_current_attention_backend(self)
|
||||
dispatch_backend = self._attention_backend
|
||||
sliding_window = getattr(attn, "sliding_window", None)
|
||||
|
||||
if backend in _FLASH_ATTENTION_BACKENDS:
|
||||
if attention_mask is not None:
|
||||
if attention_mask.ndim == 2:
|
||||
padding_mask = attention_mask.to(torch.bool)
|
||||
elif attention_mask.ndim == 4:
|
||||
keep_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0
|
||||
padding_mask = keep_mask.any(dim=(1, 2))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported ACE-Step attention mask shape for flash attention: {attention_mask.shape}"
|
||||
)
|
||||
|
||||
has_padding = not torch.all(padding_mask).item()
|
||||
if has_padding:
|
||||
attention_mask = padding_mask
|
||||
if backend not in _FLASH_ATTENTION_VARLEN_BACKENDS:
|
||||
raise ValueError(
|
||||
"ACE-Step flash attention received a padded attention mask. Use `flash_varlen` or "
|
||||
"`flash_varlen_hub` for batched prompts with padding, or use an unpadded batch with `flash`."
|
||||
)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if not is_cross and sliding_window is not None and key.shape[1] > sliding_window:
|
||||
# ACE-Step's dense mask keeps `abs(i - j) <= sliding_window`; flash-attn uses the same inclusive
|
||||
# left/right window convention, so pass the configured value through directly.
|
||||
attention_kwargs = {"window_size": (sliding_window, sliding_window)}
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=attn.dropout if attn.training else 0.0,
|
||||
scale=attn.scaling,
|
||||
enable_gqa=attn.heads != attn.kv_heads,
|
||||
attention_kwargs=attention_kwargs,
|
||||
backend=dispatch_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AceStepAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""GQA attention with RMSNorm on query/key for ACE-Step 1.5.
|
||||
|
||||
Uses the diffusers ``Attention`` + ``AttnProcessor`` split: this module holds the projections and Q/K norm; the
|
||||
processor runs the attention dispatch. Self-attention applies RoPE on query/key; cross-attention reads K/V from
|
||||
``encoder_hidden_states`` and does not apply RoPE.
|
||||
|
||||
GQA means Q has ``heads * head_dim`` output while K/V have ``kv_heads * head_dim`` — QKV fusion is therefore
|
||||
disabled (``_supports_qkv_fusion = False``).
|
||||
"""
|
||||
|
||||
_default_processor_cls = AceStepAttnProcessor2_0
|
||||
_available_processors = [AceStepAttnProcessor2_0]
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
bias: bool = False,
|
||||
dropout: float = 0.0,
|
||||
eps: float = 1e-6,
|
||||
sliding_window: Optional[int] = None,
|
||||
is_cross_attention: bool = False,
|
||||
processor: Optional[AceStepAttnProcessor2_0] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = num_attention_heads
|
||||
self.kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.dropout = dropout
|
||||
self.scaling = head_dim**-0.5
|
||||
self.sliding_window = sliding_window
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.to_q = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=bias)
|
||||
self.to_k = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
|
||||
self.to_v = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
|
||||
self.to_out = nn.ModuleList(
|
||||
[nn.Linear(num_attention_heads * head_dim, hidden_size, bias=bias), nn.Dropout(0.0)]
|
||||
)
|
||||
self.norm_q = RMSNorm(head_dim, eps=eps)
|
||||
self.norm_k = RMSNorm(head_dim, eps=eps)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AceStepTransformerBlock(nn.Module):
|
||||
"""ACE-Step DiT transformer block: self-attn (AdaLN) → cross-attn → MLP (AdaLN).
|
||||
|
||||
AdaLN parameters come from the shared ``scale_shift_table + timestep_proj`` chunked into 6 (3 for self-attn + 3 for
|
||||
MLP).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
intermediate_size: int,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_cross_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
bias=attention_bias,
|
||||
dropout=attention_dropout,
|
||||
eps=rms_norm_eps,
|
||||
sliding_window=sliding_window,
|
||||
is_cross_attention=False,
|
||||
)
|
||||
|
||||
self.use_cross_attention = use_cross_attention
|
||||
if self.use_cross_attention:
|
||||
self.cross_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.cross_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
bias=attention_bias,
|
||||
dropout=attention_dropout,
|
||||
eps=rms_norm_eps,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
|
||||
self.mlp_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.mlp = AceStepMLP(hidden_size, intermediate_size)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb).chunk(
|
||||
6, dim=1
|
||||
)
|
||||
|
||||
# Self-attention with AdaLN.
|
||||
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.self_attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
if self.use_cross_attention and encoder_hidden_states is not None:
|
||||
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
|
||||
attn_output = self.cross_attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
|
||||
ff_output = self.mlp(norm_hidden_states)
|
||||
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# main DiT model #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class AceStepTransformer1DModel(ModelMixin, ConfigMixin, AttentionMixin, CacheMixin):
|
||||
"""Diffusion Transformer for ACE-Step 1.5 music generation.
|
||||
|
||||
Generates audio latents conditioned on text, lyrics, and timbre. Uses 1D patch embedding (`Conv1d` with stride
|
||||
`patch_size`) followed by a stack of `AceStepTransformerBlock`s with alternating sliding-window / full attention on
|
||||
the self-attention branch. Cross-attention consumes the packed `encoder_hidden_states` produced by
|
||||
`AceStepConditionEncoder`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
in_channels: int = 192,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
patch_size: int = 2,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: Optional[List[str]] = None,
|
||||
# Dim of the condition encoder's output. Equal to `hidden_size` on the
|
||||
# non-XL turbo / base models, but the XL turbo has a smaller condition
|
||||
# encoder (`encoder_hidden_size=2048`) feeding a wider DiT
|
||||
# (`hidden_size=2560`), so `condition_embedder` needs to project it up.
|
||||
encoder_hidden_size: Optional[int] = None,
|
||||
# Variant metadata. Turbo models have guidance distilled into the weights and
|
||||
# should run without CFG; base/SFT models require CFG with the learned
|
||||
# `AceStepConditionEncoder.null_condition_emb`. The pipeline reads these to
|
||||
# pick default `guidance_scale`, `shift`, and `num_inference_steps`.
|
||||
is_turbo: bool = False,
|
||||
model_version: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if encoder_hidden_size is None:
|
||||
encoder_hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(num_hidden_layers)
|
||||
]
|
||||
self.layer_types = list(layer_types)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
AceStepTransformerBlock(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
|
||||
use_cross_attention=True,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Patchify: concat(src_latents, chunk_mask) on the channel dim then Conv1d with
|
||||
# stride=patch_size lifts (B, T, in_channels) -> (B, T/patch_size, hidden_size).
|
||||
self.proj_in_conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
# Dual-timestep conditioning: one path for `t`, one for `(t - r)` (mean-flow).
|
||||
self.time_embed = AceStepTimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||
self.time_embed_r = AceStepTimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||
|
||||
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
|
||||
|
||||
self.norm_out = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.proj_out_conv = nn.ConvTranspose1d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=audio_acoustic_hidden_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
context_latents: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""The [`AceStepTransformer1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, channels)`):
|
||||
Noisy latent input for the diffusion process.
|
||||
timestep (`torch.Tensor` of shape `(batch_size,)`):
|
||||
Current diffusion timestep `t`.
|
||||
timestep_r (`torch.Tensor` of shape `(batch_size,)`):
|
||||
Reference timestep `r` (set equal to `t` for standard inference).
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
|
||||
Conditioning embeddings from the condition encoder (text + lyrics + timbre).
|
||||
context_latents (`torch.Tensor` of shape `(batch_size, seq_len, context_dim)`):
|
||||
Context latents (source latents concatenated with chunk masks) — fed to the patchify conv alongside
|
||||
`hidden_states`.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a `Transformer2DModelOutput` or a plain tuple.
|
||||
|
||||
Returns:
|
||||
`Transformer2DModelOutput` or `tuple`: The predicted velocity field.
|
||||
"""
|
||||
# Dual timestep embedding: t and (t - r). Sum both paths' AdaLN projections.
|
||||
temb_t, timestep_proj_t = self.time_embed(timestep)
|
||||
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
|
||||
temb = temb_t + temb_r
|
||||
timestep_proj = timestep_proj_t + timestep_proj_r
|
||||
|
||||
# Context concatenation + padding to patch_size boundary + patchify.
|
||||
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
|
||||
original_seq_len = hidden_states.shape[1]
|
||||
if hidden_states.shape[1] % self.patch_size != 0:
|
||||
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode="constant", value=0)
|
||||
hidden_states = self.proj_in_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
|
||||
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
|
||||
position_embeddings = (cos, sin)
|
||||
|
||||
sliding_attn_mask = None
|
||||
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
|
||||
sliding_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
# Full-attention layers see no mask; only the sliding-attention layers
|
||||
# need the banded mask. Cross-attention uses no padding mask.
|
||||
layer_attn_mask = sliding_attn_mask if self.layer_types[i] == "sliding_attention" else None
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
timestep_proj,
|
||||
layer_attn_mask,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
temb=timestep_proj,
|
||||
attention_mask=layer_attn_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
)
|
||||
|
||||
# Adaptive output normalization + de-patchify.
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
hidden_states = hidden_states[:, :original_seq_len, :]
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -149,6 +149,12 @@ else:
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["ace_step"] = [
|
||||
"AceStepAudioTokenDetokenizer",
|
||||
"AceStepAudioTokenizer",
|
||||
"AceStepConditionEncoder",
|
||||
"AceStepPipeline",
|
||||
]
|
||||
_import_structure["allegro"] = ["AllegroPipeline"]
|
||||
_import_structure["animatediff"] = [
|
||||
"AnimateDiffPipeline",
|
||||
@@ -574,6 +580,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .ace_step import (
|
||||
AceStepAudioTokenDetokenizer,
|
||||
AceStepAudioTokenizer,
|
||||
AceStepConditionEncoder,
|
||||
AceStepPipeline,
|
||||
)
|
||||
from .allegro import AllegroPipeline
|
||||
from .animatediff import (
|
||||
AnimateDiffControlNetPipeline,
|
||||
|
||||
54
src/diffusers/pipelines/ace_step/__init__.py
Normal file
54
src/diffusers/pipelines/ace_step/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modeling_ace_step"] = [
|
||||
"AceStepAudioTokenDetokenizer",
|
||||
"AceStepAudioTokenizer",
|
||||
"AceStepConditionEncoder",
|
||||
]
|
||||
_import_structure["pipeline_ace_step"] = ["AceStepPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder
|
||||
from .pipeline_ace_step import AceStepPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
856
src/diffusers/pipelines/ace_step/modeling_ace_step.py
Normal file
856
src/diffusers/pipelines/ace_step/modeling_ace_step.py
Normal file
@@ -0,0 +1,856 @@
|
||||
# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pipeline-specific models for ACE-Step 1.5.
|
||||
|
||||
Holds the condition encoder (lyric + timbre + text packing), the encoder layer (``AceStepEncoderLayer`` — not used by
|
||||
the DiT itself, hence kept here), the audio tokenizer / detokenizer used by cover conditioning, and the
|
||||
``_pack_sequences`` helper. The DiT uses the RoPE helper, ``AceStepAttention``, and ``_create_4d_mask`` from
|
||||
``diffusers/models/transformers/ace_step_transformer.py``.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...models.transformers.ace_step_transformer import (
|
||||
AceStepAttention,
|
||||
AceStepMLP,
|
||||
_ace_step_rotary_freqs,
|
||||
_create_4d_mask,
|
||||
_is_flash_attention_backend,
|
||||
)
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers used only by condition encoder #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _pack_sequences(
|
||||
hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pack two masked sequences into one with all valid tokens first.
|
||||
|
||||
Concatenates ``hidden1`` + ``hidden2`` along the sequence dim, then stably sorts each batch so mask=1 tokens come
|
||||
before mask=0 tokens. Returns the packed hidden states plus a fresh contiguous mask.
|
||||
"""
|
||||
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
|
||||
mask_cat = torch.cat([mask1, mask2], dim=1)
|
||||
|
||||
B, L, D = hidden_cat.shape
|
||||
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
|
||||
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
|
||||
lengths = mask_cat.sum(dim=1)
|
||||
new_mask = torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)
|
||||
return hidden_left, new_mask
|
||||
|
||||
|
||||
class AceStepEncoderLayer(nn.Module):
|
||||
"""Pre-LN transformer block used by the lyric and timbre encoders."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
intermediate_size: int,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
bias=attention_bias,
|
||||
dropout=attention_dropout,
|
||||
eps=rms_norm_eps,
|
||||
sliding_window=sliding_window,
|
||||
is_cross_attention=False,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.mlp = AceStepMLP(hidden_size, intermediate_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
image_rotary_emb=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# encoders #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class AceStepLyricEncoder(ModelMixin, ConfigMixin):
|
||||
"""Lyric encoder: projects Qwen3 lyric embeddings and runs a small transformer.
|
||||
|
||||
Output feeds the DiT cross-attention (after packing with text + timbre).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
text_hidden_dim: int = 1024,
|
||||
num_lyric_encoder_hidden_layers: int = 8,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
|
||||
for i in range(num_lyric_encoder_hidden_layers)
|
||||
]
|
||||
|
||||
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
|
||||
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
|
||||
)
|
||||
for i in range(num_lyric_encoder_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self._layer_types = layer_types
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
dtype = inputs_embeds.dtype
|
||||
device = inputs_embeds.device
|
||||
|
||||
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
|
||||
position_embeddings = (cos, sin)
|
||||
|
||||
if _is_flash_attention_backend(self.layers[0].self_attn.processor):
|
||||
full_attn_mask = attention_mask
|
||||
sliding_attn_mask = attention_mask
|
||||
else:
|
||||
full_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, is_causal=False
|
||||
)
|
||||
sliding_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window=self.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else full_attn_mask
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module, hidden_states, position_embeddings, mask
|
||||
)
|
||||
else:
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=mask,
|
||||
)
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
class AceStepTimbreEncoder(ModelMixin, ConfigMixin):
|
||||
"""Timbre encoder: consumes VAE-encoded reference-audio latents and returns a
|
||||
pooled per-batch timbre embedding (plus a presence mask).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
timbre_hidden_dim: int = 64,
|
||||
num_timbre_encoder_hidden_layers: int = 4,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
|
||||
for i in range(num_timbre_encoder_hidden_layers)
|
||||
]
|
||||
|
||||
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
|
||||
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
|
||||
)
|
||||
for i in range(num_timbre_encoder_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self._layer_types = layer_types
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@staticmethod
|
||||
def unpack_timbre_embeddings(
|
||||
timbre_embs_packed: torch.Tensor, refer_audio_order_mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
N, d = timbre_embs_packed.shape
|
||||
device = timbre_embs_packed.device
|
||||
dtype = timbre_embs_packed.dtype
|
||||
|
||||
B = int(refer_audio_order_mask.max().item() + 1)
|
||||
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
||||
max_count = counts.max().item()
|
||||
|
||||
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
|
||||
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
||||
|
||||
positions = torch.arange(N, device=device)
|
||||
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
|
||||
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
||||
|
||||
inverse_indices = torch.empty_like(sorted_indices)
|
||||
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
||||
positions_in_batch = positions_in_sorted[inverse_indices]
|
||||
|
||||
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
|
||||
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
|
||||
|
||||
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
|
||||
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
|
||||
|
||||
mask_flat = (one_hot.sum(dim=0) > 0).long()
|
||||
new_mask = mask_flat.reshape(B, max_count)
|
||||
return timbre_embs_unpack, new_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
|
||||
refer_audio_order_mask: torch.LongTensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
inputs_embeds = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
|
||||
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
dtype = inputs_embeds.dtype
|
||||
device = inputs_embeds.device
|
||||
|
||||
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
|
||||
position_embeddings = (cos, sin)
|
||||
|
||||
sliding_attn_mask = None
|
||||
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
|
||||
sliding_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=None,
|
||||
sliding_window=self.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
# No padding mask on timbre input (pre-packed), so full-attention layers see None.
|
||||
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module, hidden_states, position_embeddings, mask
|
||||
)
|
||||
else:
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
# CLS-like pooling: first-token embedding per packed sequence.
|
||||
hidden_states = hidden_states[:, 0, :]
|
||||
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
|
||||
return timbre_embs_unpack, timbre_embs_mask
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# audio tokenizer / detokenizer #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class _AceStepResidualFSQ(nn.Module):
|
||||
"""Minimal ResidualFSQ compatible with ACE-Step's saved tokenizer weights."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2048,
|
||||
levels: Optional[list] = None,
|
||||
num_quantizers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if levels is None:
|
||||
levels = [8, 8, 8, 5, 5, 5]
|
||||
|
||||
self.levels = levels
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_dim = len(levels)
|
||||
|
||||
self.project_in = nn.Linear(dim, self.codebook_dim)
|
||||
self.project_out = nn.Linear(self.codebook_dim, dim)
|
||||
|
||||
levels_tensor = torch.tensor(levels, dtype=torch.long)
|
||||
basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.long), dim=0)
|
||||
scales = torch.stack([levels_tensor.float() ** -i for i in range(num_quantizers)])
|
||||
self.register_buffer("_levels", levels_tensor, persistent=False)
|
||||
self.register_buffer("_basis", basis, persistent=False)
|
||||
self.register_buffer("scales", scales, persistent=False)
|
||||
|
||||
@property
|
||||
def codebook_size(self) -> int:
|
||||
return int(torch.prod(self._levels).item())
|
||||
|
||||
def _indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
levels = self._levels.to(device=indices.device)
|
||||
basis = self._basis.to(device=indices.device)
|
||||
level_indices = (indices.long().unsqueeze(-1) // basis) % levels
|
||||
scale = 2.0 / (levels.to(dtype=torch.float32) - 1.0)
|
||||
return level_indices.to(dtype=torch.float32) * scale - 1.0
|
||||
|
||||
def _codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor:
|
||||
levels = self._levels.to(device=codes.device, dtype=codes.dtype)
|
||||
basis = self._basis.to(device=codes.device, dtype=codes.dtype)
|
||||
level_indices = (codes + 1.0) / (2.0 / (levels - 1.0))
|
||||
return (level_indices * basis).sum(dim=-1).round().to(torch.long)
|
||||
|
||||
def _quantize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
levels = self._levels.to(device=x.device, dtype=x.dtype)
|
||||
levels_minus_one = levels - 1.0
|
||||
step = 2.0 / levels_minus_one
|
||||
bracket = levels_minus_one * (x.clamp(-1.0, 1.0) + 1.0) / 2.0 + 0.5
|
||||
return step * torch.floor(bracket) - 1.0
|
||||
|
||||
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
if indices.ndim == 2:
|
||||
indices = indices.unsqueeze(-1)
|
||||
if indices.shape[-1] != self.num_quantizers:
|
||||
raise ValueError(
|
||||
f"Expected audio code indices with last dimension {self.num_quantizers}, got {indices.shape[-1]}."
|
||||
)
|
||||
|
||||
codes = []
|
||||
for quantizer_idx in range(self.num_quantizers):
|
||||
code = self._indices_to_codes(indices[..., quantizer_idx])
|
||||
scale = self.scales[quantizer_idx].to(device=code.device, dtype=code.dtype)
|
||||
codes.append(code * scale)
|
||||
return torch.stack(codes, dim=0)
|
||||
|
||||
def get_output_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
codes = self.get_codes_from_indices(indices).sum(dim=0)
|
||||
weight = self.project_out.weight.float()
|
||||
bias = self.project_out.bias.float() if self.project_out.bias is not None else None
|
||||
output = F.linear(codes.float(), weight, bias)
|
||||
return output.to(dtype=self.project_out.weight.dtype)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
input_dtype = hidden_states.dtype
|
||||
weight = self.project_in.weight.float()
|
||||
bias = self.project_in.bias.float() if self.project_in.bias is not None else None
|
||||
hidden_states = F.linear(hidden_states.float(), weight, bias)
|
||||
|
||||
levels = self._levels.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
soft_clamp = 1.0 + (1.0 / (levels - 1.0))
|
||||
hidden_states = (hidden_states / soft_clamp).tanh() * soft_clamp
|
||||
|
||||
quantized_out = torch.zeros_like(hidden_states)
|
||||
residual = hidden_states
|
||||
all_indices = []
|
||||
for scale in self.scales.to(device=hidden_states.device, dtype=hidden_states.dtype):
|
||||
quantized = self._quantize(residual / scale) * scale
|
||||
residual = residual - quantized.detach()
|
||||
quantized_out = quantized_out + quantized
|
||||
all_indices.append(self._codes_to_indices(quantized / scale))
|
||||
|
||||
weight = self.project_out.weight.float()
|
||||
bias = self.project_out.bias.float() if self.project_out.bias is not None else None
|
||||
quantized_out = F.linear(quantized_out.float(), weight, bias).to(dtype=input_dtype)
|
||||
all_indices = torch.stack(all_indices, dim=-1)
|
||||
return quantized_out, all_indices
|
||||
|
||||
|
||||
class AceStepAttentionPooler(nn.Module):
|
||||
"""Attention pooler used by the ACE-Step audio tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
|
||||
for i in range(num_attention_pooler_hidden_layers)
|
||||
]
|
||||
|
||||
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
|
||||
)
|
||||
for i in range(num_attention_pooler_hidden_layers)
|
||||
]
|
||||
)
|
||||
self._layer_types = layer_types
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_patches, patch_size, _ = hidden_states.shape
|
||||
hidden_states = self.embed_tokens(hidden_states)
|
||||
special_token = self.special_token.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
special_token = special_token.expand(batch_size, num_patches, -1, -1)
|
||||
hidden_states = torch.cat([special_token, hidden_states], dim=2)
|
||||
hidden_states = hidden_states.reshape(batch_size * num_patches, patch_size + 1, -1)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
|
||||
sliding_attn_mask = None
|
||||
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
|
||||
sliding_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=None,
|
||||
sliding_window=self.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states[:, 0, :]
|
||||
return hidden_states.reshape(batch_size, num_patches, -1)
|
||||
|
||||
|
||||
class AceStepAudioTokenDetokenizer(ModelMixin, ConfigMixin):
|
||||
"""Expands ACE-Step 5 Hz audio tokens back to 25 Hz acoustic conditioning."""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
pool_window_size: int = 5,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = [
|
||||
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
|
||||
for i in range(num_attention_pooler_hidden_layers)
|
||||
]
|
||||
|
||||
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
|
||||
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
self.pool_window_size = pool_window_size
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
intermediate_size=intermediate_size,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
|
||||
)
|
||||
for i in range(num_attention_pooler_hidden_layers)
|
||||
]
|
||||
)
|
||||
self._layer_types = layer_types
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_tokens, _ = hidden_states.shape
|
||||
hidden_states = self.embed_tokens(hidden_states)
|
||||
hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.pool_window_size, -1)
|
||||
special_tokens = self.special_tokens.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states + special_tokens.unsqueeze(0)
|
||||
hidden_states = hidden_states.reshape(batch_size * num_tokens, self.pool_window_size, -1)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
|
||||
sliding_attn_mask = None
|
||||
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
|
||||
sliding_attn_mask = _create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=None,
|
||||
sliding_window=self.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module, hidden_states, position_embeddings, mask
|
||||
)
|
||||
else:
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states.reshape(batch_size, num_tokens * self.pool_window_size, -1)
|
||||
|
||||
|
||||
class AceStepAudioTokenizer(ModelMixin, ConfigMixin):
|
||||
"""Converts 25 Hz acoustic latents to ACE-Step 5 Hz audio tokens."""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
pool_window_size: int = 5,
|
||||
fsq_dim: int = 2048,
|
||||
fsq_input_levels: list = None,
|
||||
fsq_input_num_quantizers: int = 1,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if fsq_input_levels is None:
|
||||
fsq_input_levels = [8, 8, 8, 5, 5, 5]
|
||||
|
||||
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
|
||||
self.attention_pooler = AceStepAttentionPooler(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window,
|
||||
layer_types=layer_types,
|
||||
)
|
||||
self.quantizer = _AceStepResidualFSQ(
|
||||
dim=fsq_dim,
|
||||
levels=fsq_input_levels,
|
||||
num_quantizers=fsq_input_num_quantizers,
|
||||
)
|
||||
self.pool_window_size = pool_window_size
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = self.audio_acoustic_proj(hidden_states)
|
||||
hidden_states = self.attention_pooler(hidden_states)
|
||||
quantized, indices = self.quantizer(hidden_states)
|
||||
return quantized.to(dtype=input_dtype), indices
|
||||
|
||||
def tokenize(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
silence_latent: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, latent_length, acoustic_dim = hidden_states.shape
|
||||
pad_len = (-latent_length) % self.pool_window_size
|
||||
if pad_len:
|
||||
if silence_latent is not None and silence_latent.shape[-1] == acoustic_dim:
|
||||
pad = silence_latent[:, :pad_len, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
pad = pad.expand(batch_size, -1, -1)
|
||||
else:
|
||||
pad = torch.zeros(
|
||||
batch_size, pad_len, acoustic_dim, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=1)
|
||||
|
||||
num_patches = hidden_states.shape[1] // self.pool_window_size
|
||||
hidden_states = hidden_states.reshape(batch_size, num_patches, self.pool_window_size, acoustic_dim)
|
||||
return self(hidden_states)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# condition encoder #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class AceStepConditionEncoder(ModelMixin, ConfigMixin):
|
||||
"""Fuses text + lyric + timbre conditioning into the packed sequence used by
|
||||
the DiT's cross-attention.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
text_hidden_dim: int = 1024,
|
||||
timbre_hidden_dim: int = 64,
|
||||
num_lyric_encoder_hidden_layers: int = 8,
|
||||
num_timbre_encoder_hidden_layers: int = 4,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 1000000.0,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
sliding_window: int = 128,
|
||||
layer_types: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
|
||||
|
||||
self.lyric_encoder = AceStepLyricEncoder(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
text_hidden_dim=text_hidden_dim,
|
||||
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window,
|
||||
layer_types=layer_types,
|
||||
)
|
||||
|
||||
self.timbre_encoder = AceStepTimbreEncoder(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
timbre_hidden_dim=timbre_hidden_dim,
|
||||
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
# Learned null-condition embedding for classifier-free guidance, trained with
|
||||
# `cfg_ratio=0.15` in the original model. Broadcast along the sequence dim when used.
|
||||
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||
|
||||
# Silence latent — VAE-encoded audio-silence, stored as (1, T_long, timbre_hidden_dim).
|
||||
# When no reference audio is provided, the pipeline slices `silence_latent[:, :timbre_fix_frame, :]`
|
||||
# and feeds that to the timbre encoder. Passing literal zeros puts the timbre encoder
|
||||
# OOD and produces drone-like audio (observed on all text2music outputs before this fix).
|
||||
# The placeholder here is overwritten by the converter with the real encoded silence,
|
||||
# so its shape just needs to match the timbre-encoder input: last dim is
|
||||
# `timbre_hidden_dim` (so smaller test configs with `timbre_hidden_dim != 64` also load).
|
||||
self.register_buffer(
|
||||
"silence_latent",
|
||||
torch.zeros(1, 15000, timbre_hidden_dim),
|
||||
persistent=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_hidden_states: torch.FloatTensor,
|
||||
text_attention_mask: torch.Tensor,
|
||||
lyric_hidden_states: torch.FloatTensor,
|
||||
lyric_attention_mask: torch.Tensor,
|
||||
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
|
||||
refer_audio_order_mask: torch.LongTensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_hidden_states = self.text_projector(text_hidden_states)
|
||||
|
||||
lyric_hidden_states = self.lyric_encoder(
|
||||
inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask
|
||||
)
|
||||
|
||||
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
|
||||
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask
|
||||
)
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = _pack_sequences(
|
||||
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
|
||||
)
|
||||
encoder_hidden_states, encoder_attention_mask = _pack_sequences(
|
||||
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
|
||||
)
|
||||
|
||||
return encoder_hidden_states, encoder_attention_mask
|
||||
1271
src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Normal file
1271
src/diffusers/pipelines/ace_step/pipeline_ace_step.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -405,6 +405,21 @@ class VaeImageProcessorLDM3D(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AceStepTransformer1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AllegroTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -632,6 +632,66 @@ class ZImageModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AceStepAudioTokenDetokenizer(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AceStepAudioTokenizer(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AceStepConditionEncoder(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AceStepPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AllegroPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AceStepTransformer1DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AceStepTransformer1DModelTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return AceStepTransformer1DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | float | bool]:
|
||||
return {
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 64,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 8,
|
||||
"in_channels": 24, # audio_acoustic_hidden_dim * 3 (hidden + context_latents)
|
||||
"audio_acoustic_hidden_dim": 8,
|
||||
"patch_size": 2,
|
||||
"rope_theta": 10000.0,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"sliding_window": 16,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 2
|
||||
seq_len = 8
|
||||
encoder_seq_len = 10
|
||||
acoustic_dim = 8
|
||||
hidden_size = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, seq_len, acoustic_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs(),
|
||||
"timestep_r": randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs(),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, encoder_seq_len, hidden_size), generator=self.generator, device=torch_device
|
||||
),
|
||||
"context_latents": randn_tensor(
|
||||
(batch_size, seq_len, acoustic_dim * 2), generator=self.generator, device=torch_device
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestAceStepTransformer1DModel(AceStepTransformer1DModelTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
0
tests/pipelines/ace_step/__init__.py
Normal file
0
tests/pipelines/ace_step/__init__.py
Normal file
486
tests/pipelines/ace_step/test_ace_step.py
Normal file
486
tests/pipelines/ace_step/test_ace_step.py
Normal file
@@ -0,0 +1,486 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import AutoencoderOobleck, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.models.transformers.ace_step_transformer import AceStepTransformer1DModel
|
||||
from diffusers.pipelines.ace_step import (
|
||||
AceStepAudioTokenDetokenizer,
|
||||
AceStepAudioTokenizer,
|
||||
AceStepConditionEncoder,
|
||||
AceStepPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AceStepConditionEncoderTests(unittest.TestCase):
|
||||
"""Fast tests for the AceStepConditionEncoder."""
|
||||
|
||||
def get_tiny_config(self):
|
||||
return {
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 64,
|
||||
"text_hidden_dim": 16,
|
||||
"timbre_hidden_dim": 8,
|
||||
"num_lyric_encoder_hidden_layers": 2,
|
||||
"num_timbre_encoder_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 8,
|
||||
"rope_theta": 10000.0,
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"sliding_window": 16,
|
||||
}
|
||||
|
||||
def test_forward_shape(self):
|
||||
"""Test that the condition encoder produces packed hidden states."""
|
||||
config = self.get_tiny_config()
|
||||
encoder = AceStepConditionEncoder(**config)
|
||||
encoder.eval()
|
||||
|
||||
batch_size = 2
|
||||
text_seq_len = 8
|
||||
lyric_seq_len = 12
|
||||
text_dim = config["text_hidden_dim"]
|
||||
timbre_dim = config["timbre_hidden_dim"]
|
||||
timbre_time = 10
|
||||
|
||||
text_hidden_states = torch.randn(batch_size, text_seq_len, text_dim)
|
||||
text_attention_mask = torch.ones(batch_size, text_seq_len)
|
||||
lyric_hidden_states = torch.randn(batch_size, lyric_seq_len, text_dim)
|
||||
lyric_attention_mask = torch.ones(batch_size, lyric_seq_len)
|
||||
|
||||
# Packed reference audio: 3 references across 2 batch items
|
||||
refer_audio = torch.randn(3, timbre_time, timbre_dim)
|
||||
refer_order_mask = torch.tensor([0, 0, 1], dtype=torch.long)
|
||||
|
||||
with torch.no_grad():
|
||||
enc_hidden, enc_mask = encoder(
|
||||
text_hidden_states=text_hidden_states,
|
||||
text_attention_mask=text_attention_mask,
|
||||
lyric_hidden_states=lyric_hidden_states,
|
||||
lyric_attention_mask=lyric_attention_mask,
|
||||
refer_audio_acoustic_hidden_states_packed=refer_audio,
|
||||
refer_audio_order_mask=refer_order_mask,
|
||||
)
|
||||
|
||||
# Output should be packed: batch_size x (lyric + timbre + text seq_len) x hidden_size
|
||||
self.assertEqual(enc_hidden.shape[0], batch_size)
|
||||
self.assertEqual(enc_hidden.shape[2], config["hidden_size"])
|
||||
self.assertEqual(enc_mask.shape[0], batch_size)
|
||||
self.assertEqual(enc_mask.shape[1], enc_hidden.shape[1])
|
||||
|
||||
def test_save_load_config(self):
|
||||
"""Test that the condition encoder config can be saved and loaded."""
|
||||
import tempfile
|
||||
|
||||
config = self.get_tiny_config()
|
||||
encoder = AceStepConditionEncoder(**config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
encoder.save_config(tmpdir)
|
||||
loaded = AceStepConditionEncoder.from_config(tmpdir)
|
||||
|
||||
self.assertEqual(encoder.config.hidden_size, loaded.config.hidden_size)
|
||||
self.assertEqual(encoder.config.text_hidden_dim, loaded.config.text_hidden_dim)
|
||||
self.assertEqual(encoder.config.timbre_hidden_dim, loaded.config.timbre_hidden_dim)
|
||||
|
||||
|
||||
class AceStepPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"""Fast end-to-end tests for AceStepPipeline with tiny models."""
|
||||
|
||||
pipeline_class = AceStepPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"lyrics",
|
||||
"audio_duration",
|
||||
"vocal_language",
|
||||
"guidance_scale",
|
||||
"shift",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "lyrics"])
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
)
|
||||
|
||||
# ACE-Step uses custom attention, not standard diffusers attention processors
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = AceStepTransformer1DModel(
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
head_dim=8,
|
||||
in_channels=24,
|
||||
audio_acoustic_hidden_dim=8,
|
||||
patch_size=2,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=16,
|
||||
)
|
||||
|
||||
# Create a tiny Qwen3Model for testing (matching the real Qwen3-Embedding-0.6B architecture)
|
||||
torch.manual_seed(0)
|
||||
qwen3_config = Qwen3Config(
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
head_dim=8,
|
||||
vocab_size=151936, # Qwen3 vocab size
|
||||
max_position_embeddings=256,
|
||||
)
|
||||
text_encoder = Qwen3Model(qwen3_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
|
||||
text_hidden_dim = qwen3_config.hidden_size # 32
|
||||
|
||||
torch.manual_seed(0)
|
||||
condition_encoder = AceStepConditionEncoder(
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
text_hidden_dim=text_hidden_dim,
|
||||
timbre_hidden_dim=8,
|
||||
num_lyric_encoder_hidden_layers=2,
|
||||
num_timbre_encoder_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
head_dim=8,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=16,
|
||||
)
|
||||
|
||||
audio_tokenizer_kwargs = {
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 64,
|
||||
"audio_acoustic_hidden_dim": 8,
|
||||
"pool_window_size": 2,
|
||||
"fsq_dim": 32,
|
||||
"fsq_input_levels": [4, 4, 4],
|
||||
"fsq_input_num_quantizers": 1,
|
||||
"num_attention_pooler_hidden_layers": 1,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 8,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 16,
|
||||
}
|
||||
torch.manual_seed(0)
|
||||
audio_tokenizer = AceStepAudioTokenizer(**audio_tokenizer_kwargs)
|
||||
torch.manual_seed(0)
|
||||
audio_token_detokenizer = AceStepAudioTokenDetokenizer(
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
audio_acoustic_hidden_dim=8,
|
||||
pool_window_size=2,
|
||||
num_attention_pooler_hidden_layers=1,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
head_dim=8,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=16,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderOobleck(
|
||||
encoder_hidden_size=6,
|
||||
downsampling_ratios=[1, 2],
|
||||
decoder_channels=3,
|
||||
decoder_input_channels=8,
|
||||
audio_channels=2,
|
||||
channel_multiples=[2, 4],
|
||||
sampling_rate=4,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"condition_encoder": condition_encoder,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"audio_tokenizer": audio_tokenizer,
|
||||
"audio_token_detokenizer": audio_token_detokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A beautiful piano piece",
|
||||
"lyrics": "[verse]\nSoft notes in the morning",
|
||||
"audio_duration": 0.4, # Very short for fast test (10 latent frames at 25Hz)
|
||||
"num_inference_steps": 2,
|
||||
"generator": generator,
|
||||
"max_text_length": 32,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ace_step_basic(self):
|
||||
"""Test basic text-to-music generation."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt="A beautiful piano piece",
|
||||
lyrics="[verse]\nSoft notes in the morning",
|
||||
audio_duration=0.4,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
max_text_length=32,
|
||||
)
|
||||
audio = output.audios
|
||||
self.assertIsNotNone(audio)
|
||||
self.assertEqual(audio.ndim, 3) # [batch, channels, samples]
|
||||
|
||||
def test_ace_step_batch(self):
|
||||
"""Test batch generation."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
output = pipe(
|
||||
prompt=["Piano piece", "Guitar solo"],
|
||||
lyrics=["[verse]\nHello", "[chorus]\nWorld"],
|
||||
audio_duration=0.4,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
max_text_length=32,
|
||||
)
|
||||
audio = output.audios
|
||||
self.assertIsNotNone(audio)
|
||||
self.assertEqual(audio.shape[0], 2) # batch size = 2
|
||||
|
||||
def test_ace_step_latent_output(self):
|
||||
"""Test that output_type='latent' returns latents."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt="A test prompt",
|
||||
lyrics="",
|
||||
audio_duration=0.4,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
output_type="latent",
|
||||
max_text_length=32,
|
||||
)
|
||||
latents = output.audios
|
||||
self.assertIsNotNone(latents)
|
||||
# Latent shape: [batch, latent_length, acoustic_dim]
|
||||
self.assertEqual(latents.ndim, 3)
|
||||
self.assertEqual(latents.shape[0], 1)
|
||||
|
||||
def test_ace_step_return_dict_false(self):
|
||||
"""Test that return_dict=False returns a tuple."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt="A test prompt",
|
||||
lyrics="",
|
||||
audio_duration=0.4,
|
||||
num_inference_steps=2,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
max_text_length=32,
|
||||
)
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertEqual(len(output), 1)
|
||||
|
||||
def test_audio_codes_cover_path(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
|
||||
output = pipe(
|
||||
prompt="A test prompt",
|
||||
lyrics="",
|
||||
audio_codes="<|audio_code_1|><|audio_code_2|>",
|
||||
num_inference_steps=1,
|
||||
output_type="latent",
|
||||
max_text_length=32,
|
||||
)
|
||||
|
||||
self.assertEqual(output.audios.shape[1], 4)
|
||||
|
||||
def test_save_load_local(self, expected_max_difference=7e-3):
|
||||
# increase tolerance to account for large composite model
|
||||
super().test_save_load_local(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=7e-3):
|
||||
# increase tolerance to account for large composite model
|
||||
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=7e-3):
|
||||
# increase tolerance for audio pipeline
|
||||
super().test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=7e-3):
|
||||
# increase tolerance for audio pipeline
|
||||
super().test_dict_tuple_outputs_equivalent(
|
||||
expected_slice=expected_slice, expected_max_difference=expected_max_difference
|
||||
)
|
||||
|
||||
# ACE-Step does not use num_images_per_prompt
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
# ACE-Step does not use standard schedulers
|
||||
@unittest.skip("ACE-Step uses built-in flow matching schedule, not diffusers schedulers")
|
||||
def test_karras_schedulers_shape(self):
|
||||
pass
|
||||
|
||||
# ACE-Step does not support prompt_embeds directly
|
||||
@unittest.skip("ACE-Step does not support prompt_embeds / negative_prompt_embeds")
|
||||
def test_cfg(self):
|
||||
pass
|
||||
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
super().test_float16_inference(expected_max_diff=expected_max_diff)
|
||||
|
||||
@unittest.skip(
|
||||
"ACE-Step __call__ does not accept prompt_embeds, so encode_prompt isolation test is not applicable"
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Sequential CPU offloading produces NaN with tiny random models")
|
||||
def test_sequential_cpu_offload_forward_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Sequential CPU offloading produces NaN with tiny random models")
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
def test_encode_prompt(self):
|
||||
"""Test that encode_prompt returns correct shapes."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
text_hidden, text_mask, lyric_hidden, lyric_mask = pipe.encode_prompt(
|
||||
prompt="A test prompt",
|
||||
lyrics="[verse]\nHello world",
|
||||
device=device,
|
||||
max_text_length=32,
|
||||
max_lyric_length=64,
|
||||
)
|
||||
|
||||
self.assertEqual(text_hidden.ndim, 3) # [batch, seq_len, hidden_dim]
|
||||
self.assertEqual(text_mask.ndim, 2) # [batch, seq_len]
|
||||
self.assertEqual(lyric_hidden.ndim, 3)
|
||||
self.assertEqual(lyric_mask.ndim, 2)
|
||||
self.assertEqual(text_hidden.shape[0], 1)
|
||||
self.assertEqual(lyric_hidden.shape[0], 1)
|
||||
|
||||
def test_prepare_latents(self):
|
||||
"""Test that prepare_latents returns correct shapes."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
latents = pipe.prepare_latents(
|
||||
batch_size=2,
|
||||
audio_duration=1.0,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected_length = math.ceil(1.0 * pipe.latents_per_second)
|
||||
self.assertEqual(latents.shape, (2, expected_length, 8))
|
||||
|
||||
def test_timestep_schedule(self):
|
||||
"""Test that the timestep schedule is generated correctly."""
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
|
||||
# Test standard schedule
|
||||
schedule = pipe._get_timestep_schedule(num_inference_steps=8, shift=3.0)
|
||||
self.assertEqual(len(schedule), 8)
|
||||
self.assertAlmostEqual(schedule[0].item(), 1.0, places=5)
|
||||
|
||||
# Test truncated schedule
|
||||
schedule = pipe._get_timestep_schedule(num_inference_steps=4, shift=3.0)
|
||||
self.assertEqual(len(schedule), 4)
|
||||
|
||||
def test_format_prompt(self):
|
||||
"""Test that prompt formatting works correctly."""
|
||||
components = self.get_dummy_components()
|
||||
pipe = AceStepPipeline(**components)
|
||||
|
||||
text, lyrics = pipe._format_prompt(
|
||||
prompt="A piano piece",
|
||||
lyrics="[verse]\nHello",
|
||||
vocal_language="en",
|
||||
audio_duration=30.0,
|
||||
)
|
||||
|
||||
self.assertIn("A piano piece", text)
|
||||
self.assertIn("30 seconds", text)
|
||||
self.assertIn("[verse]", lyrics)
|
||||
self.assertIn("Hello", lyrics)
|
||||
self.assertIn("en", lyrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user