1338 lines
66 KiB
Python
1338 lines
66 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import PIL.Image
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
|
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
|
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
|
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
|
from ...models.lora import adjust_lora_scale_text_encoder
|
|
from ...schedulers import KarrasDiffusionSchedulers
|
|
from ...utils import (
|
|
USE_PEFT_BACKEND,
|
|
deprecate,
|
|
is_torch_xla_available,
|
|
logging,
|
|
replace_example_docstring,
|
|
scale_lora_layers,
|
|
unscale_lora_layers,
|
|
)
|
|
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
|
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
EXAMPLE_DOC_STRING = """
|
|
Examples:
|
|
```py
|
|
>>> # !pip install opencv-python transformers accelerate
|
|
>>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
|
|
>>> from diffusers.utils import load_image
|
|
>>> import numpy as np
|
|
>>> import torch
|
|
|
|
>>> import cv2
|
|
>>> from PIL import Image
|
|
|
|
>>> # download an image
|
|
>>> image = load_image(
|
|
... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
|
... )
|
|
>>> np_image = np.array(image)
|
|
|
|
>>> # get canny image
|
|
>>> np_image = cv2.Canny(np_image, 100, 200)
|
|
>>> np_image = np_image[:, :, None]
|
|
>>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
|
|
>>> canny_image = Image.fromarray(np_image)
|
|
|
|
>>> # load control net and stable diffusion v1-5
|
|
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
|
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
|
... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
|
... )
|
|
|
|
>>> # speed up diffusion process with faster scheduler and memory optimization
|
|
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
>>> pipe.enable_model_cpu_offload()
|
|
|
|
>>> # generate image
|
|
>>> generator = torch.manual_seed(0)
|
|
>>> image = pipe(
|
|
... "futuristic-looking woman",
|
|
... num_inference_steps=20,
|
|
... generator=generator,
|
|
... image=image,
|
|
... control_image=canny_image,
|
|
... ).images[0]
|
|
```
|
|
"""
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
|
def retrieve_latents(
|
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
|
):
|
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
|
return encoder_output.latent_dist.sample(generator)
|
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
|
return encoder_output.latent_dist.mode()
|
|
elif hasattr(encoder_output, "latents"):
|
|
return encoder_output.latents
|
|
else:
|
|
raise AttributeError("Could not access latents of provided encoder_output")
|
|
|
|
|
|
def prepare_image(image):
|
|
if isinstance(image, torch.Tensor):
|
|
# Batch single image
|
|
if image.ndim == 3:
|
|
image = image.unsqueeze(0)
|
|
|
|
image = image.to(dtype=torch.float32)
|
|
else:
|
|
# preprocess image
|
|
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
|
image = [image]
|
|
|
|
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
|
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
|
image = np.concatenate(image, axis=0)
|
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
|
image = np.concatenate([i[None, :] for i in image], axis=0)
|
|
|
|
image = image.transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
|
|
return image
|
|
|
|
|
|
class StableDiffusionControlNetImg2ImgPipeline(
|
|
DiffusionPipeline,
|
|
StableDiffusionMixin,
|
|
TextualInversionLoaderMixin,
|
|
StableDiffusionLoraLoaderMixin,
|
|
IPAdapterMixin,
|
|
FromSingleFileMixin,
|
|
):
|
|
r"""
|
|
Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance.
|
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
|
|
|
The pipeline also inherits the following loading methods:
|
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
|
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
|
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
|
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
|
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
|
|
|
Args:
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
|
text_encoder ([`~transformers.CLIPTextModel`]):
|
|
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
|
tokenizer ([`~transformers.CLIPTokenizer`]):
|
|
A `CLIPTokenizer` to tokenize text.
|
|
unet ([`UNet2DConditionModel`]):
|
|
A `UNet2DConditionModel` to denoise the encoded image latents.
|
|
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
|
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
|
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
|
additional conditioning.
|
|
scheduler ([`SchedulerMixin`]):
|
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
|
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
|
|
more details about a model's potential harms.
|
|
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
|
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
|
"""
|
|
|
|
model_cpu_offload_seq = "text_encoder->unet->vae"
|
|
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
|
_exclude_from_cpu_offload = ["safety_checker"]
|
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
|
|
|
|
def __init__(
|
|
self,
|
|
vae: AutoencoderKL,
|
|
text_encoder: CLIPTextModel,
|
|
tokenizer: CLIPTokenizer,
|
|
unet: UNet2DConditionModel,
|
|
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
|
scheduler: KarrasDiffusionSchedulers,
|
|
safety_checker: StableDiffusionSafetyChecker,
|
|
feature_extractor: CLIPImageProcessor,
|
|
image_encoder: CLIPVisionModelWithProjection = None,
|
|
requires_safety_checker: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
if safety_checker is None and requires_safety_checker:
|
|
logger.warning(
|
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
|
)
|
|
|
|
if safety_checker is not None and feature_extractor is None:
|
|
raise ValueError(
|
|
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
|
)
|
|
|
|
if isinstance(controlnet, (list, tuple)):
|
|
controlnet = MultiControlNetModel(controlnet)
|
|
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
unet=unet,
|
|
controlnet=controlnet,
|
|
scheduler=scheduler,
|
|
safety_checker=safety_checker,
|
|
feature_extractor=feature_extractor,
|
|
image_encoder=image_encoder,
|
|
)
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
|
self.control_image_processor = VaeImageProcessor(
|
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
|
)
|
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
|
def _encode_prompt(
|
|
self,
|
|
prompt,
|
|
device,
|
|
num_images_per_prompt,
|
|
do_classifier_free_guidance,
|
|
negative_prompt=None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
lora_scale: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
|
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
|
|
|
prompt_embeds_tuple = self.encode_prompt(
|
|
prompt=prompt,
|
|
device=device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
negative_prompt=negative_prompt,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
lora_scale=lora_scale,
|
|
**kwargs,
|
|
)
|
|
|
|
# concatenate for backwards comp
|
|
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
|
|
|
return prompt_embeds
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
|
def encode_prompt(
|
|
self,
|
|
prompt,
|
|
device,
|
|
num_images_per_prompt,
|
|
do_classifier_free_guidance,
|
|
negative_prompt=None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
lora_scale: Optional[float] = None,
|
|
clip_skip: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Encodes the prompt into text encoder hidden states.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
prompt to be encoded
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
do_classifier_free_guidance (`bool`):
|
|
whether to use classifier free guidance or not
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
less than `1`).
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
argument.
|
|
lora_scale (`float`, *optional*):
|
|
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
|
clip_skip (`int`, *optional*):
|
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
"""
|
|
# set lora scale so that monkey patched LoRA
|
|
# function of text encoder can correctly access it
|
|
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
|
self._lora_scale = lora_scale
|
|
|
|
# dynamically adjust the LoRA scale
|
|
if not USE_PEFT_BACKEND:
|
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
|
else:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if prompt is not None and isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
if prompt_embeds is None:
|
|
# textual inversion: process multi-vector tokens if necessary
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
|
|
|
text_inputs = self.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=self.tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
text_input_ids, untruncated_ids
|
|
):
|
|
removed_text = self.tokenizer.batch_decode(
|
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
)
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
|
attention_mask = text_inputs.attention_mask.to(device)
|
|
else:
|
|
attention_mask = None
|
|
|
|
if clip_skip is None:
|
|
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
|
prompt_embeds = prompt_embeds[0]
|
|
else:
|
|
prompt_embeds = self.text_encoder(
|
|
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
|
)
|
|
# Access the `hidden_states` first, that contains a tuple of
|
|
# all the hidden states from the encoder layers. Then index into
|
|
# the tuple to access the hidden states from the desired layer.
|
|
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
|
# We also need to apply the final LayerNorm here to not mess with the
|
|
# representations. The `last_hidden_states` that we typically use for
|
|
# obtaining the final prompt representations passes through the LayerNorm
|
|
# layer.
|
|
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
|
|
|
if self.text_encoder is not None:
|
|
prompt_embeds_dtype = self.text_encoder.dtype
|
|
elif self.unet is not None:
|
|
prompt_embeds_dtype = self.unet.dtype
|
|
else:
|
|
prompt_embeds_dtype = prompt_embeds.dtype
|
|
|
|
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
# get unconditional embeddings for classifier free guidance
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
uncond_tokens: List[str]
|
|
if negative_prompt is None:
|
|
uncond_tokens = [""] * batch_size
|
|
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif isinstance(negative_prompt, str):
|
|
uncond_tokens = [negative_prompt]
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
else:
|
|
uncond_tokens = negative_prompt
|
|
|
|
# textual inversion: process multi-vector tokens if necessary
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
|
|
|
max_length = prompt_embeds.shape[1]
|
|
uncond_input = self.tokenizer(
|
|
uncond_tokens,
|
|
padding="max_length",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
|
attention_mask = uncond_input.attention_mask.to(device)
|
|
else:
|
|
attention_mask = None
|
|
|
|
negative_prompt_embeds = self.text_encoder(
|
|
uncond_input.input_ids.to(device),
|
|
attention_mask=attention_mask,
|
|
)
|
|
negative_prompt_embeds = negative_prompt_embeds[0]
|
|
|
|
if do_classifier_free_guidance:
|
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
seq_len = negative_prompt_embeds.shape[1]
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
return prompt_embeds, negative_prompt_embeds
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
|
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
|
dtype = next(self.image_encoder.parameters()).dtype
|
|
|
|
if not isinstance(image, torch.Tensor):
|
|
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
if output_hidden_states:
|
|
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
|
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
|
uncond_image_enc_hidden_states = self.image_encoder(
|
|
torch.zeros_like(image), output_hidden_states=True
|
|
).hidden_states[-2]
|
|
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
|
num_images_per_prompt, dim=0
|
|
)
|
|
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
|
else:
|
|
image_embeds = self.image_encoder(image).image_embeds
|
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
|
uncond_image_embeds = torch.zeros_like(image_embeds)
|
|
|
|
return image_embeds, uncond_image_embeds
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
|
def prepare_ip_adapter_image_embeds(
|
|
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
|
):
|
|
image_embeds = []
|
|
if do_classifier_free_guidance:
|
|
negative_image_embeds = []
|
|
if ip_adapter_image_embeds is None:
|
|
if not isinstance(ip_adapter_image, list):
|
|
ip_adapter_image = [ip_adapter_image]
|
|
|
|
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
|
raise ValueError(
|
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
|
)
|
|
|
|
for single_ip_adapter_image, image_proj_layer in zip(
|
|
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
|
):
|
|
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
|
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
|
single_ip_adapter_image, device, 1, output_hidden_state
|
|
)
|
|
|
|
image_embeds.append(single_image_embeds[None, :])
|
|
if do_classifier_free_guidance:
|
|
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
|
else:
|
|
for single_image_embeds in ip_adapter_image_embeds:
|
|
if do_classifier_free_guidance:
|
|
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
|
negative_image_embeds.append(single_negative_image_embeds)
|
|
image_embeds.append(single_image_embeds)
|
|
|
|
ip_adapter_image_embeds = []
|
|
for i, single_image_embeds in enumerate(image_embeds):
|
|
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
|
if do_classifier_free_guidance:
|
|
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
|
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
|
|
|
single_image_embeds = single_image_embeds.to(device=device)
|
|
ip_adapter_image_embeds.append(single_image_embeds)
|
|
|
|
return ip_adapter_image_embeds
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
|
def run_safety_checker(self, image, device, dtype):
|
|
if self.safety_checker is None:
|
|
has_nsfw_concept = None
|
|
else:
|
|
if torch.is_tensor(image):
|
|
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
|
else:
|
|
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
|
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
|
image, has_nsfw_concept = self.safety_checker(
|
|
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
|
)
|
|
return image, has_nsfw_concept
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
|
def decode_latents(self, latents):
|
|
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
|
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
|
|
|
latents = 1 / self.vae.config.scaling_factor * latents
|
|
image = self.vae.decode(latents, return_dict=False)[0]
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
return image
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
|
def prepare_extra_step_kwargs(self, generator, eta):
|
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
|
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
|
# and should be between [0, 1]
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
extra_step_kwargs = {}
|
|
if accepts_eta:
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
# check if the scheduler accepts generator
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
|
if accepts_generator:
|
|
extra_step_kwargs["generator"] = generator
|
|
return extra_step_kwargs
|
|
|
|
def check_inputs(
|
|
self,
|
|
prompt,
|
|
image,
|
|
callback_steps,
|
|
negative_prompt=None,
|
|
prompt_embeds=None,
|
|
negative_prompt_embeds=None,
|
|
ip_adapter_image=None,
|
|
ip_adapter_image_embeds=None,
|
|
controlnet_conditioning_scale=1.0,
|
|
control_guidance_start=0.0,
|
|
control_guidance_end=1.0,
|
|
callback_on_step_end_tensor_inputs=None,
|
|
):
|
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
|
raise ValueError(
|
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
f" {type(callback_steps)}."
|
|
)
|
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
|
):
|
|
raise ValueError(
|
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
|
)
|
|
|
|
if prompt is not None and prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
|
" only forward one of the two."
|
|
)
|
|
elif prompt is None and prompt_embeds is None:
|
|
raise ValueError(
|
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
|
)
|
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
)
|
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
|
raise ValueError(
|
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
|
f" {negative_prompt_embeds.shape}."
|
|
)
|
|
|
|
# `prompt` needs more sophisticated handling when there are multiple
|
|
# conditionings.
|
|
if isinstance(self.controlnet, MultiControlNetModel):
|
|
if isinstance(prompt, list):
|
|
logger.warning(
|
|
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
|
" prompts. The conditionings will be fixed across the prompts."
|
|
)
|
|
|
|
# Check `image`
|
|
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
|
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
|
)
|
|
if (
|
|
isinstance(self.controlnet, ControlNetModel)
|
|
or is_compiled
|
|
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
|
):
|
|
self.check_image(image, prompt, prompt_embeds)
|
|
elif (
|
|
isinstance(self.controlnet, MultiControlNetModel)
|
|
or is_compiled
|
|
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
|
):
|
|
if not isinstance(image, list):
|
|
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
|
|
|
# When `image` is a nested list:
|
|
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
|
elif any(isinstance(i, list) for i in image):
|
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
|
elif len(image) != len(self.controlnet.nets):
|
|
raise ValueError(
|
|
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
|
)
|
|
|
|
for image_ in image:
|
|
self.check_image(image_, prompt, prompt_embeds)
|
|
else:
|
|
assert False
|
|
|
|
# Check `controlnet_conditioning_scale`
|
|
if (
|
|
isinstance(self.controlnet, ControlNetModel)
|
|
or is_compiled
|
|
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
|
):
|
|
if not isinstance(controlnet_conditioning_scale, float):
|
|
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
|
elif (
|
|
isinstance(self.controlnet, MultiControlNetModel)
|
|
or is_compiled
|
|
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
|
|
):
|
|
if isinstance(controlnet_conditioning_scale, list):
|
|
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
|
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
|
|
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
|
self.controlnet.nets
|
|
):
|
|
raise ValueError(
|
|
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
|
" the same length as the number of controlnets"
|
|
)
|
|
else:
|
|
assert False
|
|
|
|
if len(control_guidance_start) != len(control_guidance_end):
|
|
raise ValueError(
|
|
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
|
)
|
|
|
|
if isinstance(self.controlnet, MultiControlNetModel):
|
|
if len(control_guidance_start) != len(self.controlnet.nets):
|
|
raise ValueError(
|
|
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
|
)
|
|
|
|
for start, end in zip(control_guidance_start, control_guidance_end):
|
|
if start >= end:
|
|
raise ValueError(
|
|
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
|
)
|
|
if start < 0.0:
|
|
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
|
if end > 1.0:
|
|
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
|
|
|
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
|
raise ValueError(
|
|
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
|
)
|
|
|
|
if ip_adapter_image_embeds is not None:
|
|
if not isinstance(ip_adapter_image_embeds, list):
|
|
raise ValueError(
|
|
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
|
)
|
|
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
|
raise ValueError(
|
|
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
|
)
|
|
|
|
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
|
def check_image(self, image, prompt, prompt_embeds):
|
|
image_is_pil = isinstance(image, PIL.Image.Image)
|
|
image_is_tensor = isinstance(image, torch.Tensor)
|
|
image_is_np = isinstance(image, np.ndarray)
|
|
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
|
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
|
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
|
|
|
if (
|
|
not image_is_pil
|
|
and not image_is_tensor
|
|
and not image_is_np
|
|
and not image_is_pil_list
|
|
and not image_is_tensor_list
|
|
and not image_is_np_list
|
|
):
|
|
raise TypeError(
|
|
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
|
)
|
|
|
|
if image_is_pil:
|
|
image_batch_size = 1
|
|
else:
|
|
image_batch_size = len(image)
|
|
|
|
if prompt is not None and isinstance(prompt, str):
|
|
prompt_batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
prompt_batch_size = len(prompt)
|
|
elif prompt_embeds is not None:
|
|
prompt_batch_size = prompt_embeds.shape[0]
|
|
|
|
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
|
raise ValueError(
|
|
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
|
)
|
|
|
|
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
|
def prepare_control_image(
|
|
self,
|
|
image,
|
|
width,
|
|
height,
|
|
batch_size,
|
|
num_images_per_prompt,
|
|
device,
|
|
dtype,
|
|
do_classifier_free_guidance=False,
|
|
guess_mode=False,
|
|
):
|
|
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
|
image_batch_size = image.shape[0]
|
|
|
|
if image_batch_size == 1:
|
|
repeat_by = batch_size
|
|
else:
|
|
# image batch size is the same as prompt batch size
|
|
repeat_by = num_images_per_prompt
|
|
|
|
image = image.repeat_interleave(repeat_by, dim=0)
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
|
|
if do_classifier_free_guidance and not guess_mode:
|
|
image = torch.cat([image] * 2)
|
|
|
|
return image
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
|
def get_timesteps(self, num_inference_steps, strength, device):
|
|
# get the original timestep using init_timestep
|
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
|
|
|
t_start = max(num_inference_steps - init_timestep, 0)
|
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
|
if hasattr(self.scheduler, "set_begin_index"):
|
|
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
|
|
|
return timesteps, num_inference_steps - t_start
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
|
|
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
|
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
|
raise ValueError(
|
|
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
|
)
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
|
|
batch_size = batch_size * num_images_per_prompt
|
|
|
|
if image.shape[1] == 4:
|
|
init_latents = image
|
|
|
|
else:
|
|
if isinstance(generator, list) and len(generator) != batch_size:
|
|
raise ValueError(
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
|
)
|
|
|
|
elif isinstance(generator, list):
|
|
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
|
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
|
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
|
raise ValueError(
|
|
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
|
)
|
|
|
|
init_latents = [
|
|
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
|
for i in range(batch_size)
|
|
]
|
|
init_latents = torch.cat(init_latents, dim=0)
|
|
else:
|
|
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
|
|
|
init_latents = self.vae.config.scaling_factor * init_latents
|
|
|
|
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
|
# expand init_latents for batch_size
|
|
deprecation_message = (
|
|
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
|
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
|
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
|
" your script to pass as many initial images as text prompts to suppress this warning."
|
|
)
|
|
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
|
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
|
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
|
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
|
raise ValueError(
|
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
|
)
|
|
else:
|
|
init_latents = torch.cat([init_latents], dim=0)
|
|
|
|
shape = init_latents.shape
|
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
# get latents
|
|
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
|
latents = init_latents
|
|
|
|
return latents
|
|
|
|
@property
|
|
def guidance_scale(self):
|
|
return self._guidance_scale
|
|
|
|
@property
|
|
def clip_skip(self):
|
|
return self._clip_skip
|
|
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
|
# corresponds to doing no classifier free guidance.
|
|
@property
|
|
def do_classifier_free_guidance(self):
|
|
return self._guidance_scale > 1
|
|
|
|
@property
|
|
def cross_attention_kwargs(self):
|
|
return self._cross_attention_kwargs
|
|
|
|
@property
|
|
def num_timesteps(self):
|
|
return self._num_timesteps
|
|
|
|
@property
|
|
def interrupt(self):
|
|
return self._interrupt
|
|
|
|
@torch.no_grad()
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
def __call__(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
image: PipelineImageInput = None,
|
|
control_image: PipelineImageInput = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
strength: float = 0.8,
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 7.5,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
eta: float = 0.0,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
latents: Optional[torch.Tensor] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
|
|
guess_mode: bool = False,
|
|
control_guidance_start: Union[float, List[float]] = 0.0,
|
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
|
clip_skip: Optional[int] = None,
|
|
callback_on_step_end: Optional[
|
|
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
|
] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
The call function to the pipeline for generation.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
|
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
|
The initial image to be used as the starting point for the image generation process. Can also accept
|
|
image latents as `image`, and if passing latents directly they are not encoded again.
|
|
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
|
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
|
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
|
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
|
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
|
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
|
images must be passed as a list such that each element of the list can be correctly batched for input
|
|
to a single ControlNet.
|
|
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
|
The height in pixels of the generated image.
|
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
|
The width in pixels of the generated image.
|
|
strength (`float`, *optional*, defaults to 0.8):
|
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
|
essentially ignores `image`.
|
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
|
A higher guidance scale value encourages the model to generate images closely linked to the text
|
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
eta (`float`, *optional*, defaults to 0.0):
|
|
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
|
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
|
generation deterministic.
|
|
latents (`torch.Tensor`, *optional*):
|
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
tensor is generated by sampling using the supplied random `generator`.
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
|
provided, text embeddings are generated from the `prompt` input argument.
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
|
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
|
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
|
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
plain tuple.
|
|
cross_attention_kwargs (`dict`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
|
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
|
the corresponding scale as a list.
|
|
guess_mode (`bool`, *optional*, defaults to `False`):
|
|
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
|
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
|
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
|
The percentage of total steps at which the ControlNet starts applying.
|
|
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
The percentage of total steps at which the ControlNet stops applying.
|
|
clip_skip (`int`, *optional*):
|
|
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
|
the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
|
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
|
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
|
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
|
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
|
|
|
Examples:
|
|
|
|
Returns:
|
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
|
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
|
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
|
"not-safe-for-work" (nsfw) content.
|
|
"""
|
|
|
|
callback = kwargs.pop("callback", None)
|
|
callback_steps = kwargs.pop("callback_steps", None)
|
|
|
|
if callback is not None:
|
|
deprecate(
|
|
"callback",
|
|
"1.0.0",
|
|
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
|
)
|
|
if callback_steps is not None:
|
|
deprecate(
|
|
"callback_steps",
|
|
"1.0.0",
|
|
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
|
)
|
|
|
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
|
|
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
|
|
|
# align format for control guidance
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
|
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
|
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
|
control_guidance_start, control_guidance_end = (
|
|
mult * [control_guidance_start],
|
|
mult * [control_guidance_end],
|
|
)
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(
|
|
prompt,
|
|
control_image,
|
|
callback_steps,
|
|
negative_prompt,
|
|
prompt_embeds,
|
|
negative_prompt_embeds,
|
|
ip_adapter_image,
|
|
ip_adapter_image_embeds,
|
|
controlnet_conditioning_scale,
|
|
control_guidance_start,
|
|
control_guidance_end,
|
|
callback_on_step_end_tensor_inputs,
|
|
)
|
|
|
|
self._guidance_scale = guidance_scale
|
|
self._clip_skip = clip_skip
|
|
self._cross_attention_kwargs = cross_attention_kwargs
|
|
self._interrupt = False
|
|
|
|
# 2. Define call parameters
|
|
if prompt is not None and isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
device = self._execution_device
|
|
|
|
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
|
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
|
|
|
global_pool_conditions = (
|
|
controlnet.config.global_pool_conditions
|
|
if isinstance(controlnet, ControlNetModel)
|
|
else controlnet.nets[0].config.global_pool_conditions
|
|
)
|
|
guess_mode = guess_mode or global_pool_conditions
|
|
|
|
# 3. Encode input prompt
|
|
text_encoder_lora_scale = (
|
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
|
)
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
|
prompt,
|
|
device,
|
|
num_images_per_prompt,
|
|
self.do_classifier_free_guidance,
|
|
negative_prompt,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
lora_scale=text_encoder_lora_scale,
|
|
clip_skip=self.clip_skip,
|
|
)
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
# to avoid doing two forward passes
|
|
if self.do_classifier_free_guidance:
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
ip_adapter_image,
|
|
ip_adapter_image_embeds,
|
|
device,
|
|
batch_size * num_images_per_prompt,
|
|
self.do_classifier_free_guidance,
|
|
)
|
|
|
|
# 4. Prepare image
|
|
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
|
|
|
# 5. Prepare controlnet_conditioning_image
|
|
if isinstance(controlnet, ControlNetModel):
|
|
control_image = self.prepare_control_image(
|
|
image=control_image,
|
|
width=width,
|
|
height=height,
|
|
batch_size=batch_size * num_images_per_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
device=device,
|
|
dtype=controlnet.dtype,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
guess_mode=guess_mode,
|
|
)
|
|
elif isinstance(controlnet, MultiControlNetModel):
|
|
control_images = []
|
|
|
|
for control_image_ in control_image:
|
|
control_image_ = self.prepare_control_image(
|
|
image=control_image_,
|
|
width=width,
|
|
height=height,
|
|
batch_size=batch_size * num_images_per_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
device=device,
|
|
dtype=controlnet.dtype,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
guess_mode=guess_mode,
|
|
)
|
|
|
|
control_images.append(control_image_)
|
|
|
|
control_image = control_images
|
|
else:
|
|
assert False
|
|
|
|
# 5. Prepare timesteps
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
# 6. Prepare latent variables
|
|
if latents is None:
|
|
latents = self.prepare_latents(
|
|
image,
|
|
latent_timestep,
|
|
batch_size,
|
|
num_images_per_prompt,
|
|
prompt_embeds.dtype,
|
|
device,
|
|
generator,
|
|
)
|
|
|
|
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
# 7.1 Add image embeds for IP-Adapter
|
|
added_cond_kwargs = (
|
|
{"image_embeds": image_embeds}
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
|
else None
|
|
)
|
|
|
|
# 7.2 Create tensor stating which controlnets to keep
|
|
controlnet_keep = []
|
|
for i in range(len(timesteps)):
|
|
keeps = [
|
|
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
|
for s, e in zip(control_guidance_start, control_guidance_end)
|
|
]
|
|
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
|
|
|
# 8. Denoising loop
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if self.interrupt:
|
|
continue
|
|
|
|
# expand the latents if we are doing classifier free guidance
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
# controlnet(s) inference
|
|
if guess_mode and self.do_classifier_free_guidance:
|
|
# Infer ControlNet only for the conditional batch.
|
|
control_model_input = latents
|
|
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
|
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
|
else:
|
|
control_model_input = latent_model_input
|
|
controlnet_prompt_embeds = prompt_embeds
|
|
|
|
if isinstance(controlnet_keep[i], list):
|
|
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
|
else:
|
|
controlnet_cond_scale = controlnet_conditioning_scale
|
|
if isinstance(controlnet_cond_scale, list):
|
|
controlnet_cond_scale = controlnet_cond_scale[0]
|
|
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
|
|
|
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
|
control_model_input,
|
|
t,
|
|
encoder_hidden_states=controlnet_prompt_embeds,
|
|
controlnet_cond=control_image,
|
|
conditioning_scale=cond_scale,
|
|
guess_mode=guess_mode,
|
|
return_dict=False,
|
|
)
|
|
|
|
if guess_mode and self.do_classifier_free_guidance:
|
|
# Inferred ControlNet only for the conditional batch.
|
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
# add 0 to the unconditional batch to keep it unchanged.
|
|
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
|
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
|
|
|
# predict the noise residual
|
|
noise_pred = self.unet(
|
|
latent_model_input,
|
|
t,
|
|
encoder_hidden_states=prompt_embeds,
|
|
cross_attention_kwargs=self.cross_attention_kwargs,
|
|
down_block_additional_residuals=down_block_res_samples,
|
|
mid_block_additional_residual=mid_block_res_sample,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# perform guidance
|
|
if self.do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
control_image = callback_outputs.pop("control_image", control_image)
|
|
|
|
# call the callback, if provided
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
if callback is not None and i % callback_steps == 0:
|
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
|
callback(step_idx, t, latents)
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
# If we do sequential model offloading, let's offload unet and controlnet
|
|
# manually for max memory savings
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
self.unet.to("cpu")
|
|
self.controlnet.to("cpu")
|
|
empty_device_cache()
|
|
|
|
if not output_type == "latent":
|
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
|
0
|
|
]
|
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
|
else:
|
|
image = latents
|
|
has_nsfw_concept = None
|
|
|
|
if has_nsfw_concept is None:
|
|
do_denormalize = [True] * image.shape[0]
|
|
else:
|
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
|
|
|
# Offload all models
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (image, has_nsfw_concept)
|
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|