mirror of
https://github.com/huggingface/diffusers.git
synced 2026-06-02 00:01:34 +08:00
[Enhance] Update reference (#3723)
* update reference pipeline * update reference pipeline --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
@@ -97,7 +98,14 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
@@ -130,8 +138,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
@@ -223,15 +231,12 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
@@ -266,6 +271,9 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
@@ -274,6 +282,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
@@ -289,6 +298,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
@@ -308,6 +318,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
height, width = image[0].shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -720,14 +731,15 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
controlnet_latent_model_input = latents
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
controlnet_latent_model_input = latent_model_input
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
controlnet_latent_model_input,
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
|
||||
@@ -9,6 +9,7 @@ from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
|
||||
|
||||
@@ -179,6 +180,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
attention_auto_machine_weight: float = 1.0,
|
||||
gn_auto_machine_weight: float = 1.0,
|
||||
style_fidelity: float = 0.5,
|
||||
@@ -248,6 +250,11 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
attention_auto_machine_weight (`float`):
|
||||
Weight of using reference query for self attention's context.
|
||||
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
|
||||
@@ -295,6 +302,9 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
@@ -303,6 +313,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Preprocess reference image
|
||||
@@ -748,6 +759,10 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user