Add ACE-Step pipeline for text-to-music generation (#13095)

* Add ACE-Step pipeline for text-to-music generation

Rebased on origin/main from the original pr-13095 branch (3 commits squashed).

- AceStepDiTModel: Diffusion Transformer with RoPE, GQA, sliding window,
  AdaLN timestep conditioning, and cross-attention.
- AceStepConditionEncoder: fuses text / lyric / timbre into a single
  cross-attention sequence.
- AceStepPipeline: text2music / cover / repaint / extract / lego / complete.
- Conversion script for the original checkpoint layout.
- Docs + tests.

* Fix ACE-Step pipeline audio quality and auto-detect turbo/base/sft variants

The PR's original inference produced low-quality audio on turbo because the
pipeline (a) mangled the SFT prompt format, (b) applied classifier-free guidance
with the wrong unconditional embedding (empty-string encoded vs. the learned
`null_condition_emb`), and (c) hardcoded turbo defaults even when loading a
base/SFT checkpoint.

Changes:

* Converter preserves `null_condition_emb` (stored under the condition encoder)
  and propagates `is_turbo`/`model_version` into the transformer config so the
  pipeline can route per-variant defaults.
* `AceStepConditionEncoder` registers `null_condition_emb` as a learned
  parameter matching the original module.
* Pipeline auto-detects variant via `is_turbo`/`model_version` and picks
  defaults that match `acestep/inference.py`:
    * turbo:  steps=8,  shift=3.0, guidance_scale=1.0 (no CFG)
    * base/SFT: steps=27, shift=1.0, guidance_scale=7.0
* Base/SFT timestep schedule uses the linear+shift transform from
  `acestep/models/base/modeling_acestep_v15_base.py`; turbo still uses the
  hardcoded 8-step `SHIFT_TIMESTEPS` table.
* CFG reuses the learned `null_condition_emb` and batches the
  conditional+unconditional forwards into a single transformer call.
* `SFT_GEN_PROMPT` matches the newline layout in `acestep/constants.py` so the
  text encoder sees the same prompt distribution it was trained on.

DiT parity vs. the original ACE-Step 1.5 turbo DiT is bit-identical
(max_abs=0.0 in fp32 eager/SDPA across 4 seed/shape cases) — see
scripts/dit_parity_test.py.

* Add ACE-Step parity test scripts

Two developer-facing parity harnesses live under scripts/:

* dit_parity_test.py — loads the same converted turbo weights into the
  original AceStepDiTModel and the diffusers AceStepDiTModel, drives
  identical (hidden_states, timestep, timestep_r, encoder_hidden_states,
  context_latents) inputs, and asserts max-abs-diff ≤ 1e-5 in fp32
  eager/SDPA. Currently passes bit-identical (max_abs=0) across four
  shape/seed cases including batched + odd-length paths.

* audio_parity_jieyue.py — full end-to-end audio parity. Given the same
  JSON example, runs both the original ACE-Step 1.5 pipeline and the
  diffusers AceStepPipeline at matched seed/precision (bf16 + FA2 by
  default) and saves side-by-side .wav files for listening verification.
  Supports text2music / cover / repaint × turbo / base / sft via a
  --matrix mode that writes 18 wavs named
  {variant}_{task}_{official,diffusers}.wav.

* Route SFT parity to acestep-v15-sft checkpoint

On jieyue the release tree has a dedicated SFT checkpoint at
checkpoints/acestep-v15-sft with its own modeling_acestep_v15_base.py
shipped under acestep/models/sft/. Point the SFT row of the parity matrix
at that checkpoint / module so we're testing the actual SFT weights, not
the plain base ones.

* audio_parity_jieyue: fix doubled 'acestep-' in cache path; --converted-root flag

Previously the converted-pipeline cache dir was
`/tmp/acestep-<variant>-diffusers` but <variant> already starts with
"acestep-", giving `/tmp/acestep-acestep-v15-turbo-diffusers`. Drop the
prefix.

On jieyue the overlay rootfs (including /tmp) only has a few GB free; a
full turbo conversion needs ~5 GB per variant. Add --converted-root (env
ACESTEP_CONVERTED_ROOT) so the cache can live on vepfs.

* audio_parity_jieyue: two-phase matrix bootstraps cover/repaint from text2music

The ACE-Step release bundle on jieyue doesn't ship sample .wav/.mp3
files, so matrix mode had no default --src-audio and would skip
cover/repaint entirely. Run text2music first for every variant, then
reuse the TURBO official text2music output as the shared source for the
cover/repaint rows. Users can still override with --src-audio.

* audio_parity_jieyue: seed the diffusers generator on the pipeline device

The ORIGINAL ACE-Step pipeline seeds on the execution device
(`torch.Generator(device=device).manual_seed(seed)`), i.e. the CUDA RNG
stream when running on GPU. Previously the parity harness seeded the
diffusers side with a CPU generator, so even though the seed integer
matched, the two sides drew different noise from the outset and the
final outputs were essentially uncorrelated. Use the execution-device
generator on both sides for a fair comparison.

* Fix ACE-Step pipeline: switch to APG guidance + peak normalization

Two issues found after the first jieyue audio parity run:

1. The original base/SFT pipeline uses APG (Adaptive Projected Guidance,
   acestep/models/common/apg_guidance.py) with a stateful momentum
   buffer and norm/projection steps — NOT vanilla CFG. Using vanilla CFG
   produced uncorrelated outputs vs. the reference (pearson ~0.0 on
   20 s samples); this PR ports `_apg_forward` + `_APGMomentumBuffer`
   and plugs them into the denoising loop when `guidance_scale > 1`.
   Momentum is instantiated once per pipeline call (persists across
   denoising steps) to match the reference semantics.

2. The post-VAE "anti-clipping normalization" in this pipeline was
   `audio /= std * 5` with a `std<1 -> std=1` guard. The original
   post-processing in
   acestep/core/generation/handler/generate_music_decode.py is simple
   peak normalization: `if audio.abs().max() > 1: audio /= peak`. The
   std-based proxy both (a) let clips with peak < 1 leak through
   unchanged (over-quiet) and (b) failed to bring clipping peaks to
   exactly 1 in a bunch of base/SFT cases (observed max=1.000, std=0.200
   repeatedly in the first parity run). Switch to peak normalization on
   both sides.

Tested via scripts/audio_parity_jieyue.py on A800; re-run pending to
confirm the base/SFT correlation improvements.

* Fix ACE-Step chunk mask values to match the original pipeline

The DiT receives `context_latents = concat(src_latents, chunk_mask)` on the
channel dim, and was trained with chunk_mask values drawn from the three
sentinels documented in acestep/inference.py:

  2.0 -> model-decided (default for text2music / cover / full-generation)
  1.0 -> keep this latent frame from src_latents (repaint preserved region)
  0.0 -> explicitly repaint this frame (only inside the repaint window)

Previously _build_chunk_mask returned all-1.0 for text2music (and cover /
lego), and an inverted 0/1 mask for repaint (1 inside the window, 0 outside).
Either case puts context_latents out of distribution. Switch text2music /
cover to the 2.0 sentinel and flip the repaint mask so it's 1.0 outside /
0.0 inside. Update the repaint src_latents zero-out to multiply by the new
mask (was `1 - chunk_mask`) so the zero region still lines up with the
repaint window.

* Add direct invoker for ACE-Step generate_music (ground truth)

Our earlier audio_parity_jieyue.py reconstructs the original pipeline by
calling AceStepConditionGenerationModel.generate_audio() directly, which
silently skips a lot of the real handler plumbing (conditioning masks,
silence-latent tiling, cover/repaint pre-processing, etc.). That made the
'official' wavs we saved sound wrong — flat, drone-like, not music.

This new script calls acestep.inference.generate_music end-to-end through
the real AceStepHandler, with LM + CoT explicitly disabled so we still have
a deterministic comparison. Use it to generate the ground-truth 'official'
wav for a given JSON example, then separately run the diffusers pipeline
with the same inputs and diff the two.

* run_official_generate_music: call initialize_service to bind a DiT variant

AceStepHandler() is a shell — you have to call handler.initialize_service(
project_root=..., config_path=..., device=..., use_flash_attention=..., ...)
before generate_music will work. Mirror what cli.py does at the equivalent
spot (around cli.py:1400).

* Fix silence-reference for ACE-Step timbre encoder

The root cause for the flat / drone-like outputs I was seeing (including
in my 'official' reconstruction): when no reference_audio is provided the
pipeline was feeding literal zeros to the timbre encoder. The real
handler feeds a slice of the learned `silence_latent` tensor.

The handler also transposes silence_latent on load (see
acestep/core/generation/handler/init_service_loader.py:214:
  self.silence_latent = torch.load(...).transpose(1, 2)
) converting [1, 64, 15000] -> [1, 15000, 64] so that
`silence_latent[:, :750, :]` yields the expected [1, 750, 64] shape.

Changes:

* Converter: load silence_latent.pt, transpose to [1, T, C], bake into
  the condition_encoder safetensors under key `silence_latent`.
  (Also keeps the raw .pt file at the pipeline root for debugging.)
* AceStepConditionEncoder: register `silence_latent` as a persistent
  buffer so from_pretrained loads it alongside the trained weights.
* Pipeline: when reference_audio is None, slice
  `condition_encoder.silence_latent[:, :timbre_fix_frame, :]` and
  broadcast across the batch instead of zeros. Emits a loud warning
  (and falls back to zeros) if the buffer is all-zero — that means the
  checkpoint was produced by an older converter and should be rebuilt.
* audio_parity_jieyue.py: the reference path now matches the handler's
  silence-latent slicing.

Without this fix, every variant/task combo produced drone-like audio
even when my numeric DiT-forward parity claimed they were identical.

* Fix three more ACE-Step pipeline bugs I found by dumping real inputs

Instrumented the live generate_audio call in the real ACE-Step handler and
observed the exact tensors it sees — my diffusers pipeline was wrong in
three independent ways:

1. src_latents for text2music should be silence_latent tiled to
   latent_length, NOT zeros. The handler fills no-target cases from
   silence_latent_tiled (observed std=0.96). Zeros are OOD for the DiT
   context_latents concat and produce drone-like outputs.

2. chunk_mask values cap at 1.0 (not 2.0). The handler starts with a
   bool tensor (True inside the generate span, False outside); the
   chunk_mask_modes=auto -> 2.0 override does NOT take effect because
   the underlying tensor is bool, so setting entry = 2.0 casts to True.
   After the later .to(dtype) float cast, the DiT sees 1.0/0.0 — exactly
   what I observed in the captured tensor (unique values = [True]).

3. Default shift is 1.0 for ALL variants, including turbo. I was
   defaulting turbo to shift=3.0 which picks a different SHIFT_TIMESTEPS
   table (the 8-step schedule is keyed by shift, not variant).

Also:
* Added _silence_latent_tiled() helper that slices / tiles the learned
  silence_latent (now loaded as a buffer on the condition encoder) to
  the requested latent length.
* Repaint path now substitutes silence_latent (not raw zeros) inside
  the repaint window — matches conditioning_masks.py.
* audio_parity_jieyue.py mirrors the same src/chunk/shift choices on
  its 'original' leg for apples-to-apples parity once the buggy
  reconstruction is removed from the picture.

* Add peak+loudness post-normalization to AceStepPipeline

The real pipeline normalizes audio in two stages (see
acestep/audio_utils.py:72 normalize_audio + generate_music_decode.py):
  1. if peak > 1: audio /= peak  (anti-clip)
  2. audio *= target_amp / peak   (target_amp = 10 ** (-1/20) ~ 0.891)

Step 2 is loudness normalization to -1 dBFS. Without it diffusers outputs
had peak=1.0 vs the real 0.891 — same music content (pearson was ~0.86
already), just 1.12x louder. Add step 2 after the existing anti-clip step.

* Match acestep/inference.py inference_steps=8 for ALL variants

GenerationParams.inference_steps default is 8 — turbo AND base/SFT. I had
base/SFT defaulting to 27 here, so every base/SFT parity run was comparing
a 27-step diffusers trajectory against an 8-step real trajectory. Different
number of denoising steps means different audio even at fixed seed.

This likely explains the lower base/SFT correlation in my earlier jieyue
runs (turbo was 0.86, base/SFT were 0.32-0.34). Aligning step counts
should bring base/SFT closer to turbo parity.

* Address PR #13095 review: rename classes + reuse diffusers primitives

Response to dg845's PR comments batch 1+2. DiT parity harness still bit-identical
(max_abs=0 on fp32 / SDPA across 4 shape cases).

Transformer file:
* Rename AceStepDiTModel -> AceStepTransformer1DModel (alias kept).
* Rename AceStepDiTLayer -> AceStepTransformerBlock (alias kept).
* Inherit AttentionMixin + CacheMixin on the DiT model.
* Swap in diffusers.models.normalization.RMSNorm for the hand-rolled
  AceStepRMSNorm (weight-key-compatible).
* Swap the hand-rolled rotary embedding + apply_rotary for diffusers'
  get_1d_rotary_pos_embed + apply_rotary_emb (use_real_unbind_dim=-2 to
  match the cat-half convention ACE-Step inherits from Qwen3).
* Use get_timestep_embedding with flip_sin_to_cos=True — keeps the
  (cos, sin) ordering of the original sinusoidal. State-dict-compatible.
* Drop max_position_embeddings arg from DiT config (RoPE computes freqs
  per call based on seq_len); converter drops it.
* Gradient-checkpoint call now takes just the layer module (matches the
  Flux2 idiom).

Pipeline modeling file (pipelines/ace_step/modeling_ace_step.py):
* Moved _pack_sequences + AceStepEncoderLayer here — they aren't used
  by the DiT, so they shouldn't live in the transformer file.
* AceStepLyricEncoder + AceStepTimbreEncoder set
  _supports_gradient_checkpointing = True and wrap encoder-layer calls
  through the checkpointing func when enabled.
* Use diffusers RMSNorm + the RoPE helper from the transformer file
  (shared single implementation).

Converter (scripts/convert_ace_step_to_diffusers.py):
* model_index.json now carries AceStepTransformer1DModel.
* Drop max_position_embeddings / use_sliding_window from the emitted
  configs.

No numerical regressions: scripts/dit_parity_test.py PASSES with
max_abs=0.0 on fp32/SDPA across short, long, batched, and
padding-path shape variants.

* Address PR #13095 review: pipeline polish + converter HF-hub support

Response to dg845 review comments on the pipeline side. DiT parity still
bit-identical (max_abs=0 across 4 shape cases).

Pipeline (pipelines/ace_step/pipeline_ace_step.py):
* Add `sample_rate` + `latents_per_second` properties sourced from the
  VAE config so the pipeline no longer hardcodes 48000 / 25 / 1920.
  Propagates through prepare_latents, chunk_mask window math, and the
  audio-duration round-trip.
* Add `do_classifier_free_guidance` property (matches LTX2 et al.).
* Add `check_inputs(...)` called from `__call__` before allocating noise.
  Validates prompt type, lyrics type, task_type, step count, guidance
  scale, shift, cfg interval bounds and repaint window ordering.
* Add `callback_on_step_end` + `callback_on_step_end_tensor_inputs` —
  the modern callback form. The legacy `callback` / `callback_steps`
  pair is kept for back-compat. Setting `pipe._interrupt = True` inside
  the callback stops the loop early.
* Expose `encode_audio(audio)` as a public helper that wraps the tiled
  VAE encode + (B, T, D) transpose the pipeline performs internally.

Converter (scripts/convert_ace_step_to_diffusers.py):
* Accept a Hugging Face Hub repo id for `--checkpoint_dir`; resolves it
  via `huggingface_hub.snapshot_download` when the argument isn't a
  local path.

Exports:
* Register `AceStepTransformer1DModel` in the top-level __init__,
  models/__init__, models/transformers/__init__, and dummy_pt_objects so
  `from diffusers import AceStepTransformer1DModel` works and the
  pipeline loader resolves the new class name from model_index.json.

Deferred for a follow-up (commented inline in the PR): full
`Attention + AttnProcessor + dispatch_attention_fn` refactor and
`FlowMatchEulerDiscreteScheduler` migration — both would benefit from a
dedicated parity re-run and review.

* Fix stale ACE-Step 1.0-era docs / class names in the 1.5 integration

Docs and docstrings still carried a mix of 1.0 paper title, non-existent
`ACE-Step/ACE-Step-v1-5-turbo` hub id, `shift=3.0` turbo default, and
the old `AceStepDiTModel` class name. Cleaned up to match the actual
1.5 release:

* pipelines/ace_step.md: correct citation title ("ACE-Step 1.5: Pushing
  the Boundaries of Open-Source Music Generation"), correct repo
  (`ace-step/ACE-Step-1.5`), new variants table with real HF ids
  (`Ace-Step1.5` / `acestep-v15-base` / `acestep-v15-sft`) and their
  per-variant step/CFG defaults, drop the wrong `shift=3.0` tip.
* models/ace_step_transformer.md: page renamed to
  `AceStepTransformer1DModel` with a short 1.5-specific description;
  `AceStepDiTModel` noted as a backwards-compat alias.
* pipeline_ace_step.py: import, docstring, `Args`, and `__init__`
  annotation reference `AceStepTransformer1DModel`; example model id
  now `ACE-Step/Ace-Step1.5`; `_variant_defaults` docstring and the
  `__call__` variant-fallback comment no longer claim `shift=3.0` /
  `27 steps` — real defaults are 8 steps / shift=1.0 across all
  variants, guidance=1.0 (turbo) vs 7.0 (base+sft).

* Address PR #13095 review: VAE tiling on AutoencoderOobleck + Timesteps class

Two more deferred review threads from dg845 addressed:

* Move tiled encode/decode onto AutoencoderOobleck
  (https://github.com/huggingface/diffusers/pull/13095#discussion_r2785513647).
  AutoencoderOobleck now carries `use_tiling` + `tile_sample_min_length` /
  `tile_sample_overlap` / `tile_latent_min_length` / `tile_latent_overlap`
  attributes and private `_tiled_encode` / `_tiled_decode` methods; the
  existing `encode` / `_decode` dispatch to them when tiling is enabled and
  the input exceeds the threshold. `AutoencoderMixin.enable_tiling()` is
  already inherited.

  AceStepPipeline's private `_tiled_encode` / `_tiled_decode` and the
  `use_tiled_decode` `__call__` arg are gone; `__init__` now calls
  `self.vae.enable_tiling()` so the long-audio memory behaviour is preserved
  by default. Users can opt out with `pipe.vae.disable_tiling()`.

  Note: the VAE-side tiling concatenates encoder features (h) and samples
  the posterior once, instead of the old per-tile `.sample()` calls. This
  is the standard diffusers pattern; numerically differs only in the
  structure of the noise across tile boundaries.

* Use the Timesteps nn.Module for the sinusoid
  (https://github.com/huggingface/diffusers/pull/13095#discussion_r2785420234).
  `AceStepTimestepEmbedding` wraps `Timesteps(in_channels, flip_sin_to_cos=
  True, downscale_freq_shift=0)` instead of calling `get_timestep_embedding`
  directly — reviewer asked for the Module form.

* Address PR #13095 review: refactor AceStepAttention to Attention + AttnProcessor

Splits the monolithic AceStepAttention into the diffusers standard
Attention + AttnProcessor layout:
  - AceStepAttention (torch.nn.Module, AttentionModuleMixin) holds the
    to_q/to_k/to_v/to_out projections and norm_q/norm_k RMSNorms.
  - AceStepAttnProcessor2_0 runs the attention dispatch through
    dispatch_attention_fn so users can pick flash / sage / native backends
    via model.set_attention_backend(...) or the attention_backend context
    manager.

GQA (Q has 16 heads / K,V have 8) is preserved by passing enable_gqa=True
to dispatch_attention_fn instead of repeat_interleave; fusion is disabled
(_supports_qkv_fusion = False) because Q and K,V have different output
sizes.

The converter is updated to rename the six attention sub-keys
(q_proj -> to_q, k_proj -> to_k, v_proj -> to_v, o_proj -> to_out.0,
q_norm -> norm_q, k_norm -> norm_k) on both the DiT decoder path and the
condition encoder path, since AceStepLyricEncoder / AceStepTimbreEncoder
share the same AceStepAttention class.

Addresses review comments r2785433213 and r2785450463.

* Address PR #13095 review: migrate to FlowMatchEulerDiscreteScheduler

Replace the hand-rolled flow-matching Euler loop with
`FlowMatchEulerDiscreteScheduler`. ACE-Step still computes its own shifted /
turbo sigma schedule via `_get_timestep_schedule`, but now passes it to
`scheduler.set_timesteps(sigmas=...)` and delegates the ODE step to
`scheduler.step()`. The scheduler is configured with `num_train_timesteps=1`
and `shift=1.0` so `scheduler.timesteps` stays in `[0, 1]` (the convention the
DiT was trained on) and the scheduler doesn't re-shift already-shifted sigmas.

The scheduler's appended terminal `sigma=0` reproduces the old loop's
final-step "project to x0" case exactly: `prev = x + (0 - t_curr) * v`.

Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music, 8 steps):
  waveform Pearson = 0.999999
  spectral Pearson = 1.000000
  max |diff|       = 2.5e-3  (fp32 step-math vs previous bf16 step-math)

fp32 Euler-loop A/B against the hand-rolled path: max |diff| = 3.6e-7.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address PR #13095 review: move DiT tests + drop stale test kwargs

- Move the DiT transformer tests out of the pipeline test file into a new
  tests/models/transformers/test_models_transformer_ace_step.py that follows
  the standard BaseModelTesterConfig + ModelTesterMixin scaffold (matches
  test_models_transformer_longcat_audio_dit.py).
- Drop `max_position_embeddings` from the remaining AceStepDiTModel and
  AceStepConditionEncoder test fixtures — neither constructor accepts that
  argument anymore.
- Drop `use_sliding_window` from the same fixtures — also no longer a
  constructor argument (the actual `sliding_window` int kwarg is kept).
- Wire `FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)`
  into `get_dummy_components()` now that the pipeline requires it.

Resolves https://github.com/huggingface/diffusers/pull/13095#discussion_r3115653554,
r3115664850, r3115673059, r3115676580, r3115680700.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address PR #13095 review from dg845 (2026-04-23)

Fixes 5 review threads + style:

1. Converter now builds `AceStepPipeline` in memory and calls
   `save_pretrained`. Previously the hand-written `model_index.json` was
   missing the `scheduler` entry — fresh converter output couldn't be loaded
   by `AceStepPipeline.from_pretrained` (r3127767785). This also makes the
   converter robust to future `__init__` signature changes.

2. `latent_length` uses `math.ceil(...)` instead of `int(...)` so non-integer
   products (e.g. `latents_per_second=2.0, audio_duration=0.4 → 0.8`) round up
   to `1` instead of truncating to `0` and crashing shape checks (r3127790939).

3. Add `_callback_tensor_inputs = ["latents"]` on `AceStepPipeline` so the
   standard diffusers callback tests pick up the right tensor (r3127795954).

4. `AceStepConditionEncoder.silence_latent` no longer hard-codes the channel
   dim to 64. The placeholder buffer now uses the `timbre_hidden_dim`
   constructor argument, so smaller test configs with `timbre_hidden_dim != 64`
   load without shape errors (r3127812932).

5. Revert `self.vae.enable_tiling()` from `AceStepPipeline.__init__`. Users can
   call `pipe.vae.enable_tiling()` themselves for long-form generation; that
   matches the opt-in convention used by the rest of diffusers (r3127777296).

6. `ruff check --fix` + `ruff format` over all ACE-Step sources (the style fix
   dg845 asked for via `@bot /style`).

Also: converter now accepts sharded `model.safetensors.index.json` layouts
alongside the single-file `model.safetensors`, so the 5B XL turbo variant
converts without a pre-processing step.

Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music 160s, fresh
converter output loaded via `from_pretrained`):
  waveform Pearson  = 0.999954
  spectral Pearson  = 0.999977
  max |a-b| bf16    = 4.3e-02  (dominated by the VAE tiling default flip)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address PR #13095 review from yiyixuxu (2026-04-23)

Code-level (22 threads):

1. Delete 3 dev/parity scripts (`scripts/audio_parity_jieyue.py`,
   `scripts/dit_parity_test.py`, `scripts/run_official_generate_music.py`)
   that shouldn't have been committed.
2. Rename `AutoencoderOobleck._encode_one` → `_encode` to match the convention
   used by other diffusers VAEs.
3. Delete the hard-coded `SHIFT_TIMESTEPS` / `VALID_SHIFTS` table in
   `pipeline_ace_step.py`: the per-shift turbo schedules are recovered
   exactly by `linspace(1, 0, N+1)[:-1]` plus the flow-match shift formula
   that the non-turbo branch already uses, so a single code path covers both.
4. Drop the backwards-compat `AceStepDiTModel` / `AceStepDiTLayer` aliases
   and every reference (top-level `__init__`, `models/__init__`,
   `transformers/__init__`, dummy objects, tests, docs toctree, model card).
   `AceStepTransformer1DModel` is the only exported name now.
5. Remove the unused `attention_mask` / `encoder_attention_mask` args from
   `AceStepTransformer1DModel.forward`; the model rebuilds its masks from
   the sequence shape and never consumed them.
6. In the DiT forward and both encoders, pass `None` instead of an all-zero
   `full_attn_mask` / `encoder_4d_mask` to non-sliding attention layers — SDPA
   dispatches to a faster kernel when the mask is None.
7. Inline the shared `_run_encoder_layers` helper directly into
   `AceStepLyricEncoder.forward` / `AceStepTimbreEncoder.forward` so layer
   calls are visible at the forward boundary (diffusers style).
8. Move `is_turbo` / `sample_rate` / `latents_per_second` from `@property`s
   that re-read module configs each call to cached attributes populated in
   `__init__` (Flux2-style), with a default-ACE-Step fallback when
   `self.vae` is offloaded. Drop the now-unused `SAMPLE_RATE = 48000`
   module-level constant and the three property definitions.
9. Warn + coerce `guidance_scale` to 1.0 on turbo (guidance-distilled)
   checkpoints, following `pipeline_flux2_klein`. Prevents over-guided
   audio when users forward their base/sft CFG settings to a turbo pipe.
10. Remove the `logger.warning(...)` paths that triggered on
    `silence_latent` missing/zero — those only fired for author-side
    unconverted checkpoints and tests; end users always load converted
    weights where the buffer is baked in.
11. Drop the redundant `with torch.no_grad():` wrappers inside
    `encode_prompt` — the pipeline's `__call__` runs under `torch.no_grad`
    already.
12. Strip "reviewer comment on PR #13095" attribution comments from three
    docstrings (here and everywhere).

Parity on jieyue (seed=42, bf16 + flash-attn, XL turbo 160s text2music):
  waveform Pearson = 0.9747
  spectral Pearson = 0.9895

The shift comes from full-attention layers switching `attn_mask=0_tensor` →
`attn_mask=None`, which dispatches to a different SDPA kernel on bf16. The
two outputs are algebraically equivalent for fp32 eager; on bf16+FA the
delta is dominated by kernel-level ULPs, well within the sampler-noise
band (ear-check on the 160s example confirms no audible regression).

Still open — AudioTokenizer/Detokenizer (deferred) + APG guider follow-up
(dims differ from `diffusers.guiders.adaptive_projected_guidance`, not a
drop-in; worth a separate PR).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address ACE-Step audio token and APG review

* Fix ACE-Step docs CI

* Address ACE-Step pipeline cleanup review

* Fix ACE-Step flash attention sliding windows

* Add ACE-Step callback properties

* Address ACE-Step final review comments

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
This commit is contained in:
Gong Junmin
2026-05-01 12:30:44 +08:00
committed by GitHub
parent 303c1d8b04
commit 1a8a17b71b
20 changed files with 4156 additions and 11 deletions

View File

@@ -324,6 +324,8 @@
title: SparseControlNetModel
title: ControlNets
- sections:
- local: api/models/ace_step_transformer
title: AceStepTransformer1DModel
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
@@ -488,6 +490,8 @@
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- sections:
- local: api/pipelines/ace_step
title: ACE-Step
- local: api/pipelines/audioldm2
title: AudioLDM 2
- local: api/pipelines/longcat_audio_dit

View File

@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AceStepTransformer1DModel
A 1D Diffusion Transformer for music generation from [ACE-Step 1.5](https://github.com/ace-step/ACE-Step-1.5). The model operates on the 25 Hz stereo latents produced by [`AutoencoderOobleck`] using flow matching, and is trained with a Qwen3-derived backbone (grouped-query attention, rotary position embedding, RMSNorm, AdaLN-Zero timestep conditioning) plus cross-attention to the text / lyric / timbre conditions built by `AceStepConditionEncoder`.
## AceStepTransformer1DModel
[[autodoc]] AceStepTransformer1DModel

View File

@@ -0,0 +1,72 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ACE-Step 1.5
ACE-Step 1.5 was introduced in [ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation](https://arxiv.org/abs/2602.00744) by the ACE-Step Team (ACE Studio and StepFun). It is an open-source music foundation model that generates commercial-grade stereo music with lyrics from text prompts.
ACE-Step 1.5 generates variable-length stereo audio at 48 kHz (10 seconds to 10 minutes) from text prompts and optional lyrics. The full system pairs a Language Model planner with a Diffusion Transformer (DiT) synthesizer; this pipeline wraps the DiT half of that stack, and consists of three components: an [`AutoencoderOobleck`] VAE that compresses waveforms into 25 Hz stereo latents, a Qwen3-based text encoder for prompt and lyric conditioning, and an [`AceStepTransformer1DModel`] DiT that operates in the VAE latent space using flow matching.
The model supports 50+ languages for lyrics — including English, Chinese, Japanese, Korean, French, German, Spanish, Italian, Portuguese, and Russian — and runs on consumer GPUs (under 4 GB of VRAM when offloaded).
This pipeline was contributed by the [ACE-Step Team](https://github.com/ace-step). The original codebase can be found at [ace-step/ACE-Step-1.5](https://github.com/ace-step/ACE-Step-1.5).
## Variants
ACE-Step 1.5 ships three DiT checkpoints that share the same transformer architecture but differ in guidance behavior; the pipeline auto-detects turbo checkpoints from the loaded transformer config and ignores CFG guidance for those guidance-distilled weights.
| Variant | CFG | Default steps | Default `guidance_scale` | Default `shift` | HF repo |
|---------|:---:|:-------------:|:------------------------:|:---------------:|---------|
| `turbo` (guidance-distilled) | off | 8 | ignored | 3.0 | [`ACE-Step/Ace-Step1.5`](https://huggingface.co/ACE-Step/Ace-Step1.5) |
| `base` | on | 8 | 7.0 | 3.0 | [`ACE-Step/acestep-v15-base`](https://huggingface.co/ACE-Step/acestep-v15-base) |
| `sft` | on | 8 | 7.0 | 3.0 | [`ACE-Step/acestep-v15-sft`](https://huggingface.co/ACE-Step/acestep-v15-sft) |
Base and SFT use the learned `null_condition_emb` for classifier-free guidance (APG, not vanilla CFG). Users commonly override `num_inference_steps` to 3060 on base/sft for higher quality.
## Tips
When constructing a prompt, keep in mind:
* Descriptive prompt inputs work best; use adjectives to describe the music style, instruments, mood, and tempo.
* The prompt should describe the overall musical characteristics (e.g., "upbeat pop song with electric guitar and drums").
* Lyrics should be structured with tags like `[verse]`, `[chorus]`, `[bridge]`, etc.
During inference:
* `num_inference_steps`, `guidance_scale`, and `shift` default to the values shown above. For turbo checkpoints, `guidance_scale > 1.0` is ignored with a warning because guidance is distilled into the weights.
* The `audio_duration` parameter controls the length of the generated music in seconds.
* The `vocal_language` parameter should match the language of the lyrics.
* `pipe.sample_rate` and `pipe.latents_per_second` are sourced from the VAE config (48000 Hz and 25 fps for the released checkpoints).
* For audio-to-audio tasks, pass `src_audio` and `reference_audio` as preprocessed stereo tensors at `pipe.sample_rate`.
* `flash` and `flash_hub` use FlashAttention's native sliding-window support for ACE-Step's self-attention and expect unpadded text batches. If a batched prompt contains padding, use `flash_varlen` or `flash_varlen_hub` instead. Single-prompt inference with `padding="longest"` is normally unpadded.
```python
import torch
import soundfile as sf
from diffusers import AceStepPipeline
pipe = AceStepPipeline.from_pretrained("ACE-Step/Ace-Step1.5", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
audio = pipe(
prompt="A beautiful piano piece with soft melodies and gentle rhythm",
lyrics="[verse]\nSoft notes in the morning light\nDancing through the air so bright\n[chorus]\nMusic fills the air tonight\nEvery note feels just right",
audio_duration=30.0,
).audios
sf.write("output.wav", audio[0].T.cpu().float().numpy(), pipe.sample_rate)
```
## AceStepPipeline
[[autodoc]] AceStepPipeline
- all
- __call__

View File

@@ -0,0 +1,454 @@
# Run this script to convert ACE-Step model weights to a diffusers pipeline.
#
# Usage:
# python scripts/convert_ace_step_to_diffusers.py \
# --checkpoint_dir /path/to/ACE-Step-1.5/checkpoints \
# --dit_config acestep-v15-turbo \
# --output_dir /path/to/output/ACE-Step-v1-5-turbo \
# --dtype bf16
import argparse
import json
import os
import shutil
import torch
from safetensors.torch import load_file
def convert_ace_step_weights(checkpoint_dir, dit_config, output_dir, dtype_str="bf16"):
"""
Convert ACE-Step checkpoint weights into a Diffusers-compatible pipeline layout.
The original ACE-Step model stores all weights in a single `model.safetensors` file
under `checkpoints/<dit_config>/`. This script splits the weights into separate
sub-model directories that can be loaded by `AceStepPipeline.from_pretrained()`.
Expected input layout:
checkpoint_dir/
<dit_config>/ # e.g., acestep-v15-turbo
config.json
model.safetensors
silence_latent.pt
vae/
config.json
diffusion_pytorch_model.safetensors
Qwen3-Embedding-0.6B/
config.json
model.safetensors
tokenizer.json
...
Output layout:
output_dir/
model_index.json
transformer/
config.json
diffusion_pytorch_model.safetensors
condition_encoder/
config.json
diffusion_pytorch_model.safetensors
vae/
config.json
diffusion_pytorch_model.safetensors
text_encoder/
config.json
model.safetensors
...
tokenizer/
tokenizer.json
...
"""
# Support `--checkpoint_dir <repo-id>` by snapshot-downloading it first. A
# local path that happens not to exist still raises the clearer FileNotFoundError
# below, so we only fall through to the Hub if the path is missing AND looks like
# a repo id (namespace/name).
if not os.path.exists(checkpoint_dir) and "/" in checkpoint_dir and not checkpoint_dir.startswith((".", "~", "/")):
try:
from huggingface_hub import snapshot_download
print(f"Downloading `{checkpoint_dir}` from the Hugging Face Hub ...")
checkpoint_dir = snapshot_download(repo_id=checkpoint_dir)
print(f" -> local snapshot at {checkpoint_dir}")
except ImportError as e:
raise ImportError(
"To use a Hugging Face Hub repo id for --checkpoint_dir, install `huggingface_hub`."
) from e
# Resolve paths
dit_dir = os.path.join(checkpoint_dir, dit_config)
vae_dir = os.path.join(checkpoint_dir, "vae")
text_encoder_dir = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
# The DiT weights ship either as a single `model.safetensors` (the smaller turbo
# variant) or as sharded safetensors keyed by `model.safetensors.index.json`
# (the 5B XL variant). Resolve both layouts to `dit_weight_files` and load below.
single_model_path = os.path.join(dit_dir, "model.safetensors")
sharded_index_path = os.path.join(dit_dir, "model.safetensors.index.json")
config_path = os.path.join(dit_dir, "config.json")
if os.path.exists(single_model_path):
dit_weight_files = [single_model_path]
elif os.path.exists(sharded_index_path):
with open(sharded_index_path) as f:
shard_index = json.load(f)
dit_weight_files = [os.path.join(dit_dir, s) for s in sorted(set(shard_index["weight_map"].values()))]
for p in dit_weight_files:
if not os.path.exists(p):
raise FileNotFoundError(f"sharded DiT weight missing: {p}")
else:
raise FileNotFoundError(
f"DiT weights not found at: {single_model_path} or {sharded_index_path}. "
"Expected either a single `model.safetensors` or a sharded "
"`model.safetensors.index.json` + per-shard files."
)
for path, name in [
(config_path, "config"),
(vae_dir, "VAE"),
(text_encoder_dir, "text encoder"),
]:
if not os.path.exists(path):
raise FileNotFoundError(f"{name} not found at: {path}")
# Select dtype
dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
if dtype_str not in dtype_map:
raise ValueError(f"Unsupported dtype: {dtype_str}. Choose from {list(dtype_map.keys())}")
target_dtype = dtype_map[dtype_str]
# Load original config
with open(config_path) as f:
original_config = json.load(f)
print(f"Loading DiT weights from {len(dit_weight_files)} file(s) ...")
state_dict = {}
for p in dit_weight_files:
print(f" loading {os.path.basename(p)}")
state_dict.update(load_file(p))
print(f" Total keys: {len(state_dict)}")
# =========================================================================
# 1. Split weights by prefix
# =========================================================================
transformer_sd = {}
condition_encoder_sd = {}
audio_tokenizer_sd = {}
audio_token_detokenizer_sd = {}
other_sd = {}
# Rename original ACE-Step attention keys to the diffusers `Attention` +
# `AttnProcessor` convention (`to_q`/`to_k`/`to_v`/`to_out.0`/`norm_q`/`norm_k`).
# Applies uniformly to both the DiT (self-attn and cross-attn) and the
# condition-encoder self-attention, since both use `AceStepAttention`.
_ATTN_KEY_RENAMES = [
(".q_proj.", ".to_q."),
(".k_proj.", ".to_k."),
(".v_proj.", ".to_v."),
(".o_proj.", ".to_out.0."),
(".q_norm.", ".norm_q."),
(".k_norm.", ".norm_k."),
]
def _rename_attn_keys(key: str) -> str:
for old, new in _ATTN_KEY_RENAMES:
key = key.replace(old, new)
return key
for key, value in state_dict.items():
if key.startswith("decoder."):
# Strip "decoder." prefix for the transformer
new_key = key[len("decoder.") :]
# The original model uses nn.Sequential for proj_in/proj_out:
# proj_in = Sequential(Lambda, Conv1d, Lambda)
# proj_out = Sequential(Lambda, ConvTranspose1d, Lambda)
# Only the Conv1d/ConvTranspose1d (index 1) has parameters.
# In diffusers, we use standalone Conv1d/ConvTranspose1d named proj_in_conv/proj_out_conv.
new_key = new_key.replace("proj_in.1.", "proj_in_conv.")
new_key = new_key.replace("proj_out.1.", "proj_out_conv.")
new_key = _rename_attn_keys(new_key)
transformer_sd[new_key] = value.to(target_dtype)
elif key.startswith("encoder."):
# Strip "encoder." prefix for the condition encoder
new_key = key[len("encoder.") :]
new_key = _rename_attn_keys(new_key)
condition_encoder_sd[new_key] = value.to(target_dtype)
elif key == "null_condition_emb":
# Learned unconditional embedding (used by the base/SFT CFG path).
# Keep it co-located with the condition encoder since that is where the
# pipeline pulls unconditional sequences from.
condition_encoder_sd["null_condition_emb"] = value.to(target_dtype)
elif key.startswith("tokenizer."):
new_key = key[len("tokenizer.") :]
new_key = _rename_attn_keys(new_key)
audio_tokenizer_sd[new_key] = value.to(target_dtype)
elif key.startswith("detokenizer."):
new_key = key[len("detokenizer.") :]
new_key = _rename_attn_keys(new_key)
audio_token_detokenizer_sd[new_key] = value.to(target_dtype)
else:
other_sd[key] = value.to(target_dtype)
print(f" Transformer keys: {len(transformer_sd)}")
print(f" Condition encoder keys: {len(condition_encoder_sd)}")
print(f" Audio tokenizer keys: {len(audio_tokenizer_sd)}")
print(f" Audio token detokenizer keys: {len(audio_token_detokenizer_sd)}")
print(f" Other keys: {len(other_sd)} ({list(other_sd.keys())[:5]}...)")
# =========================================================================
# 2. Build configs for each sub-model
# =========================================================================
# On the 5B XL turbo the condition encoder is narrower than the DiT
# (`encoder_hidden_size=2048` feeding a `hidden_size=2560` DiT). Non-XL
# turbo / base checkpoints don't set this field, so fall back to
# `hidden_size` — that makes the DiT's `condition_embedder` an identity-width
# Linear as before. Similarly `encoder_intermediate_size` /
# `encoder_num_attention_heads` / `encoder_num_key_value_heads` describe the
# condition encoder on XL only.
encoder_hidden_size = original_config.get("encoder_hidden_size", original_config["hidden_size"])
encoder_intermediate_size = original_config.get("encoder_intermediate_size", original_config["intermediate_size"])
encoder_num_attention_heads = original_config.get(
"encoder_num_attention_heads", original_config["num_attention_heads"]
)
encoder_num_key_value_heads = original_config.get(
"encoder_num_key_value_heads", original_config["num_key_value_heads"]
)
# Transformer (DiT) config. `is_turbo` / `model_version` propagate the variant so
# the pipeline can pick the right CFG / shift / step-count defaults at inference.
# Note: `max_position_embeddings` is dropped (RoPE computes freqs on-the-fly per call),
# and `use_sliding_window` is implied by the mix of `layer_types`.
transformer_config = {
"_class_name": "AceStepTransformer1DModel",
"_diffusers_version": "0.33.0.dev0",
"hidden_size": original_config["hidden_size"],
"intermediate_size": original_config["intermediate_size"],
"num_hidden_layers": original_config["num_hidden_layers"],
"num_attention_heads": original_config["num_attention_heads"],
"num_key_value_heads": original_config["num_key_value_heads"],
"head_dim": original_config["head_dim"],
"in_channels": original_config["in_channels"],
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
"patch_size": original_config["patch_size"],
"rope_theta": original_config["rope_theta"],
"attention_bias": original_config["attention_bias"],
"attention_dropout": original_config["attention_dropout"],
"rms_norm_eps": original_config["rms_norm_eps"],
"sliding_window": original_config["sliding_window"],
"layer_types": original_config["layer_types"],
"encoder_hidden_size": encoder_hidden_size,
"is_turbo": bool(original_config.get("is_turbo", False)),
"model_version": original_config.get("model_version"),
}
# Condition encoder config
condition_encoder_config = {
"_class_name": "AceStepConditionEncoder",
"_diffusers_version": "0.33.0.dev0",
"hidden_size": encoder_hidden_size,
"intermediate_size": encoder_intermediate_size,
"text_hidden_dim": original_config["text_hidden_dim"],
"timbre_hidden_dim": original_config["timbre_hidden_dim"],
"num_lyric_encoder_hidden_layers": original_config["num_lyric_encoder_hidden_layers"],
"num_timbre_encoder_hidden_layers": original_config["num_timbre_encoder_hidden_layers"],
"num_attention_heads": encoder_num_attention_heads,
"num_key_value_heads": encoder_num_key_value_heads,
"head_dim": original_config["head_dim"],
"rope_theta": original_config["rope_theta"],
"attention_bias": original_config["attention_bias"],
"attention_dropout": original_config["attention_dropout"],
"rms_norm_eps": original_config["rms_norm_eps"],
"sliding_window": original_config["sliding_window"],
}
audio_tokenizer_config = {
"_class_name": "AceStepAudioTokenizer",
"_diffusers_version": "0.33.0.dev0",
"hidden_size": encoder_hidden_size,
"intermediate_size": encoder_intermediate_size,
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
"pool_window_size": original_config.get("pool_window_size", 5),
"fsq_dim": original_config.get("fsq_dim", encoder_hidden_size),
"fsq_input_levels": original_config.get("fsq_input_levels", [8, 8, 8, 5, 5, 5]),
"fsq_input_num_quantizers": original_config.get("fsq_input_num_quantizers", 1),
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
"num_attention_heads": encoder_num_attention_heads,
"num_key_value_heads": encoder_num_key_value_heads,
"head_dim": original_config["head_dim"],
"rope_theta": original_config["rope_theta"],
"attention_bias": original_config["attention_bias"],
"attention_dropout": original_config["attention_dropout"],
"rms_norm_eps": original_config["rms_norm_eps"],
"sliding_window": original_config["sliding_window"],
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
}
audio_token_detokenizer_config = {
"_class_name": "AceStepAudioTokenDetokenizer",
"_diffusers_version": "0.33.0.dev0",
"hidden_size": encoder_hidden_size,
"intermediate_size": encoder_intermediate_size,
"audio_acoustic_hidden_dim": original_config["audio_acoustic_hidden_dim"],
"pool_window_size": original_config.get("pool_window_size", 5),
"num_attention_pooler_hidden_layers": original_config.get("num_attention_pooler_hidden_layers", 2),
"num_attention_heads": encoder_num_attention_heads,
"num_key_value_heads": encoder_num_key_value_heads,
"head_dim": original_config["head_dim"],
"rope_theta": original_config["rope_theta"],
"attention_bias": original_config["attention_bias"],
"attention_dropout": original_config["attention_dropout"],
"rms_norm_eps": original_config["rms_norm_eps"],
"sliding_window": original_config["sliding_window"],
"layer_types": original_config["layer_types"][: original_config.get("num_attention_pooler_hidden_layers", 2)],
}
# =========================================================================
# 3. Bake silence_latent into the condition_encoder state dict.
#
# The original loader in
# acestep/core/generation/handler/init_service_loader.py:214 does
# self.silence_latent = torch.load(...).transpose(1, 2)
# converting the stored [B, C=64, T=15000] tensor to [B, T, C=64] before any
# downstream slicing. Do the same transpose here and register it as the
# `silence_latent` buffer on AceStepConditionEncoder — the pipeline slices
# `silence_latent[:, :timbre_fix_frame, :]` to build the "silence" input to the
# timbre encoder when no reference audio is supplied. Passing literal zeros
# produces drone-like audio.
silence_latent_src = os.path.join(dit_dir, "silence_latent.pt")
if os.path.exists(silence_latent_src):
silence_raw = torch.load(silence_latent_src, weights_only=True, map_location="cpu")
silence_latent = silence_raw.transpose(1, 2).to(target_dtype).contiguous()
print(f" silence_latent raw shape: {tuple(silence_raw.shape)} -> baked shape: {tuple(silence_latent.shape)}")
condition_encoder_sd["silence_latent"] = silence_latent
# =========================================================================
# 4. Build the AceStepPipeline in memory and save via `save_pretrained`.
# Assembling the pipeline directly (rather than hand-writing model_index.json)
# ensures the saved repo stays in sync with the `AceStepPipeline.__init__`
# signature — e.g. a future sub-module added to the pipeline can't silently
# drift out of `model_index.json`.
# =========================================================================
from transformers import AutoModel, AutoTokenizer
from diffusers import (
AceStepPipeline,
AceStepTransformer1DModel,
AutoencoderOobleck,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.pipelines.ace_step import (
AceStepAudioTokenDetokenizer,
AceStepAudioTokenizer,
AceStepConditionEncoder,
)
# Drop metadata keys — they're re-populated by `save_pretrained` at save time.
transformer_init_kwargs = {k: v for k, v in transformer_config.items() if not k.startswith("_")}
condition_encoder_init_kwargs = {k: v for k, v in condition_encoder_config.items() if not k.startswith("_")}
audio_tokenizer_init_kwargs = {k: v for k, v in audio_tokenizer_config.items() if not k.startswith("_")}
audio_token_detokenizer_init_kwargs = {
k: v for k, v in audio_token_detokenizer_config.items() if not k.startswith("_")
}
print("\nConstructing transformer ...")
transformer = AceStepTransformer1DModel(**transformer_init_kwargs).to(target_dtype)
transformer.load_state_dict(transformer_sd, strict=True)
print("Constructing condition_encoder ...")
condition_encoder = AceStepConditionEncoder(**condition_encoder_init_kwargs).to(target_dtype)
condition_encoder.load_state_dict(condition_encoder_sd, strict=True)
print("Constructing audio_tokenizer ...")
audio_tokenizer = AceStepAudioTokenizer(**audio_tokenizer_init_kwargs).to(target_dtype)
audio_tokenizer.load_state_dict(audio_tokenizer_sd, strict=True)
print("Constructing audio_token_detokenizer ...")
audio_token_detokenizer = AceStepAudioTokenDetokenizer(**audio_token_detokenizer_init_kwargs).to(target_dtype)
audio_token_detokenizer.load_state_dict(audio_token_detokenizer_sd, strict=True)
print("Loading VAE ...")
vae = AutoencoderOobleck.from_pretrained(vae_dir).to(target_dtype)
print("Loading text encoder ...")
text_encoder = AutoModel.from_pretrained(text_encoder_dir, torch_dtype=target_dtype)
print("Loading tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
# ACE-Step drives the DiT with t ∈ [0, 1] and computes its own shifted / turbo
# sigma schedule, which it passes to `scheduler.set_timesteps(sigmas=...)` at
# sampling time. So the scheduler needs `num_train_timesteps=1` (so
# `scheduler.timesteps == sigmas`) and `shift=1.0` (so it doesn't re-shift
# already-shifted sigmas). All other defaults are fine.
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)
pipe = AceStepPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
condition_encoder=condition_encoder,
scheduler=scheduler,
audio_tokenizer=audio_tokenizer,
audio_token_detokenizer=audio_token_detokenizer,
)
print(f"\nSaving pipeline -> {output_dir}")
pipe.save_pretrained(output_dir, safe_serialization=True, max_shard_size="5GB")
# Keep the raw silence_latent.pt at the pipeline root for debugging — not
# required by `from_pretrained`, but makes it easy to re-derive the buffer
# without re-running the full conversion.
if os.path.exists(silence_latent_src):
shutil.copy2(silence_latent_src, os.path.join(output_dir, "silence_latent.pt"))
print(f" kept raw silence_latent copy at {output_dir}/silence_latent.pt")
# Report any keys that were not saved to registered pipeline modules.
if other_sd:
print(f"\nNote: {len(other_sd)} keys were dropped:")
for key in sorted(other_sd.keys())[:10]:
print(f" {key}")
if len(other_sd) > 10:
print(f" ... ({len(other_sd) - 10} more)")
print(f"\nConversion complete! Output saved to: {output_dir}")
print("\nTo load the pipeline:")
print(" from diffusers import AceStepPipeline")
print(f" pipe = AceStepPipeline.from_pretrained('{output_dir}', torch_dtype=torch.bfloat16)")
print(" pipe = pipe.to('cuda')")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert ACE-Step model weights to Diffusers pipeline format")
parser.add_argument(
"--checkpoint_dir",
type=str,
required=True,
help="Path to the ACE-Step checkpoints directory (containing vae/, Qwen3-Embedding-0.6B/, and dit config dirs)",
)
parser.add_argument(
"--dit_config",
type=str,
default="acestep-v15-turbo",
help="Name of the DiT config directory (default: acestep-v15-turbo)",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Path to save the converted Diffusers pipeline",
)
parser.add_argument(
"--dtype",
type=str,
default="bf16",
choices=["fp32", "fp16", "bf16"],
help="Data type for saved weights (default: bf16)",
)
args = parser.parse_args()
convert_ace_step_weights(
checkpoint_dir=args.checkpoint_dir,
dit_config=args.dit_config,
output_dir=args.output_dir,
dtype_str=args.dtype,
)

View File

@@ -188,6 +188,7 @@ else:
]
_import_structure["models"].extend(
[
"AceStepTransformer1DModel",
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionBackendName",
@@ -488,6 +489,10 @@ else:
)
_import_structure["pipelines"].extend(
[
"AceStepAudioTokenDetokenizer",
"AceStepAudioTokenizer",
"AceStepConditionEncoder",
"AceStepPipeline",
"AllegroPipeline",
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
@@ -1000,6 +1005,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VaeImageProcessorLDM3D,
)
from .models import (
AceStepTransformer1DModel,
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionBackendName,
@@ -1277,6 +1283,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageModularPipeline,
)
from .pipelines import (
AceStepAudioTokenDetokenizer,
AceStepAudioTokenizer,
AceStepConditionEncoder,
AceStepPipeline,
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,

View File

@@ -40,6 +40,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
adaptive_projected_guidance_norm_dim (`int` or `tuple[int]`, *optional*):
Dimension(s) over which to compute the APG norm and projection. If omitted, all non-batch dimensions are
used, preserving the original behavior.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
@@ -62,6 +65,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: float | None = None,
adaptive_projected_guidance_rescale: float = 15.0,
adaptive_projected_guidance_norm_dim: int | tuple[int, ...] | None = None,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
@@ -74,6 +78,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.adaptive_projected_guidance_norm_dim = adaptive_projected_guidance_norm_dim
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
@@ -117,6 +122,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
self.adaptive_projected_guidance_norm_dim,
)
if self.guidance_rescale > 0.0:
@@ -210,9 +216,15 @@ def normalized_guidance(
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
norm_dim: int | tuple[int, ...] | None = None,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if norm_dim is None:
dim = [-i for i in range(1, len(diff.shape))]
elif isinstance(norm_dim, int):
dim = [norm_dim]
else:
dim = list(norm_dim)
if momentum_buffer is not None:
momentum_buffer.update(diff)
@@ -224,11 +236,15 @@ def normalized_guidance(
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
if diff.device.type in {"mps", "npu"}:
v0, v1 = diff.cpu().double(), pred_cond.cpu().double()
else:
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
diff_parallel = v0_parallel.to(device=diff.device, dtype=diff.dtype)
diff_orthogonal = v0_orthogonal.to(device=diff.device, dtype=diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond

View File

@@ -79,6 +79,7 @@ if is_torch_available():
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.ace_step_transformer"] = ["AceStepTransformer1DModel"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
_import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"]
@@ -209,6 +210,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
AceStepTransformer1DModel,
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
BriaFiboTransformer2DModel,

View File

@@ -1091,14 +1091,14 @@ def _flash_attention_forward_op(
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
# Hardcoded for now
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
@@ -1191,6 +1191,8 @@ def _flash_attention_hub_forward_op(
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
@@ -1209,7 +1211,6 @@ def _flash_attention_hub_forward_op(
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
@@ -2453,6 +2454,7 @@ def _flash_attention(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
window_size: tuple[int, int] = (-1, -1),
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
@@ -2468,11 +2470,13 @@ def _flash_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
forward_op = functools.partial(_flash_attention_forward_op, window_size=window_size)
out = _templated_context_parallel_attention(
query,
key,
@@ -2483,7 +2487,7 @@ def _flash_attention(
scale,
False,
return_lse,
forward_op=_flash_attention_forward_op,
forward_op=forward_op,
backward_op=_flash_attention_backward_op,
_parallel_config=_parallel_config,
)
@@ -2506,6 +2510,7 @@ def _flash_attention_hub(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
window_size: tuple[int, int] = (-1, -1),
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
@@ -2522,11 +2527,13 @@ def _flash_attention_hub(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
forward_op = functools.partial(_flash_attention_hub_forward_op, window_size=window_size)
out = _templated_context_parallel_attention(
query,
key,
@@ -2537,7 +2544,7 @@ def _flash_attention_hub(
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
forward_op=forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
@@ -2560,6 +2567,7 @@ def _flash_varlen_attention_hub(
dropout_p: float = 0.0,
scale: float | None = None,
is_causal: bool = False,
window_size: tuple[int, int] = (-1, -1),
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
@@ -2597,6 +2605,7 @@ def _flash_varlen_attention_hub(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
@@ -2616,6 +2625,7 @@ def _flash_varlen_attention(
dropout_p: float = 0.0,
scale: float | None = None,
is_causal: bool = False,
window_size: tuple[int, int] = (-1, -1),
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
@@ -2652,6 +2662,7 @@ def _flash_varlen_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))

View File

@@ -355,6 +355,24 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
)
self.use_slicing = False
self.use_tiling = False
# 1D time-axis tiling defaults. `tile_sample_min_length` is the raw-audio
# threshold (in samples) above which `encode` splits the input; chunks are
# `tile_sample_min_length` wide with `tile_sample_overlap` samples of overlap
# on each side, trimmed back out after decoding. `tile_latent_min_length`
# is the equivalent threshold on the decode side, expressed in latent frames.
self.tile_sample_min_length = sampling_rate * 30 # 30 seconds
self.tile_sample_overlap = sampling_rate * 2 # 2 seconds per side
# Decode chunk is smaller than encode chunk because the decoder upsamples
# back to raw audio and is more VRAM-heavy per frame.
self.tile_latent_min_length = 512
self.tile_latent_overlap = 64
def _encode(self, x: torch.Tensor) -> torch.Tensor:
if self.use_tiling and x.shape[-1] > self.tile_sample_min_length:
return self._tiled_encode(x)
return self.encoder(x)
@apply_forward_hook
def encode(
@@ -373,10 +391,10 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
h = self._encode(x)
posterior = OobleckDiagonalGaussianDistribution(h)
@@ -385,14 +403,88 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
return AutoencoderOobleckOutput(latent_dist=posterior)
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a long audio waveform by splitting it into overlapping tiles along
the time axis and concatenating the resulting encoder features. Used to keep memory bounded regardless of clip
length. Not bit-identical to a single unsplit encode — each tile has its own receptive-field boundary — but the
overlap/trim scheme keeps the joined feature map smooth.
"""
_B, _C, S = x.shape
chunk = self.tile_sample_min_length
overlap = self.tile_sample_overlap
stride = chunk - 2 * overlap
if stride <= 0:
raise ValueError(
f"tile_sample_min_length ({chunk}) must be greater than 2 * tile_sample_overlap ({overlap})"
)
num_steps = math.ceil(S / stride)
tiles = []
hop = None
for i in range(num_steps):
core_start = i * stride
core_end = min(core_start + stride, S)
win_start = max(0, core_start - overlap)
win_end = min(S, core_end + overlap)
tile = self.encoder(x[:, :, win_start:win_end])
if hop is None:
hop = (win_end - win_start) / tile.shape[-1]
trim_l = int(round((core_start - win_start) / hop))
trim_r = int(round((win_end - core_end) / hop))
end_idx = tile.shape[-1] - trim_r if trim_r > 0 else tile.shape[-1]
tiles.append(tile[:, :, trim_l:end_idx])
return torch.cat(tiles, dim=-1)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> OobleckDecoderOutput | torch.Tensor:
dec = self.decoder(z)
if self.use_tiling and z.shape[-1] > self.tile_latent_min_length:
dec = self._tiled_decode(z)
else:
dec = self.decoder(z)
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
def _tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""Decode a long latent by splitting it into overlapping tiles along the
time axis, decoding each, and concatenating the audio tiles back together."""
_B, _C, T = z.shape
chunk = self.tile_latent_min_length
overlap = self.tile_latent_overlap
stride = chunk - 2 * overlap
if stride <= 0:
raise ValueError(
f"tile_latent_min_length ({chunk}) must be greater than 2 * tile_latent_overlap ({overlap})"
)
num_steps = math.ceil(T / stride)
tiles = []
upsample = None
for i in range(num_steps):
core_start = i * stride
core_end = min(core_start + stride, T)
win_start = max(0, core_start - overlap)
win_end = min(T, core_end + overlap)
tile = self.decoder(z[:, :, win_start:win_end])
if upsample is None:
upsample = tile.shape[-1] / (win_end - win_start)
trim_l = int(round((core_start - win_start) * upsample))
trim_r = int(round((win_end - core_end) * upsample))
end_idx = tile.shape[-1] - trim_r if trim_r > 0 else tile.shape[-1]
tiles.append(tile[:, :, trim_l:end_idx])
return torch.cat(tiles, dim=-1)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None

View File

@@ -2,6 +2,7 @@ from ...utils import is_torch_available
if is_torch_available():
from .ace_step_transformer import AceStepTransformer1DModel
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .consisid_transformer_3d import ConsisIDTransformer3DModel

View File

@@ -0,0 +1,626 @@
# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion Transformer (DiT) for ACE-Step 1.5 music generation."""
import inspect
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
dispatch_attention_fn,
)
from ..cache_utils import CacheMixin
from ..embeddings import Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_FLASH_ATTENTION_BACKENDS = {
AttentionBackendName.FLASH,
AttentionBackendName.FLASH_HUB,
AttentionBackendName.FLASH_VARLEN,
AttentionBackendName.FLASH_VARLEN_HUB,
}
_FLASH_ATTENTION_VARLEN_BACKENDS = {
AttentionBackendName.FLASH_VARLEN,
AttentionBackendName.FLASH_VARLEN_HUB,
}
def _get_current_attention_backend(processor: Optional["AceStepAttnProcessor2_0"] = None) -> AttentionBackendName:
backend = getattr(processor, "_attention_backend", None)
if backend is None:
backend, _ = _AttentionBackendRegistry.get_active_backend()
return AttentionBackendName(backend)
def _is_flash_attention_backend(processor: Optional["AceStepAttnProcessor2_0"] = None) -> bool:
return _get_current_attention_backend(processor) in _FLASH_ATTENTION_BACKENDS
# --------------------------------------------------------------------------- #
# attention-mask #
# --------------------------------------------------------------------------- #
def _create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
"""Build a `[B, 1, seq_len, seq_len]` additive mask (0.0 kept, -inf masked).
Mirrors the mask construction in ``acestep/models/turbo/modeling_acestep_v15_turbo.py::create_4d_mask`` so the DiT
sees identical attention coverage regardless of whether SDPA, eager or flash attention is selected downstream.
"""
indices = torch.arange(seq_len, device=device)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
if is_causal:
valid_mask = valid_mask & (diff >= 0)
if is_sliding_window and sliding_window is not None:
if is_causal:
valid_mask = valid_mask & (diff <= sliding_window)
else:
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
valid_mask = valid_mask & padding_mask_4d
min_dtype = torch.finfo(dtype).min
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
# --------------------------------------------------------------------------- #
# RoPE helpers #
# --------------------------------------------------------------------------- #
def _ace_step_rotary_freqs(
seq_len: int, head_dim: int, theta: float, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Build (cos, sin) freqs for ACE-Step RoPE using ``get_1d_rotary_pos_embed``.
The original ACE-Step DiT reuses Qwen3's rotary layout: ``freqs = cat([freq_half, freq_half], dim=-1)`` (not
interleaved), and the rotate-half convention splits the last dim in two halves rather than unbinding pairs. That
matches ``get_1d_rotary_pos_embed(..., use_real=True, repeat_interleave_real=False)`` + ``apply_rotary_emb(...,
use_real_unbind_dim=-2)``.
"""
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
cos, sin = get_1d_rotary_pos_embed(head_dim, positions, theta=theta, use_real=True, repeat_interleave_real=False)
return cos.to(dtype=dtype), sin.to(dtype=dtype)
# --------------------------------------------------------------------------- #
# building blocks #
# --------------------------------------------------------------------------- #
class AceStepMLP(nn.Module):
"""SwiGLU MLP used in ACE-Step transformer blocks."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class AceStepTimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding + 2-layer MLP + 6-way AdaLN scale/shift projection.
Matches the original ACE-Step checkpoint layout exactly (``linear_1``, ``linear_2``, ``time_proj``) so the
converter maps keys 1:1. The sinusoid itself is the shared ``Timesteps`` module (``flip_sin_to_cos=True`` for
ACE-Step's ``cat([cos, sin])`` convention).
"""
def __init__(self, in_channels: int = 256, time_embed_dim: int = 2048, scale: float = 1000.0):
super().__init__()
self.in_channels = in_channels
self.scale = scale
self.time_sinusoid = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
def forward(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
t_freq = self.time_sinusoid(t * self.scale)
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.act1(temb)
temb = self.linear_2(temb)
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
return temb, timestep_proj
class AceStepAttnProcessor2_0:
"""Attention processor for ACE-Step GQA attention.
Dispatches the actual attention call through ``dispatch_attention_fn`` so users can pick flash / sage / native
backends via ``model.set_attention_backend(...)`` or the ``attention_backend`` context manager. Uses the ``(B, L,
H, D)`` tensor layout that the diffusers attention backends consume directly.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AceStepAttnProcessor2_0 requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "AceStepAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
is_cross = attn.is_cross_attention and encoder_hidden_states is not None
kv_input = encoder_hidden_states if is_cross else hidden_states
# Project to (B, L, H, D). Q uses ``heads``; K/V use ``kv_heads`` (GQA).
query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, attn.head_dim))
key = attn.to_k(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
value = attn.to_v(kv_input).unflatten(-1, (attn.kv_heads, attn.head_dim))
query = attn.norm_q(query)
key = attn.norm_k(key)
# RoPE on self-attention only. Matches Qwen3 layout:
# freqs = cat([freq_half, freq_half], dim=-1); rotate-half splits last dim.
if not is_cross and image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2, sequence_dim=1)
attention_kwargs = None
backend = _get_current_attention_backend(self)
dispatch_backend = self._attention_backend
sliding_window = getattr(attn, "sliding_window", None)
if backend in _FLASH_ATTENTION_BACKENDS:
if attention_mask is not None:
if attention_mask.ndim == 2:
padding_mask = attention_mask.to(torch.bool)
elif attention_mask.ndim == 4:
keep_mask = attention_mask if attention_mask.dtype == torch.bool else attention_mask == 0
padding_mask = keep_mask.any(dim=(1, 2))
else:
raise ValueError(
f"Unsupported ACE-Step attention mask shape for flash attention: {attention_mask.shape}"
)
has_padding = not torch.all(padding_mask).item()
if has_padding:
attention_mask = padding_mask
if backend not in _FLASH_ATTENTION_VARLEN_BACKENDS:
raise ValueError(
"ACE-Step flash attention received a padded attention mask. Use `flash_varlen` or "
"`flash_varlen_hub` for batched prompts with padding, or use an unpadded batch with `flash`."
)
else:
attention_mask = None
if not is_cross and sliding_window is not None and key.shape[1] > sliding_window:
# ACE-Step's dense mask keeps `abs(i - j) <= sliding_window`; flash-attn uses the same inclusive
# left/right window convention, so pass the configured value through directly.
attention_kwargs = {"window_size": (sliding_window, sliding_window)}
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=attn.dropout if attn.training else 0.0,
scale=attn.scaling,
enable_gqa=attn.heads != attn.kv_heads,
attention_kwargs=attention_kwargs,
backend=dispatch_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AceStepAttention(torch.nn.Module, AttentionModuleMixin):
"""GQA attention with RMSNorm on query/key for ACE-Step 1.5.
Uses the diffusers ``Attention`` + ``AttnProcessor`` split: this module holds the projections and Q/K norm; the
processor runs the attention dispatch. Self-attention applies RoPE on query/key; cross-attention reads K/V from
``encoder_hidden_states`` and does not apply RoPE.
GQA means Q has ``heads * head_dim`` output while K/V have ``kv_heads * head_dim`` — QKV fusion is therefore
disabled (``_supports_qkv_fusion = False``).
"""
_default_processor_cls = AceStepAttnProcessor2_0
_available_processors = [AceStepAttnProcessor2_0]
_supports_qkv_fusion = False
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
bias: bool = False,
dropout: float = 0.0,
eps: float = 1e-6,
sliding_window: Optional[int] = None,
is_cross_attention: bool = False,
processor: Optional[AceStepAttnProcessor2_0] = None,
):
super().__init__()
self.heads = num_attention_heads
self.kv_heads = num_key_value_heads
self.head_dim = head_dim
self.dropout = dropout
self.scaling = head_dim**-0.5
self.sliding_window = sliding_window
self.is_cross_attention = is_cross_attention
self.to_q = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=bias)
self.to_k = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
self.to_v = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=bias)
self.to_out = nn.ModuleList(
[nn.Linear(num_attention_heads * head_dim, hidden_size, bias=bias), nn.Dropout(0.0)]
)
self.norm_q = RMSNorm(head_dim, eps=eps)
self.norm_k = RMSNorm(head_dim, eps=eps)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
kwargs = {k: v for k, v in kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
**kwargs,
)
class AceStepTransformerBlock(nn.Module):
"""ACE-Step DiT transformer block: self-attn (AdaLN) → cross-attn → MLP (AdaLN).
AdaLN parameters come from the shared ``scale_shift_table + timestep_proj`` chunked into 6 (3 for self-attn + 3 for
MLP).
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
intermediate_size: int,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: Optional[int] = None,
use_cross_attention: bool = True,
):
super().__init__()
self.self_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
bias=attention_bias,
dropout=attention_dropout,
eps=rms_norm_eps,
sliding_window=sliding_window,
is_cross_attention=False,
)
self.use_cross_attention = use_cross_attention
if self.use_cross_attention:
self.cross_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
bias=attention_bias,
dropout=attention_dropout,
eps=rms_norm_eps,
is_cross_attention=True,
)
self.mlp_norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = AceStepMLP(hidden_size, intermediate_size)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb).chunk(
6, dim=1
)
# Self-attention with AdaLN.
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.self_attn(
hidden_states=norm_hidden_states,
image_rotary_emb=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
if self.use_cross_attention and encoder_hidden_states is not None:
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
attn_output = self.cross_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_output
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
return hidden_states
# --------------------------------------------------------------------------- #
# main DiT model #
# --------------------------------------------------------------------------- #
class AceStepTransformer1DModel(ModelMixin, ConfigMixin, AttentionMixin, CacheMixin):
"""Diffusion Transformer for ACE-Step 1.5 music generation.
Generates audio latents conditioned on text, lyrics, and timbre. Uses 1D patch embedding (`Conv1d` with stride
`patch_size`) followed by a stack of `AceStepTransformerBlock`s with alternating sliding-window / full attention on
the self-attention branch. Cross-attention consumes the packed `encoder_hidden_states` produced by
`AceStepConditionEncoder`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
in_channels: int = 192,
audio_acoustic_hidden_dim: int = 64,
patch_size: int = 2,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: Optional[List[str]] = None,
# Dim of the condition encoder's output. Equal to `hidden_size` on the
# non-XL turbo / base models, but the XL turbo has a smaller condition
# encoder (`encoder_hidden_size=2048`) feeding a wider DiT
# (`hidden_size=2560`), so `condition_embedder` needs to project it up.
encoder_hidden_size: Optional[int] = None,
# Variant metadata. Turbo models have guidance distilled into the weights and
# should run without CFG; base/SFT models require CFG with the learned
# `AceStepConditionEncoder.null_condition_emb`. The pipeline reads these to
# pick default `guidance_scale`, `shift`, and `num_inference_steps`.
is_turbo: bool = False,
model_version: Optional[str] = None,
):
super().__init__()
if encoder_hidden_size is None:
encoder_hidden_size = hidden_size
self.patch_size = patch_size
self.head_dim = head_dim
self.rope_theta = rope_theta
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(num_hidden_layers)
]
self.layer_types = list(layer_types)
self.layers = nn.ModuleList(
[
AceStepTransformerBlock(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
use_cross_attention=True,
)
for i in range(num_hidden_layers)
]
)
# Patchify: concat(src_latents, chunk_mask) on the channel dim then Conv1d with
# stride=patch_size lifts (B, T, in_channels) -> (B, T/patch_size, hidden_size).
self.proj_in_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
padding=0,
)
# Dual-timestep conditioning: one path for `t`, one for `(t - r)` (mean-flow).
self.time_embed = AceStepTimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
self.time_embed_r = AceStepTimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
self.norm_out = RMSNorm(hidden_size, eps=rms_norm_eps)
self.proj_out_conv = nn.ConvTranspose1d(
in_channels=hidden_size,
out_channels=audio_acoustic_hidden_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_r: torch.Tensor,
encoder_hidden_states: torch.Tensor,
context_latents: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""The [`AceStepTransformer1DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, channels)`):
Noisy latent input for the diffusion process.
timestep (`torch.Tensor` of shape `(batch_size,)`):
Current diffusion timestep `t`.
timestep_r (`torch.Tensor` of shape `(batch_size,)`):
Reference timestep `r` (set equal to `t` for standard inference).
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, encoder_seq_len, hidden_size)`):
Conditioning embeddings from the condition encoder (text + lyrics + timbre).
context_latents (`torch.Tensor` of shape `(batch_size, seq_len, context_dim)`):
Context latents (source latents concatenated with chunk masks) — fed to the patchify conv alongside
`hidden_states`.
return_dict (`bool`, defaults to `True`):
Whether to return a `Transformer2DModelOutput` or a plain tuple.
Returns:
`Transformer2DModelOutput` or `tuple`: The predicted velocity field.
"""
# Dual timestep embedding: t and (t - r). Sum both paths' AdaLN projections.
temb_t, timestep_proj_t = self.time_embed(timestep)
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
temb = temb_t + temb_r
timestep_proj = timestep_proj_t + timestep_proj_r
# Context concatenation + padding to patch_size boundary + patchify.
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
original_seq_len = hidden_states.shape[1]
if hidden_states.shape[1] % self.patch_size != 0:
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode="constant", value=0)
hidden_states = self.proj_in_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
seq_len = hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
position_embeddings = (cos, sin)
sliding_attn_mask = None
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
sliding_attn_mask = _create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
sliding_window=self.config.sliding_window,
is_sliding_window=True,
is_causal=False,
)
for i, layer_module in enumerate(self.layers):
# Full-attention layers see no mask; only the sliding-attention layers
# need the banded mask. Cross-attention uses no padding mask.
layer_attn_mask = sliding_attn_mask if self.layer_types[i] == "sliding_attention" else None
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer_module,
hidden_states,
position_embeddings,
timestep_proj,
layer_attn_mask,
encoder_hidden_states,
None,
)
else:
hidden_states = layer_module(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
temb=timestep_proj,
attention_mask=layer_attn_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
)
# Adaptive output normalization + de-patchify.
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out_conv(hidden_states.transpose(1, 2)).transpose(1, 2)
hidden_states = hidden_states[:, :original_seq_len, :]
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)

View File

@@ -149,6 +149,12 @@ else:
"WuerstchenPriorPipeline",
]
)
_import_structure["ace_step"] = [
"AceStepAudioTokenDetokenizer",
"AceStepAudioTokenizer",
"AceStepConditionEncoder",
"AceStepPipeline",
]
_import_structure["allegro"] = ["AllegroPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
@@ -574,6 +580,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .ace_step import (
AceStepAudioTokenDetokenizer,
AceStepAudioTokenizer,
AceStepConditionEncoder,
AceStepPipeline,
)
from .allegro import AllegroPipeline
from .animatediff import (
AnimateDiffControlNetPipeline,

View File

@@ -0,0 +1,54 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_ace_step"] = [
"AceStepAudioTokenDetokenizer",
"AceStepAudioTokenizer",
"AceStepConditionEncoder",
]
_import_structure["pipeline_ace_step"] = ["AceStepPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_ace_step import AceStepAudioTokenDetokenizer, AceStepAudioTokenizer, AceStepConditionEncoder
from .pipeline_ace_step import AceStepPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,856 @@
# Copyright 2025 The ACE-Step Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline-specific models for ACE-Step 1.5.
Holds the condition encoder (lyric + timbre + text packing), the encoder layer (``AceStepEncoderLayer`` — not used by
the DiT itself, hence kept here), the audio tokenizer / detokenizer used by cover conditioning, and the
``_pack_sequences`` helper. The DiT uses the RoPE helper, ``AceStepAttention``, and ``_create_4d_mask`` from
``diffusers/models/transformers/ace_step_transformer.py``.
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...models.normalization import RMSNorm
from ...models.transformers.ace_step_transformer import (
AceStepAttention,
AceStepMLP,
_ace_step_rotary_freqs,
_create_4d_mask,
_is_flash_attention_backend,
)
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# --------------------------------------------------------------------------- #
# helpers used only by condition encoder #
# --------------------------------------------------------------------------- #
def _pack_sequences(
hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pack two masked sequences into one with all valid tokens first.
Concatenates ``hidden1`` + ``hidden2`` along the sequence dim, then stably sorts each batch so mask=1 tokens come
before mask=0 tokens. Returns the packed hidden states plus a fresh contiguous mask.
"""
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
mask_cat = torch.cat([mask1, mask2], dim=1)
B, L, D = hidden_cat.shape
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
lengths = mask_cat.sum(dim=1)
new_mask = torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)
return hidden_left, new_mask
class AceStepEncoderLayer(nn.Module):
"""Pre-LN transformer block used by the lyric and timbre encoders."""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
intermediate_size: int,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: Optional[int] = None,
):
super().__init__()
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
bias=attention_bias,
dropout=attention_dropout,
eps=rms_norm_eps,
sliding_window=sliding_window,
is_cross_attention=False,
)
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = AceStepMLP(hidden_size, intermediate_size)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
image_rotary_emb=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# --------------------------------------------------------------------------- #
# encoders #
# --------------------------------------------------------------------------- #
class AceStepLyricEncoder(ModelMixin, ConfigMixin):
"""Lyric encoder: projects Qwen3 lyric embeddings and runs a small transformer.
Output feeds the DiT cross-attention (after packing with text + timbre).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
text_hidden_dim: int = 1024,
num_lyric_encoder_hidden_layers: int = 8,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
for i in range(num_lyric_encoder_hidden_layers)
]
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.head_dim = head_dim
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.layers = nn.ModuleList(
[
AceStepEncoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
)
for i in range(num_lyric_encoder_hidden_layers)
]
)
self._layer_types = layer_types
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.FloatTensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(inputs_embeds)
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
position_embeddings = (cos, sin)
if _is_flash_attention_backend(self.layers[0].self_attn.processor):
full_attn_mask = attention_mask
sliding_attn_mask = attention_mask
else:
full_attn_mask = _create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device, attention_mask=attention_mask, is_causal=False
)
sliding_attn_mask = _create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False,
)
hidden_states = inputs_embeds
for i, layer_module in enumerate(self.layers):
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else full_attn_mask
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer_module, hidden_states, position_embeddings, mask
)
else:
hidden_states = layer_module(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=mask,
)
return self.norm(hidden_states)
class AceStepTimbreEncoder(ModelMixin, ConfigMixin):
"""Timbre encoder: consumes VAE-encoded reference-audio latents and returns a
pooled per-batch timbre embedding (plus a presence mask).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
timbre_hidden_dim: int = 64,
num_timbre_encoder_hidden_layers: int = 4,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
for i in range(num_timbre_encoder_hidden_layers)
]
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
self.head_dim = head_dim
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.layers = nn.ModuleList(
[
AceStepEncoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
)
for i in range(num_timbre_encoder_hidden_layers)
]
)
self._layer_types = layer_types
self.gradient_checkpointing = False
@staticmethod
def unpack_timbre_embeddings(
timbre_embs_packed: torch.Tensor, refer_audio_order_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device
dtype = timbre_embs_packed.dtype
B = int(refer_audio_order_mask.max().item() + 1)
counts = torch.bincount(refer_audio_order_mask, minlength=B)
max_count = counts.max().item()
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
positions = torch.arange(N, device=device)
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
inverse_indices = torch.empty_like(sorted_indices)
inverse_indices[sorted_indices] = torch.arange(N, device=device)
positions_in_batch = positions_in_sorted[inverse_indices]
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
mask_flat = (one_hot.sum(dim=0) > 0).long()
new_mask = mask_flat.reshape(B, max_count)
return timbre_embs_unpack, new_mask
def forward(
self,
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
refer_audio_order_mask: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
inputs_embeds = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
cos, sin = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
position_embeddings = (cos, sin)
sliding_attn_mask = None
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
sliding_attn_mask = _create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=None,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False,
)
hidden_states = inputs_embeds
for i, layer_module in enumerate(self.layers):
# No padding mask on timbre input (pre-packed), so full-attention layers see None.
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer_module, hidden_states, position_embeddings, mask
)
else:
hidden_states = layer_module(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=mask,
)
hidden_states = self.norm(hidden_states)
# CLS-like pooling: first-token embedding per packed sequence.
hidden_states = hidden_states[:, 0, :]
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
# --------------------------------------------------------------------------- #
# audio tokenizer / detokenizer #
# --------------------------------------------------------------------------- #
class _AceStepResidualFSQ(nn.Module):
"""Minimal ResidualFSQ compatible with ACE-Step's saved tokenizer weights."""
def __init__(
self,
dim: int = 2048,
levels: Optional[list] = None,
num_quantizers: int = 1,
):
super().__init__()
if levels is None:
levels = [8, 8, 8, 5, 5, 5]
self.levels = levels
self.num_quantizers = num_quantizers
self.codebook_dim = len(levels)
self.project_in = nn.Linear(dim, self.codebook_dim)
self.project_out = nn.Linear(self.codebook_dim, dim)
levels_tensor = torch.tensor(levels, dtype=torch.long)
basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.long), dim=0)
scales = torch.stack([levels_tensor.float() ** -i for i in range(num_quantizers)])
self.register_buffer("_levels", levels_tensor, persistent=False)
self.register_buffer("_basis", basis, persistent=False)
self.register_buffer("scales", scales, persistent=False)
@property
def codebook_size(self) -> int:
return int(torch.prod(self._levels).item())
def _indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
levels = self._levels.to(device=indices.device)
basis = self._basis.to(device=indices.device)
level_indices = (indices.long().unsqueeze(-1) // basis) % levels
scale = 2.0 / (levels.to(dtype=torch.float32) - 1.0)
return level_indices.to(dtype=torch.float32) * scale - 1.0
def _codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor:
levels = self._levels.to(device=codes.device, dtype=codes.dtype)
basis = self._basis.to(device=codes.device, dtype=codes.dtype)
level_indices = (codes + 1.0) / (2.0 / (levels - 1.0))
return (level_indices * basis).sum(dim=-1).round().to(torch.long)
def _quantize(self, x: torch.Tensor) -> torch.Tensor:
levels = self._levels.to(device=x.device, dtype=x.dtype)
levels_minus_one = levels - 1.0
step = 2.0 / levels_minus_one
bracket = levels_minus_one * (x.clamp(-1.0, 1.0) + 1.0) / 2.0 + 0.5
return step * torch.floor(bracket) - 1.0
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
if indices.ndim == 2:
indices = indices.unsqueeze(-1)
if indices.shape[-1] != self.num_quantizers:
raise ValueError(
f"Expected audio code indices with last dimension {self.num_quantizers}, got {indices.shape[-1]}."
)
codes = []
for quantizer_idx in range(self.num_quantizers):
code = self._indices_to_codes(indices[..., quantizer_idx])
scale = self.scales[quantizer_idx].to(device=code.device, dtype=code.dtype)
codes.append(code * scale)
return torch.stack(codes, dim=0)
def get_output_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
codes = self.get_codes_from_indices(indices).sum(dim=0)
weight = self.project_out.weight.float()
bias = self.project_out.bias.float() if self.project_out.bias is not None else None
output = F.linear(codes.float(), weight, bias)
return output.to(dtype=self.project_out.weight.dtype)
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
input_dtype = hidden_states.dtype
weight = self.project_in.weight.float()
bias = self.project_in.bias.float() if self.project_in.bias is not None else None
hidden_states = F.linear(hidden_states.float(), weight, bias)
levels = self._levels.to(device=hidden_states.device, dtype=hidden_states.dtype)
soft_clamp = 1.0 + (1.0 / (levels - 1.0))
hidden_states = (hidden_states / soft_clamp).tanh() * soft_clamp
quantized_out = torch.zeros_like(hidden_states)
residual = hidden_states
all_indices = []
for scale in self.scales.to(device=hidden_states.device, dtype=hidden_states.dtype):
quantized = self._quantize(residual / scale) * scale
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(self._codes_to_indices(quantized / scale))
weight = self.project_out.weight.float()
bias = self.project_out.bias.float() if self.project_out.bias is not None else None
quantized_out = F.linear(quantized_out.float(), weight, bias).to(dtype=input_dtype)
all_indices = torch.stack(all_indices, dim=-1)
return quantized_out, all_indices
class AceStepAttentionPooler(nn.Module):
"""Attention pooler used by the ACE-Step audio tokenizer."""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_pooler_hidden_layers: int = 2,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
for i in range(num_attention_pooler_hidden_layers)
]
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.head_dim = head_dim
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.layers = nn.ModuleList(
[
AceStepEncoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
)
for i in range(num_attention_pooler_hidden_layers)
]
)
self._layer_types = layer_types
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_patches, patch_size, _ = hidden_states.shape
hidden_states = self.embed_tokens(hidden_states)
special_token = self.special_token.to(device=hidden_states.device, dtype=hidden_states.dtype)
special_token = special_token.expand(batch_size, num_patches, -1, -1)
hidden_states = torch.cat([special_token, hidden_states], dim=2)
hidden_states = hidden_states.reshape(batch_size * num_patches, patch_size + 1, -1)
seq_len = hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
sliding_attn_mask = None
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
sliding_attn_mask = _create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=None,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False,
)
for i, layer_module in enumerate(self.layers):
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
hidden_states = layer_module(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=mask,
)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, 0, :]
return hidden_states.reshape(batch_size, num_patches, -1)
class AceStepAudioTokenDetokenizer(ModelMixin, ConfigMixin):
"""Expands ACE-Step 5 Hz audio tokens back to 25 Hz acoustic conditioning."""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
num_attention_pooler_hidden_layers: int = 2,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
if layer_types is None:
layer_types = [
"sliding_attention" if bool((i + 1) % 2) else "full_attention"
for i in range(num_attention_pooler_hidden_layers)
]
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
self.head_dim = head_dim
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.pool_window_size = pool_window_size
self.layers = nn.ModuleList(
[
AceStepEncoderLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
intermediate_size=intermediate_size,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window if layer_types[i] == "sliding_attention" else None,
)
for i in range(num_attention_pooler_hidden_layers)
]
)
self._layer_types = layer_types
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_tokens, _ = hidden_states.shape
hidden_states = self.embed_tokens(hidden_states)
hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.pool_window_size, -1)
special_tokens = self.special_tokens.to(device=hidden_states.device, dtype=hidden_states.dtype)
hidden_states = hidden_states + special_tokens.unsqueeze(0)
hidden_states = hidden_states.reshape(batch_size * num_tokens, self.pool_window_size, -1)
seq_len = hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
position_embeddings = _ace_step_rotary_freqs(seq_len, self.head_dim, self.rope_theta, device, dtype)
sliding_attn_mask = None
if not _is_flash_attention_backend(self.layers[0].self_attn.processor):
sliding_attn_mask = _create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=None,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False,
)
for i, layer_module in enumerate(self.layers):
mask = sliding_attn_mask if self._layer_types[i] == "sliding_attention" else None
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer_module, hidden_states, position_embeddings, mask
)
else:
hidden_states = layer_module(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=mask,
)
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states.reshape(batch_size, num_tokens * self.pool_window_size, -1)
class AceStepAudioTokenizer(ModelMixin, ConfigMixin):
"""Converts 25 Hz acoustic latents to ACE-Step 5 Hz audio tokens."""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
fsq_dim: int = 2048,
fsq_input_levels: list = None,
fsq_input_num_quantizers: int = 1,
num_attention_pooler_hidden_layers: int = 2,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
if fsq_input_levels is None:
fsq_input_levels = [8, 8, 8, 5, 5, 5]
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
self.attention_pooler = AceStepAttentionPooler(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rope_theta=rope_theta,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window,
layer_types=layer_types,
)
self.quantizer = _AceStepResidualFSQ(
dim=fsq_dim,
levels=fsq_input_levels,
num_quantizers=fsq_input_num_quantizers,
)
self.pool_window_size = pool_window_size
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
input_dtype = hidden_states.dtype
hidden_states = self.audio_acoustic_proj(hidden_states)
hidden_states = self.attention_pooler(hidden_states)
quantized, indices = self.quantizer(hidden_states)
return quantized.to(dtype=input_dtype), indices
def tokenize(
self,
hidden_states: torch.Tensor,
silence_latent: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, latent_length, acoustic_dim = hidden_states.shape
pad_len = (-latent_length) % self.pool_window_size
if pad_len:
if silence_latent is not None and silence_latent.shape[-1] == acoustic_dim:
pad = silence_latent[:, :pad_len, :].to(device=hidden_states.device, dtype=hidden_states.dtype)
pad = pad.expand(batch_size, -1, -1)
else:
pad = torch.zeros(
batch_size, pad_len, acoustic_dim, device=hidden_states.device, dtype=hidden_states.dtype
)
hidden_states = torch.cat([hidden_states, pad], dim=1)
num_patches = hidden_states.shape[1] // self.pool_window_size
hidden_states = hidden_states.reshape(batch_size, num_patches, self.pool_window_size, acoustic_dim)
return self(hidden_states)
# --------------------------------------------------------------------------- #
# condition encoder #
# --------------------------------------------------------------------------- #
class AceStepConditionEncoder(ModelMixin, ConfigMixin):
"""Fuses text + lyric + timbre conditioning into the packed sequence used by
the DiT's cross-attention.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
text_hidden_dim: int = 1024,
timbre_hidden_dim: int = 64,
num_lyric_encoder_hidden_layers: int = 8,
num_timbre_encoder_hidden_layers: int = 4,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 128,
rope_theta: float = 1000000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
sliding_window: int = 128,
layer_types: list = None,
):
super().__init__()
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
self.lyric_encoder = AceStepLyricEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
text_hidden_dim=text_hidden_dim,
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rope_theta=rope_theta,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window,
layer_types=layer_types,
)
self.timbre_encoder = AceStepTimbreEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
timbre_hidden_dim=timbre_hidden_dim,
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rope_theta=rope_theta,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rms_norm_eps=rms_norm_eps,
sliding_window=sliding_window,
)
# Learned null-condition embedding for classifier-free guidance, trained with
# `cfg_ratio=0.15` in the original model. Broadcast along the sequence dim when used.
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
# Silence latent — VAE-encoded audio-silence, stored as (1, T_long, timbre_hidden_dim).
# When no reference audio is provided, the pipeline slices `silence_latent[:, :timbre_fix_frame, :]`
# and feeds that to the timbre encoder. Passing literal zeros puts the timbre encoder
# OOD and produces drone-like audio (observed on all text2music outputs before this fix).
# The placeholder here is overwritten by the converter with the real encoded silence,
# so its shape just needs to match the timbre-encoder input: last dim is
# `timbre_hidden_dim` (so smaller test configs with `timbre_hidden_dim != 64` also load).
self.register_buffer(
"silence_latent",
torch.zeros(1, 15000, timbre_hidden_dim),
persistent=True,
)
def forward(
self,
text_hidden_states: torch.FloatTensor,
text_attention_mask: torch.Tensor,
lyric_hidden_states: torch.FloatTensor,
lyric_attention_mask: torch.Tensor,
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
refer_audio_order_mask: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
text_hidden_states = self.text_projector(text_hidden_states)
lyric_hidden_states = self.lyric_encoder(
inputs_embeds=lyric_hidden_states, attention_mask=lyric_attention_mask
)
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(
refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask
)
encoder_hidden_states, encoder_attention_mask = _pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
)
encoder_hidden_states, encoder_attention_mask = _pack_sequences(
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
)
return encoder_hidden_states, encoder_attention_mask

File diff suppressed because it is too large Load Diff

View File

@@ -405,6 +405,21 @@ class VaeImageProcessorLDM3D(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AceStepTransformer1DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AllegroTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -632,6 +632,66 @@ class ZImageModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AceStepAudioTokenDetokenizer(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AceStepAudioTokenizer(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AceStepConditionEncoder(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AceStepPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import AceStepTransformer1DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
enable_full_determinism()
class AceStepTransformer1DModelTesterConfig(BaseModelTesterConfig):
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def model_class(self):
return AceStepTransformer1DModel
@property
def output_shape(self) -> tuple[int, ...]:
return (8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | float | bool]:
return {
"hidden_size": 32,
"intermediate_size": 64,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 8,
"in_channels": 24, # audio_acoustic_hidden_dim * 3 (hidden + context_latents)
"audio_acoustic_hidden_dim": 8,
"patch_size": 2,
"rope_theta": 10000.0,
"rms_norm_eps": 1e-6,
"sliding_window": 16,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
seq_len = 8
encoder_seq_len = 10
acoustic_dim = 8
hidden_size = 32
return {
"hidden_states": randn_tensor(
(batch_size, seq_len, acoustic_dim), generator=self.generator, device=torch_device
),
"timestep": randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs(),
"timestep_r": randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs(),
"encoder_hidden_states": randn_tensor(
(batch_size, encoder_seq_len, hidden_size), generator=self.generator, device=torch_device
),
"context_latents": randn_tensor(
(batch_size, seq_len, acoustic_dim * 2), generator=self.generator, device=torch_device
),
}
class TestAceStepTransformer1DModel(AceStepTransformer1DModelTesterConfig, ModelTesterMixin):
pass

View File

View File

@@ -0,0 +1,486 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import torch
from transformers import AutoTokenizer, Qwen3Config, Qwen3Model
from diffusers import AutoencoderOobleck, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.ace_step_transformer import AceStepTransformer1DModel
from diffusers.pipelines.ace_step import (
AceStepAudioTokenDetokenizer,
AceStepAudioTokenizer,
AceStepConditionEncoder,
AceStepPipeline,
)
from ...testing_utils import enable_full_determinism
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class AceStepConditionEncoderTests(unittest.TestCase):
"""Fast tests for the AceStepConditionEncoder."""
def get_tiny_config(self):
return {
"hidden_size": 32,
"intermediate_size": 64,
"text_hidden_dim": 16,
"timbre_hidden_dim": 8,
"num_lyric_encoder_hidden_layers": 2,
"num_timbre_encoder_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 8,
"rope_theta": 10000.0,
"attention_bias": False,
"attention_dropout": 0.0,
"rms_norm_eps": 1e-6,
"sliding_window": 16,
}
def test_forward_shape(self):
"""Test that the condition encoder produces packed hidden states."""
config = self.get_tiny_config()
encoder = AceStepConditionEncoder(**config)
encoder.eval()
batch_size = 2
text_seq_len = 8
lyric_seq_len = 12
text_dim = config["text_hidden_dim"]
timbre_dim = config["timbre_hidden_dim"]
timbre_time = 10
text_hidden_states = torch.randn(batch_size, text_seq_len, text_dim)
text_attention_mask = torch.ones(batch_size, text_seq_len)
lyric_hidden_states = torch.randn(batch_size, lyric_seq_len, text_dim)
lyric_attention_mask = torch.ones(batch_size, lyric_seq_len)
# Packed reference audio: 3 references across 2 batch items
refer_audio = torch.randn(3, timbre_time, timbre_dim)
refer_order_mask = torch.tensor([0, 0, 1], dtype=torch.long)
with torch.no_grad():
enc_hidden, enc_mask = encoder(
text_hidden_states=text_hidden_states,
text_attention_mask=text_attention_mask,
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyric_attention_mask,
refer_audio_acoustic_hidden_states_packed=refer_audio,
refer_audio_order_mask=refer_order_mask,
)
# Output should be packed: batch_size x (lyric + timbre + text seq_len) x hidden_size
self.assertEqual(enc_hidden.shape[0], batch_size)
self.assertEqual(enc_hidden.shape[2], config["hidden_size"])
self.assertEqual(enc_mask.shape[0], batch_size)
self.assertEqual(enc_mask.shape[1], enc_hidden.shape[1])
def test_save_load_config(self):
"""Test that the condition encoder config can be saved and loaded."""
import tempfile
config = self.get_tiny_config()
encoder = AceStepConditionEncoder(**config)
with tempfile.TemporaryDirectory() as tmpdir:
encoder.save_config(tmpdir)
loaded = AceStepConditionEncoder.from_config(tmpdir)
self.assertEqual(encoder.config.hidden_size, loaded.config.hidden_size)
self.assertEqual(encoder.config.text_hidden_dim, loaded.config.text_hidden_dim)
self.assertEqual(encoder.config.timbre_hidden_dim, loaded.config.timbre_hidden_dim)
class AceStepPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"""Fast end-to-end tests for AceStepPipeline with tiny models."""
pipeline_class = AceStepPipeline
params = frozenset(
[
"prompt",
"lyrics",
"audio_duration",
"vocal_language",
"guidance_scale",
"shift",
]
)
batch_params = frozenset(["prompt", "lyrics"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"output_type",
"return_dict",
]
)
# ACE-Step uses custom attention, not standard diffusers attention processors
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = AceStepTransformer1DModel(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
in_channels=24,
audio_acoustic_hidden_dim=8,
patch_size=2,
rope_theta=10000.0,
sliding_window=16,
)
# Create a tiny Qwen3Model for testing (matching the real Qwen3-Embedding-0.6B architecture)
torch.manual_seed(0)
qwen3_config = Qwen3Config(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
vocab_size=151936, # Qwen3 vocab size
max_position_embeddings=256,
)
text_encoder = Qwen3Model(qwen3_config)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
text_hidden_dim = qwen3_config.hidden_size # 32
torch.manual_seed(0)
condition_encoder = AceStepConditionEncoder(
hidden_size=32,
intermediate_size=64,
text_hidden_dim=text_hidden_dim,
timbre_hidden_dim=8,
num_lyric_encoder_hidden_layers=2,
num_timbre_encoder_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
rope_theta=10000.0,
sliding_window=16,
)
audio_tokenizer_kwargs = {
"hidden_size": 32,
"intermediate_size": 64,
"audio_acoustic_hidden_dim": 8,
"pool_window_size": 2,
"fsq_dim": 32,
"fsq_input_levels": [4, 4, 4],
"fsq_input_num_quantizers": 1,
"num_attention_pooler_hidden_layers": 1,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 8,
"rope_theta": 10000.0,
"sliding_window": 16,
}
torch.manual_seed(0)
audio_tokenizer = AceStepAudioTokenizer(**audio_tokenizer_kwargs)
torch.manual_seed(0)
audio_token_detokenizer = AceStepAudioTokenDetokenizer(
hidden_size=32,
intermediate_size=64,
audio_acoustic_hidden_dim=8,
pool_window_size=2,
num_attention_pooler_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
rope_theta=10000.0,
sliding_window=16,
)
torch.manual_seed(0)
vae = AutoencoderOobleck(
encoder_hidden_size=6,
downsampling_ratios=[1, 2],
decoder_channels=3,
decoder_input_channels=8,
audio_channels=2,
channel_multiples=[2, 4],
sampling_rate=4,
)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)
components = {
"transformer": transformer,
"condition_encoder": condition_encoder,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"audio_tokenizer": audio_tokenizer,
"audio_token_detokenizer": audio_token_detokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A beautiful piano piece",
"lyrics": "[verse]\nSoft notes in the morning",
"audio_duration": 0.4, # Very short for fast test (10 latent frames at 25Hz)
"num_inference_steps": 2,
"generator": generator,
"max_text_length": 32,
}
return inputs
def test_ace_step_basic(self):
"""Test basic text-to-music generation."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
prompt="A beautiful piano piece",
lyrics="[verse]\nSoft notes in the morning",
audio_duration=0.4,
num_inference_steps=2,
generator=generator,
max_text_length=32,
)
audio = output.audios
self.assertIsNotNone(audio)
self.assertEqual(audio.ndim, 3) # [batch, channels, samples]
def test_ace_step_batch(self):
"""Test batch generation."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(42)
output = pipe(
prompt=["Piano piece", "Guitar solo"],
lyrics=["[verse]\nHello", "[chorus]\nWorld"],
audio_duration=0.4,
num_inference_steps=2,
generator=generator,
max_text_length=32,
)
audio = output.audios
self.assertIsNotNone(audio)
self.assertEqual(audio.shape[0], 2) # batch size = 2
def test_ace_step_latent_output(self):
"""Test that output_type='latent' returns latents."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
prompt="A test prompt",
lyrics="",
audio_duration=0.4,
num_inference_steps=2,
generator=generator,
output_type="latent",
max_text_length=32,
)
latents = output.audios
self.assertIsNotNone(latents)
# Latent shape: [batch, latent_length, acoustic_dim]
self.assertEqual(latents.ndim, 3)
self.assertEqual(latents.shape[0], 1)
def test_ace_step_return_dict_false(self):
"""Test that return_dict=False returns a tuple."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
prompt="A test prompt",
lyrics="",
audio_duration=0.4,
num_inference_steps=2,
generator=generator,
return_dict=False,
max_text_length=32,
)
self.assertIsInstance(output, tuple)
self.assertEqual(len(output), 1)
def test_audio_codes_cover_path(self):
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
output = pipe(
prompt="A test prompt",
lyrics="",
audio_codes="<|audio_code_1|><|audio_code_2|>",
num_inference_steps=1,
output_type="latent",
max_text_length=32,
)
self.assertEqual(output.audios.shape[1], 4)
def test_save_load_local(self, expected_max_difference=7e-3):
# increase tolerance to account for large composite model
super().test_save_load_local(expected_max_difference=expected_max_difference)
def test_save_load_optional_components(self, expected_max_difference=7e-3):
# increase tolerance to account for large composite model
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=7e-3):
# increase tolerance for audio pipeline
super().test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=7e-3):
# increase tolerance for audio pipeline
super().test_dict_tuple_outputs_equivalent(
expected_slice=expected_slice, expected_max_difference=expected_max_difference
)
# ACE-Step does not use num_images_per_prompt
def test_num_images_per_prompt(self):
pass
# ACE-Step does not use standard schedulers
@unittest.skip("ACE-Step uses built-in flow matching schedule, not diffusers schedulers")
def test_karras_schedulers_shape(self):
pass
# ACE-Step does not support prompt_embeds directly
@unittest.skip("ACE-Step does not support prompt_embeds / negative_prompt_embeds")
def test_cfg(self):
pass
def test_float16_inference(self, expected_max_diff=5e-2):
super().test_float16_inference(expected_max_diff=expected_max_diff)
@unittest.skip(
"ACE-Step __call__ does not accept prompt_embeds, so encode_prompt isolation test is not applicable"
)
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Sequential CPU offloading produces NaN with tiny random models")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Sequential CPU offloading produces NaN with tiny random models")
def test_sequential_offload_forward_pass_twice(self):
pass
def test_encode_prompt(self):
"""Test that encode_prompt returns correct shapes."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
text_hidden, text_mask, lyric_hidden, lyric_mask = pipe.encode_prompt(
prompt="A test prompt",
lyrics="[verse]\nHello world",
device=device,
max_text_length=32,
max_lyric_length=64,
)
self.assertEqual(text_hidden.ndim, 3) # [batch, seq_len, hidden_dim]
self.assertEqual(text_mask.ndim, 2) # [batch, seq_len]
self.assertEqual(lyric_hidden.ndim, 3)
self.assertEqual(lyric_mask.ndim, 2)
self.assertEqual(text_hidden.shape[0], 1)
self.assertEqual(lyric_hidden.shape[0], 1)
def test_prepare_latents(self):
"""Test that prepare_latents returns correct shapes."""
device = "cpu"
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
pipe = pipe.to(device)
latents = pipe.prepare_latents(
batch_size=2,
audio_duration=1.0,
dtype=torch.float32,
device=device,
)
expected_length = math.ceil(1.0 * pipe.latents_per_second)
self.assertEqual(latents.shape, (2, expected_length, 8))
def test_timestep_schedule(self):
"""Test that the timestep schedule is generated correctly."""
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
# Test standard schedule
schedule = pipe._get_timestep_schedule(num_inference_steps=8, shift=3.0)
self.assertEqual(len(schedule), 8)
self.assertAlmostEqual(schedule[0].item(), 1.0, places=5)
# Test truncated schedule
schedule = pipe._get_timestep_schedule(num_inference_steps=4, shift=3.0)
self.assertEqual(len(schedule), 4)
def test_format_prompt(self):
"""Test that prompt formatting works correctly."""
components = self.get_dummy_components()
pipe = AceStepPipeline(**components)
text, lyrics = pipe._format_prompt(
prompt="A piano piece",
lyrics="[verse]\nHello",
vocal_language="en",
audio_duration=30.0,
)
self.assertIn("A piano piece", text)
self.assertIn("30 seconds", text)
self.assertIn("[verse]", lyrics)
self.assertIn("Hello", lyrics)
self.assertIn("en", lyrics)
if __name__ == "__main__":
unittest.main()