363 lines
17 KiB
Python
363 lines
17 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
|
|
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
|
from ...models import UVit2DModel, VQModel
|
|
from ...schedulers import AmusedScheduler
|
|
from ...utils import is_torch_xla_available, replace_example_docstring
|
|
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
|
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
|
|
EXAMPLE_DOC_STRING = """
|
|
Examples:
|
|
```py
|
|
>>> import torch
|
|
>>> from diffusers import AmusedImg2ImgPipeline
|
|
>>> from diffusers.utils import load_image
|
|
|
|
>>> pipe = AmusedImg2ImgPipeline.from_pretrained(
|
|
... "amused/amused-512", variant="fp16", torch_dtype=torch.float16
|
|
... )
|
|
>>> pipe = pipe.to("cuda")
|
|
|
|
>>> prompt = "winter mountains"
|
|
>>> input_image = (
|
|
... load_image(
|
|
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg"
|
|
... )
|
|
... .resize((512, 512))
|
|
... .convert("RGB")
|
|
... )
|
|
>>> image = pipe(prompt, input_image).images[0]
|
|
```
|
|
"""
|
|
|
|
|
|
class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
|
_last_supported_version = "0.33.1"
|
|
image_processor: VaeImageProcessor
|
|
vqvae: VQModel
|
|
tokenizer: CLIPTokenizer
|
|
text_encoder: CLIPTextModelWithProjection
|
|
transformer: UVit2DModel
|
|
scheduler: AmusedScheduler
|
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
|
|
|
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
|
|
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
|
|
# off the meta device. There should be a way to fix this instead of just not offloading it
|
|
_exclude_from_cpu_offload = ["vqvae"]
|
|
|
|
def __init__(
|
|
self,
|
|
vqvae: VQModel,
|
|
tokenizer: CLIPTokenizer,
|
|
text_encoder: CLIPTextModelWithProjection,
|
|
transformer: UVit2DModel,
|
|
scheduler: AmusedScheduler,
|
|
):
|
|
super().__init__()
|
|
|
|
self.register_modules(
|
|
vqvae=vqvae,
|
|
tokenizer=tokenizer,
|
|
text_encoder=text_encoder,
|
|
transformer=transformer,
|
|
scheduler=scheduler,
|
|
)
|
|
self.vae_scale_factor = (
|
|
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
|
|
)
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
|
|
|
@torch.no_grad()
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
def __call__(
|
|
self,
|
|
prompt: Optional[Union[List[str], str]] = None,
|
|
image: PipelineImageInput = None,
|
|
strength: float = 0.5,
|
|
num_inference_steps: int = 12,
|
|
guidance_scale: float = 10.0,
|
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
generator: Optional[torch.Generator] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
output_type="pil",
|
|
return_dict: bool = True,
|
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
|
callback_steps: int = 1,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
micro_conditioning_aesthetic_score: int = 6,
|
|
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
|
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
|
):
|
|
"""
|
|
The call function to the pipeline for generation.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
|
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
|
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
|
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
|
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
|
latents as `image`, but if passing latents directly it is not encoded again.
|
|
strength (`float`, *optional*, defaults to 0.5):
|
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
|
essentially ignores `image`.
|
|
num_inference_steps (`int`, *optional*, defaults to 12):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
guidance_scale (`float`, *optional*, defaults to 10.0):
|
|
A higher guidance scale value encourages the model to generate images closely linked to the text
|
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
generator (`torch.Generator`, *optional*):
|
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
|
generation deterministic.
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
|
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
|
pooled and projected final hidden states.
|
|
encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
|
negative_encoder_hidden_states (`torch.Tensor`, *optional*):
|
|
Analogous to `encoder_hidden_states` for the positive prompt.
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
plain tuple.
|
|
callback (`Callable`, *optional*):
|
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
|
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
|
callback_steps (`int`, *optional*, defaults to 1):
|
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
|
every step.
|
|
cross_attention_kwargs (`dict`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
|
The targeted aesthetic score according to the laion aesthetic classifier. See
|
|
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
|
|
https://huggingface.co/papers/2307.01952.
|
|
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
|
The targeted height, width crop coordinates. See the micro-conditioning section of
|
|
https://huggingface.co/papers/2307.01952.
|
|
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
|
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
|
|
|
Examples:
|
|
|
|
Returns:
|
|
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
|
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
|
`tuple` is returned where the first element is a list with the generated images.
|
|
"""
|
|
|
|
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
|
prompt_embeds is None and encoder_hidden_states is not None
|
|
):
|
|
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
|
|
|
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
|
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
|
):
|
|
raise ValueError(
|
|
"pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
|
)
|
|
|
|
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
|
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
|
|
|
if isinstance(prompt, str):
|
|
prompt = [prompt]
|
|
|
|
if prompt is not None:
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
batch_size = batch_size * num_images_per_prompt
|
|
|
|
if prompt_embeds is None:
|
|
input_ids = self.tokenizer(
|
|
prompt,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.tokenizer.model_max_length,
|
|
).input_ids.to(self._execution_device)
|
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
|
prompt_embeds = outputs.text_embeds
|
|
encoder_hidden_states = outputs.hidden_states[-2]
|
|
|
|
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
|
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
|
|
|
if guidance_scale > 1.0:
|
|
if negative_prompt_embeds is None:
|
|
if negative_prompt is None:
|
|
negative_prompt = [""] * len(prompt)
|
|
|
|
if isinstance(negative_prompt, str):
|
|
negative_prompt = [negative_prompt]
|
|
|
|
input_ids = self.tokenizer(
|
|
negative_prompt,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.tokenizer.model_max_length,
|
|
).input_ids.to(self._execution_device)
|
|
|
|
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
|
negative_prompt_embeds = outputs.text_embeds
|
|
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
|
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
|
|
|
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
|
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
|
|
|
image = self.image_processor.preprocess(image)
|
|
|
|
height, width = image.shape[-2:]
|
|
|
|
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
|
# and the crop coordinates. This is how it was done in the original code base
|
|
micro_conds = torch.tensor(
|
|
[
|
|
width,
|
|
height,
|
|
micro_conditioning_crop_coord[0],
|
|
micro_conditioning_crop_coord[1],
|
|
micro_conditioning_aesthetic_score,
|
|
],
|
|
device=self._execution_device,
|
|
dtype=encoder_hidden_states.dtype,
|
|
)
|
|
|
|
micro_conds = micro_conds.unsqueeze(0)
|
|
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
|
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
|
|
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
|
|
|
|
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
|
|
|
if needs_upcasting:
|
|
self.vqvae.float()
|
|
|
|
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
|
|
latents_bsz, channels, latents_height, latents_width = latents.shape
|
|
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
|
|
latents = self.scheduler.add_noise(
|
|
latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
|
|
)
|
|
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
|
|
timestep = self.scheduler.timesteps[i]
|
|
|
|
if guidance_scale > 1.0:
|
|
model_input = torch.cat([latents] * 2)
|
|
else:
|
|
model_input = latents
|
|
|
|
model_output = self.transformer(
|
|
model_input,
|
|
micro_conds=micro_conds,
|
|
pooled_text_emb=prompt_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
)
|
|
|
|
if guidance_scale > 1.0:
|
|
uncond_logits, cond_logits = model_output.chunk(2)
|
|
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
|
|
|
latents = self.scheduler.step(
|
|
model_output=model_output,
|
|
timestep=timestep,
|
|
sample=latents,
|
|
generator=generator,
|
|
).prev_sample
|
|
|
|
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
if callback is not None and i % callback_steps == 0:
|
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
|
callback(step_idx, timestep, latents)
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
if output_type == "latent":
|
|
output = latents
|
|
else:
|
|
output = self.vqvae.decode(
|
|
latents,
|
|
force_not_quantize=True,
|
|
shape=(
|
|
batch_size,
|
|
height // self.vae_scale_factor,
|
|
width // self.vae_scale_factor,
|
|
self.vqvae.config.latent_channels,
|
|
),
|
|
).sample.clip(0, 1)
|
|
output = self.image_processor.postprocess(output, output_type)
|
|
|
|
if needs_upcasting:
|
|
self.vqvae.half()
|
|
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (output,)
|
|
|
|
return ImagePipelineOutput(output)
|