mirror of
https://github.com/huggingface/diffusers.git
synced 2026-06-02 00:01:34 +08:00
* [feat] JoyAI-JoyImage-Edit support * [fix] remove rearrange * [refactor] two pass when do cfg * [refactor] remove repa, use wantimetextembeding, refactor modulate code * [refactor] Joyimage Attention refactor * remove vae tiling and autocast * [fix] remove einops from setup.py * [refactor] Refactor JoyImageEditPipeline to use explicit arguments instead of namespace and remove _build_arg * [fix] remove deprecated method decode_latents * [refactor] refactor the image pre-processing logic into a separate VaeImageProcessor subclass * [refactor] add JoyImageAttention to align with Attention + AttnProcessor design and update conversion script for new weight key mapping (e.g. img_attn_qkv -> attn.img_attn_qkv) * [refactor] simplify bucket logic in JoyImageEditImageProcessor by replacing runtime generation with precomputed lookup tables * [fix] remove leftover training-only parameters * [fix] add layerwise casting and fp32 module patterns to JoyImageTransformer3DModel. Reference WanTransformer3DModel to fix layer casting errors during inference. * [test] add JoyImageEditPipeline fast tests and JoyImageEditTransformer3DModel model tests * [fix] fix some pipeline args to support batch inference * [fix] duplicate images to match batch size when fewer images than prompts in JoyImageEditPipeline * [fix] remove no longer used config parameters * Apply style fixes * [fix] remove unused dataclass and rewrite helpers as inline functions * [fix] make dummy objects for JoyImageEdit * [fix] allow test_torch_compile_repeated_blocks to pass * [fix] add examples on JoyImageEditPipeline * fix code style issues with ruff and black * Apply style fixes * [fix] change default num_inference_steps to 40 * [fix] use forward hook to extract pre-norm hidden states for transformers 5.x compatibility * [fix] change the assert to ValueError in pipeline * [fix] rename JoyImageTransformer3DModel to JoyImageEditTransformer3DModel, clean up anything about the alias * [fix] support gradient checkpointing * [refactor] simplify RoPE utilities, inline helpers, copy WanTimeTextImageEmbedding locally and remove unused parameters * [fix] remove _get_text_encoder_ckpt and qwen_processor * [fix] change nn.RMSNorm to FP32LayerNorm * [fix] small fixes for suggestions given by Claude * [refactor] build model using from _pretained instead of config * [refactor] auto-wrap prompt and support text-to-image in JoyImage Edit pipeline * make style, make quality and make fix-copies * [test] small fix to use vocab_size=1024 * [refactor] separate encode_prompt_multiple_images from encode_prompt, support prompt_embeds/prompt_embesd_mask/num_images_per_prompt in edit mode * [test] fix CI: use strict=False for xfail and add @require_torch_accelerator to group offloading test * [refactor] separate image_latents from latents in prepare_latents to align with flux2 * make style --------- Co-authored-by: zhangmaoquan.1 <zhangmaoquan.1@jd.com> Co-authored-by: huangfeice <huangfeice@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
367 lines
15 KiB
Python
367 lines
15 KiB
Python
import argparse
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import torch
|
|
from accelerate import init_empty_weights
|
|
from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration
|
|
|
|
from diffusers import (
|
|
AutoencoderKLWan,
|
|
JoyImageEditPipeline,
|
|
JoyImageEditTransformer3DModel,
|
|
)
|
|
from diffusers.schedulers.scheduling_flow_match_euler_discrete import (
|
|
FlowMatchEulerDiscreteScheduler,
|
|
)
|
|
|
|
|
|
# This code is modified from convert_wan_to_diffusers.py to support input ckpt path
|
|
def convert_vae(vae_ckpt_path):
|
|
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
|
new_state_dict = {}
|
|
|
|
# Create mappings for specific components
|
|
middle_key_mapping = {
|
|
# Encoder middle block
|
|
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
|
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
|
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
|
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
|
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
|
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
|
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
|
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
|
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
|
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
|
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
|
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
|
# Decoder middle block
|
|
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
|
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
|
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
|
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
|
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
|
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
|
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
|
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
|
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
|
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
|
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
|
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
|
}
|
|
|
|
# Create a mapping for attention blocks
|
|
attention_mapping = {
|
|
# Encoder middle attention
|
|
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
|
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
|
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
|
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
|
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
|
# Decoder middle attention
|
|
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
|
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
|
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
|
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
|
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
|
}
|
|
|
|
# Create a mapping for the head components
|
|
head_mapping = {
|
|
# Encoder head
|
|
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
|
"encoder.head.2.bias": "encoder.conv_out.bias",
|
|
"encoder.head.2.weight": "encoder.conv_out.weight",
|
|
# Decoder head
|
|
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
|
"decoder.head.2.bias": "decoder.conv_out.bias",
|
|
"decoder.head.2.weight": "decoder.conv_out.weight",
|
|
}
|
|
|
|
# Create a mapping for the quant components
|
|
quant_mapping = {
|
|
"conv1.weight": "quant_conv.weight",
|
|
"conv1.bias": "quant_conv.bias",
|
|
"conv2.weight": "post_quant_conv.weight",
|
|
"conv2.bias": "post_quant_conv.bias",
|
|
}
|
|
|
|
# Process each key in the state dict
|
|
for key, value in old_state_dict.items():
|
|
# Handle middle block keys using the mapping
|
|
if key in middle_key_mapping:
|
|
new_key = middle_key_mapping[key]
|
|
new_state_dict[new_key] = value
|
|
# Handle attention blocks using the mapping
|
|
elif key in attention_mapping:
|
|
new_key = attention_mapping[key]
|
|
new_state_dict[new_key] = value
|
|
# Handle head keys using the mapping
|
|
elif key in head_mapping:
|
|
new_key = head_mapping[key]
|
|
new_state_dict[new_key] = value
|
|
# Handle quant keys using the mapping
|
|
elif key in quant_mapping:
|
|
new_key = quant_mapping[key]
|
|
new_state_dict[new_key] = value
|
|
# Handle encoder conv1
|
|
elif key == "encoder.conv1.weight":
|
|
new_state_dict["encoder.conv_in.weight"] = value
|
|
elif key == "encoder.conv1.bias":
|
|
new_state_dict["encoder.conv_in.bias"] = value
|
|
# Handle decoder conv1
|
|
elif key == "decoder.conv1.weight":
|
|
new_state_dict["decoder.conv_in.weight"] = value
|
|
elif key == "decoder.conv1.bias":
|
|
new_state_dict["decoder.conv_in.bias"] = value
|
|
# Handle encoder downsamples
|
|
elif key.startswith("encoder.downsamples."):
|
|
# Convert to down_blocks
|
|
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
|
|
|
# Convert residual block naming but keep the original structure
|
|
if ".residual.0.gamma" in new_key:
|
|
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
|
elif ".residual.2.bias" in new_key:
|
|
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
|
elif ".residual.2.weight" in new_key:
|
|
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
|
elif ".residual.3.gamma" in new_key:
|
|
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
|
elif ".residual.6.bias" in new_key:
|
|
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
|
elif ".residual.6.weight" in new_key:
|
|
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
|
elif ".shortcut.bias" in new_key:
|
|
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
|
elif ".shortcut.weight" in new_key:
|
|
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
|
|
|
new_state_dict[new_key] = value
|
|
|
|
# Handle decoder upsamples
|
|
elif key.startswith("decoder.upsamples."):
|
|
# Convert to up_blocks
|
|
parts = key.split(".")
|
|
block_idx = int(parts[2])
|
|
|
|
# Group residual blocks
|
|
if "residual" in key:
|
|
if block_idx in [0, 1, 2]:
|
|
new_block_idx = 0
|
|
resnet_idx = block_idx
|
|
elif block_idx in [4, 5, 6]:
|
|
new_block_idx = 1
|
|
resnet_idx = block_idx - 4
|
|
elif block_idx in [8, 9, 10]:
|
|
new_block_idx = 2
|
|
resnet_idx = block_idx - 8
|
|
elif block_idx in [12, 13, 14]:
|
|
new_block_idx = 3
|
|
resnet_idx = block_idx - 12
|
|
else:
|
|
# Keep as is for other blocks
|
|
new_state_dict[key] = value
|
|
continue
|
|
|
|
# Convert residual block naming
|
|
if ".residual.0.gamma" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
|
elif ".residual.2.bias" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
|
elif ".residual.2.weight" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
|
elif ".residual.3.gamma" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
|
elif ".residual.6.bias" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
|
elif ".residual.6.weight" in key:
|
|
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
|
else:
|
|
new_key = key
|
|
|
|
new_state_dict[new_key] = value
|
|
|
|
# Handle shortcut connections
|
|
elif ".shortcut." in key:
|
|
if block_idx == 4:
|
|
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
|
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
|
else:
|
|
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
|
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
|
|
|
new_state_dict[new_key] = value
|
|
|
|
# Handle upsamplers
|
|
elif ".resample." in key or ".time_conv." in key:
|
|
if block_idx == 3:
|
|
new_key = key.replace(
|
|
f"decoder.upsamples.{block_idx}",
|
|
"decoder.up_blocks.0.upsamplers.0",
|
|
)
|
|
elif block_idx == 7:
|
|
new_key = key.replace(
|
|
f"decoder.upsamples.{block_idx}",
|
|
"decoder.up_blocks.1.upsamplers.0",
|
|
)
|
|
elif block_idx == 11:
|
|
new_key = key.replace(
|
|
f"decoder.upsamples.{block_idx}",
|
|
"decoder.up_blocks.2.upsamplers.0",
|
|
)
|
|
else:
|
|
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
|
|
|
new_state_dict[new_key] = value
|
|
else:
|
|
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
|
new_state_dict[new_key] = value
|
|
else:
|
|
# Keep other keys unchanged
|
|
new_state_dict[key] = value
|
|
|
|
with init_empty_weights():
|
|
vae = AutoencoderKLWan()
|
|
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
|
return vae
|
|
|
|
|
|
def get_transformer_config() -> Tuple[Dict[str, Any], ...]:
|
|
config = {
|
|
"diffusers_config": {
|
|
"hidden_size": 4096,
|
|
"in_channels": 16,
|
|
"num_attention_heads": 32,
|
|
"num_layers": 40,
|
|
"out_channels": 16,
|
|
"patch_size": [1, 2, 2],
|
|
"rope_dim_list": [16, 56, 56],
|
|
"text_dim": 4096,
|
|
"rope_type": "rope",
|
|
"theta": 10000,
|
|
},
|
|
}
|
|
return config
|
|
|
|
|
|
def convert_transformer(ckpt_path: str):
|
|
checkpoint = torch.load(ckpt_path, weights_only=True)
|
|
if "model" in checkpoint:
|
|
original_state_dict = checkpoint["model"]
|
|
else:
|
|
original_state_dict = checkpoint
|
|
|
|
# Attention weights moved from block to block.attn submodule
|
|
attn_suffixes = (
|
|
"img_attn_qkv.",
|
|
"img_attn_q_norm.",
|
|
"img_attn_k_norm.",
|
|
"img_attn_proj.",
|
|
"txt_attn_qkv.",
|
|
"txt_attn_q_norm.",
|
|
"txt_attn_k_norm.",
|
|
"txt_attn_proj.",
|
|
)
|
|
remapped = {}
|
|
for key, value in original_state_dict.items():
|
|
new_key = key
|
|
if key.startswith("double_blocks."):
|
|
for suffix in attn_suffixes:
|
|
# double_blocks.0.img_attn_qkv.weight -> double_blocks.0.attn.img_attn_qkv.weight
|
|
if "." + suffix in key and ".attn." + suffix not in key:
|
|
new_key = key.replace("." + suffix, ".attn." + suffix)
|
|
break
|
|
remapped[new_key] = value
|
|
|
|
config = get_transformer_config()
|
|
with init_empty_weights():
|
|
transformer = JoyImageEditTransformer3DModel(**config["diffusers_config"])
|
|
transformer.load_state_dict(remapped, strict=True, assign=True)
|
|
return transformer
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--transformer_ckpt_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to original transformer checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--vae_ckpt_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to original VAE checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--text_encoder_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to original llama checkpoint",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to original llama tokenizer",
|
|
)
|
|
parser.add_argument("--save_pipeline", action="store_true")
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path where converted model should be saved",
|
|
)
|
|
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
|
parser.add_argument("--flow_shift", type=float, default=7.0)
|
|
return parser.parse_args()
|
|
|
|
|
|
DTYPE_MAPPING = {
|
|
"fp32": torch.float32,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
}
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
transformer = None
|
|
vae = None
|
|
dtype = DTYPE_MAPPING[args.dtype]
|
|
|
|
if args.save_pipeline:
|
|
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
|
assert args.text_encoder_path is not None
|
|
# assert args.tokenizer_path is not None
|
|
if args.transformer_ckpt_path is not None:
|
|
transformer = convert_transformer(args.transformer_ckpt_path)
|
|
transformer = transformer.to(dtype=dtype)
|
|
if not args.save_pipeline:
|
|
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
|
if args.vae_ckpt_path is not None:
|
|
vae = convert_vae(args.vae_ckpt_path)
|
|
vae = vae.to(dtype=dtype)
|
|
if not args.save_pipeline:
|
|
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
|
if args.save_pipeline:
|
|
processor = AutoProcessor.from_pretrained(args.text_encoder_path)
|
|
text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(
|
|
args.text_encoder_path, torch_dtype=torch.bfloat16
|
|
).to("cuda")
|
|
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path)
|
|
flow_shift = 1.5
|
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=flow_shift)
|
|
transformer = transformer.to("cuda")
|
|
vae = vae.to("cuda")
|
|
pipe = JoyImageEditPipeline(
|
|
processor=processor,
|
|
transformer=transformer,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
vae=vae,
|
|
scheduler=scheduler,
|
|
).to("cuda")
|
|
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
|
processor.save_pretrained(f"{args.output_path}/processor")
|