949 lines
44 KiB
Python
949 lines
44 KiB
Python
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
|
|
|
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
|
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
|
from ...models import AutoencoderKL, ChromaTransformer2DModel
|
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
|
from ...utils import (
|
|
USE_PEFT_BACKEND,
|
|
is_torch_xla_available,
|
|
logging,
|
|
replace_example_docstring,
|
|
scale_lora_layers,
|
|
unscale_lora_layers,
|
|
)
|
|
from ...utils.torch_utils import randn_tensor
|
|
from ..pipeline_utils import DiffusionPipeline
|
|
from .pipeline_output import ChromaPipelineOutput
|
|
|
|
|
|
if is_torch_xla_available():
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
XLA_AVAILABLE = True
|
|
else:
|
|
XLA_AVAILABLE = False
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
EXAMPLE_DOC_STRING = """
|
|
Examples:
|
|
```py
|
|
>>> import torch
|
|
>>> from diffusers import ChromaPipeline
|
|
|
|
>>> model_id = "lodestones/Chroma"
|
|
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
|
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
|
>>> pipe = ChromaPipeline.from_pretrained(
|
|
... model_id,
|
|
... transformer=transformer,
|
|
... torch_dtype=torch.bfloat16,
|
|
... )
|
|
>>> pipe.enable_model_cpu_offload()
|
|
>>> prompt = [
|
|
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
|
... ]
|
|
>>> negative_prompt = [
|
|
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
|
... ]
|
|
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
|
>>> image.save("chroma.png")
|
|
```
|
|
"""
|
|
|
|
|
|
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
|
def calculate_shift(
|
|
image_seq_len,
|
|
base_seq_len: int = 256,
|
|
max_seq_len: int = 4096,
|
|
base_shift: float = 0.5,
|
|
max_shift: float = 1.15,
|
|
):
|
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
|
b = base_shift - m * base_seq_len
|
|
mu = image_seq_len * m + b
|
|
return mu
|
|
|
|
|
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
|
def retrieve_timesteps(
|
|
scheduler,
|
|
num_inference_steps: Optional[int] = None,
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
timesteps: Optional[List[int]] = None,
|
|
sigmas: Optional[List[float]] = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
|
Args:
|
|
scheduler (`SchedulerMixin`):
|
|
The scheduler to get timesteps from.
|
|
num_inference_steps (`int`):
|
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
|
must be `None`.
|
|
device (`str` or `torch.device`, *optional*):
|
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
timesteps (`List[int]`, *optional*):
|
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
|
`num_inference_steps` and `sigmas` must be `None`.
|
|
sigmas (`List[float]`, *optional*):
|
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
|
`num_inference_steps` and `timesteps` must be `None`.
|
|
|
|
Returns:
|
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
|
second element is the number of inference steps.
|
|
"""
|
|
if timesteps is not None and sigmas is not None:
|
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
|
if timesteps is not None:
|
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accepts_timesteps:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
elif sigmas is not None:
|
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
|
if not accept_sigmas:
|
|
raise ValueError(
|
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
|
)
|
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
num_inference_steps = len(timesteps)
|
|
else:
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
|
timesteps = scheduler.timesteps
|
|
return timesteps, num_inference_steps
|
|
|
|
|
|
class ChromaPipeline(
|
|
DiffusionPipeline,
|
|
FluxLoraLoaderMixin,
|
|
FromSingleFileMixin,
|
|
TextualInversionLoaderMixin,
|
|
FluxIPAdapterMixin,
|
|
):
|
|
r"""
|
|
The Chroma pipeline for text-to-image generation.
|
|
|
|
Reference: https://huggingface.co/lodestones/Chroma/
|
|
|
|
Args:
|
|
transformer ([`ChromaTransformer2DModel`]):
|
|
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
|
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
|
vae ([`AutoencoderKL`]):
|
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
|
|
text_encoder ([`T5EncoderModel`]):
|
|
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
|
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
|
tokenizer (`T5TokenizerFast`):
|
|
Second Tokenizer of class
|
|
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
|
"""
|
|
|
|
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
|
_optional_components = ["image_encoder", "feature_extractor"]
|
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
|
|
|
def __init__(
|
|
self,
|
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
|
vae: AutoencoderKL,
|
|
text_encoder: T5EncoderModel,
|
|
tokenizer: T5TokenizerFast,
|
|
transformer: ChromaTransformer2DModel,
|
|
image_encoder: CLIPVisionModelWithProjection = None,
|
|
feature_extractor: CLIPImageProcessor = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.register_modules(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
transformer=transformer,
|
|
scheduler=scheduler,
|
|
image_encoder=image_encoder,
|
|
feature_extractor=feature_extractor,
|
|
)
|
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
|
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
|
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
|
self.default_sample_size = 128
|
|
|
|
def _get_t5_prompt_embeds(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
num_images_per_prompt: int = 1,
|
|
max_sequence_length: int = 512,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
device = device or self._execution_device
|
|
dtype = dtype or self.text_encoder.dtype
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
batch_size = len(prompt)
|
|
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
|
|
|
text_inputs = self.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
attention_mask = text_inputs.attention_mask.clone()
|
|
|
|
# Chroma requires the attention mask to include one padding token
|
|
seq_lengths = attention_mask.sum(dim=1)
|
|
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
|
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
|
|
|
prompt_embeds = self.text_encoder(
|
|
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
|
)[0]
|
|
|
|
dtype = self.text_encoder.dtype
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
|
|
|
_, seq_len, _ = prompt_embeds.shape
|
|
|
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
|
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
|
|
|
return prompt_embeds, attention_mask
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
negative_prompt: Union[str, List[str]] = None,
|
|
device: Optional[torch.device] = None,
|
|
num_images_per_prompt: int = 1,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
do_classifier_free_guidance: bool = True,
|
|
max_sequence_length: int = 512,
|
|
lora_scale: Optional[float] = None,
|
|
):
|
|
r"""
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
prompt to be encoded
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
|
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
lora_scale (`float`, *optional*):
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
|
"""
|
|
device = device or self._execution_device
|
|
|
|
# set lora scale so that monkey patched LoRA
|
|
# function of text encoder can correctly access it
|
|
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
|
self._lora_scale = lora_scale
|
|
|
|
# dynamically adjust the LoRA scale
|
|
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if prompt is not None:
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
if prompt_embeds is None:
|
|
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
|
prompt=prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
)
|
|
|
|
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
negative_text_ids = None
|
|
|
|
if do_classifier_free_guidance:
|
|
if negative_prompt_embeds is None:
|
|
negative_prompt = negative_prompt or ""
|
|
negative_prompt = (
|
|
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
|
)
|
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
|
raise TypeError(
|
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
f" {type(prompt)}."
|
|
)
|
|
elif batch_size != len(negative_prompt):
|
|
raise ValueError(
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
" the batch size of `prompt`."
|
|
)
|
|
|
|
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
|
prompt=negative_prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
)
|
|
|
|
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
return (
|
|
prompt_embeds,
|
|
text_ids,
|
|
prompt_attention_mask,
|
|
negative_prompt_embeds,
|
|
negative_text_ids,
|
|
negative_prompt_attention_mask,
|
|
)
|
|
|
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
|
def encode_image(self, image, device, num_images_per_prompt):
|
|
dtype = next(self.image_encoder.parameters()).dtype
|
|
|
|
if not isinstance(image, torch.Tensor):
|
|
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
|
|
|
image = image.to(device=device, dtype=dtype)
|
|
image_embeds = self.image_encoder(image).image_embeds
|
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
|
return image_embeds
|
|
|
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
|
def prepare_ip_adapter_image_embeds(
|
|
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
|
):
|
|
image_embeds = []
|
|
if ip_adapter_image_embeds is None:
|
|
if not isinstance(ip_adapter_image, list):
|
|
ip_adapter_image = [ip_adapter_image]
|
|
|
|
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
|
raise ValueError(
|
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
|
)
|
|
|
|
for single_ip_adapter_image in ip_adapter_image:
|
|
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
|
image_embeds.append(single_image_embeds[None, :])
|
|
else:
|
|
if not isinstance(ip_adapter_image_embeds, list):
|
|
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
|
|
|
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
|
raise ValueError(
|
|
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
|
)
|
|
|
|
for single_image_embeds in ip_adapter_image_embeds:
|
|
image_embeds.append(single_image_embeds)
|
|
|
|
ip_adapter_image_embeds = []
|
|
for single_image_embeds in image_embeds:
|
|
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
|
single_image_embeds = single_image_embeds.to(device=device)
|
|
ip_adapter_image_embeds.append(single_image_embeds)
|
|
|
|
return ip_adapter_image_embeds
|
|
|
|
def check_inputs(
|
|
self,
|
|
prompt,
|
|
height,
|
|
width,
|
|
negative_prompt=None,
|
|
prompt_embeds=None,
|
|
prompt_attention_mask=None,
|
|
negative_prompt_embeds=None,
|
|
negative_prompt_attention_mask=None,
|
|
callback_on_step_end_tensor_inputs=None,
|
|
max_sequence_length=None,
|
|
):
|
|
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
|
logger.warning(
|
|
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
|
)
|
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
|
):
|
|
raise ValueError(
|
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
|
)
|
|
|
|
if prompt is not None and prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
|
" only forward one of the two."
|
|
)
|
|
elif prompt is None and prompt_embeds is None:
|
|
raise ValueError(
|
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
|
)
|
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
|
raise ValueError(
|
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
)
|
|
|
|
if prompt_embeds is not None and prompt_attention_mask is None:
|
|
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
|
|
|
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
|
raise ValueError(
|
|
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
|
|
)
|
|
|
|
if max_sequence_length is not None and max_sequence_length > 512:
|
|
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
|
|
|
@staticmethod
|
|
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
|
latent_image_ids = torch.zeros(height, width, 3)
|
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
|
|
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
|
|
|
latent_image_ids = latent_image_ids.reshape(
|
|
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
|
)
|
|
|
|
return latent_image_ids.to(device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
|
|
|
return latents
|
|
|
|
@staticmethod
|
|
def _unpack_latents(latents, height, width, vae_scale_factor):
|
|
batch_size, num_patches, channels = latents.shape
|
|
|
|
# VAE applies 8x compression on images but we must also account for packing which requires
|
|
# latent height and width to be divisible by 2.
|
|
height = 2 * (int(height) // (vae_scale_factor * 2))
|
|
width = 2 * (int(width) // (vae_scale_factor * 2))
|
|
|
|
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
|
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
|
|
|
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
|
|
|
return latents
|
|
|
|
def enable_vae_slicing(self):
|
|
r"""
|
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
|
"""
|
|
self.vae.enable_slicing()
|
|
|
|
def disable_vae_slicing(self):
|
|
r"""
|
|
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
|
computing decoding in one step.
|
|
"""
|
|
self.vae.disable_slicing()
|
|
|
|
def enable_vae_tiling(self):
|
|
r"""
|
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
|
processing larger images.
|
|
"""
|
|
self.vae.enable_tiling()
|
|
|
|
def disable_vae_tiling(self):
|
|
r"""
|
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
|
computing decoding in one step.
|
|
"""
|
|
self.vae.disable_tiling()
|
|
|
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
|
def prepare_latents(
|
|
self,
|
|
batch_size,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
dtype,
|
|
device,
|
|
generator,
|
|
latents=None,
|
|
):
|
|
# VAE applies 8x compression on images but we must also account for packing which requires
|
|
# latent height and width to be divisible by 2.
|
|
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
|
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
|
|
|
shape = (batch_size, num_channels_latents, height, width)
|
|
|
|
if latents is not None:
|
|
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
|
return latents.to(device=device, dtype=dtype), latent_image_ids
|
|
|
|
if isinstance(generator, list) and len(generator) != batch_size:
|
|
raise ValueError(
|
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
|
)
|
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
|
|
|
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
|
|
|
return latents, latent_image_ids
|
|
|
|
def _prepare_attention_mask(
|
|
self,
|
|
batch_size,
|
|
sequence_length,
|
|
dtype,
|
|
attention_mask=None,
|
|
):
|
|
if attention_mask is None:
|
|
return attention_mask
|
|
|
|
# Extend the prompt attention mask to account for image tokens in the final sequence
|
|
attention_mask = torch.cat(
|
|
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
|
dim=1,
|
|
)
|
|
attention_mask = attention_mask.to(dtype)
|
|
|
|
return attention_mask
|
|
|
|
@property
|
|
def guidance_scale(self):
|
|
return self._guidance_scale
|
|
|
|
@property
|
|
def joint_attention_kwargs(self):
|
|
return self._joint_attention_kwargs
|
|
|
|
@property
|
|
def do_classifier_free_guidance(self):
|
|
return self._guidance_scale > 1
|
|
|
|
@property
|
|
def num_timesteps(self):
|
|
return self._num_timesteps
|
|
|
|
@property
|
|
def current_timestep(self):
|
|
return self._current_timestep
|
|
|
|
@property
|
|
def interrupt(self):
|
|
return self._interrupt
|
|
|
|
@torch.no_grad()
|
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
def __call__(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
negative_prompt: Union[str, List[str]] = None,
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
num_inference_steps: int = 35,
|
|
sigmas: Optional[List[float]] = None,
|
|
guidance_scale: float = 5.0,
|
|
num_images_per_prompt: Optional[int] = 1,
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
latents: Optional[torch.Tensor] = None,
|
|
prompt_embeds: Optional[torch.Tensor] = None,
|
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
|
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
|
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
output_type: Optional[str] = "pil",
|
|
return_dict: bool = True,
|
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
max_sequence_length: int = 512,
|
|
):
|
|
r"""
|
|
Function invoked when calling the pipeline for generation.
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
instead.
|
|
negative_prompt (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
not greater than `1`).
|
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
expense of slower inference.
|
|
sigmas (`List[float]`, *optional*):
|
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
|
will be used.
|
|
guidance_scale (`float`, *optional*, defaults to 3.5):
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
usually at the expense of lower image quality.
|
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
The number of images to generate per prompt.
|
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
|
to make generation deterministic.
|
|
latents (`torch.Tensor`, *optional*):
|
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
tensor will ge generated by sampling using the supplied random `generator`.
|
|
prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
|
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
|
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
|
negative_ip_adapter_image:
|
|
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
|
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
|
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
argument.
|
|
prompt_attention_mask (torch.Tensor, *optional*):
|
|
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
|
|
Chroma requires a single padding token remain unmasked. Please refer to
|
|
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
|
negative_prompt_attention_mask (torch.Tensor, *optional*):
|
|
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
|
|
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
|
|
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
The output format of the generate image. Choose between
|
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
return_dict (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
|
|
joint_attention_kwargs (`dict`, *optional*):
|
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
`self.processor` in
|
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
callback_on_step_end (`Callable`, *optional*):
|
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
|
`callback_on_step_end_tensor_inputs`.
|
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
|
`._callback_tensor_inputs` attribute of your pipeline class.
|
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
|
|
|
Examples:
|
|
|
|
Returns:
|
|
[`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
|
|
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
|
generated images.
|
|
"""
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor
|
|
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
self.check_inputs(
|
|
prompt,
|
|
height,
|
|
width,
|
|
negative_prompt=negative_prompt,
|
|
prompt_embeds=prompt_embeds,
|
|
prompt_attention_mask=prompt_attention_mask,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
|
max_sequence_length=max_sequence_length,
|
|
)
|
|
|
|
self._guidance_scale = guidance_scale
|
|
self._joint_attention_kwargs = joint_attention_kwargs
|
|
self._current_timestep = None
|
|
self._interrupt = False
|
|
|
|
# 2. Define call parameters
|
|
if prompt is not None and isinstance(prompt, str):
|
|
batch_size = 1
|
|
elif prompt is not None and isinstance(prompt, list):
|
|
batch_size = len(prompt)
|
|
else:
|
|
batch_size = prompt_embeds.shape[0]
|
|
|
|
device = self._execution_device
|
|
|
|
lora_scale = (
|
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
|
)
|
|
(
|
|
prompt_embeds,
|
|
text_ids,
|
|
prompt_attention_mask,
|
|
negative_prompt_embeds,
|
|
negative_text_ids,
|
|
negative_prompt_attention_mask,
|
|
) = self.encode_prompt(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
prompt_embeds=prompt_embeds,
|
|
negative_prompt_embeds=negative_prompt_embeds,
|
|
prompt_attention_mask=prompt_attention_mask,
|
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
device=device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
max_sequence_length=max_sequence_length,
|
|
lora_scale=lora_scale,
|
|
)
|
|
|
|
# 4. Prepare latent variables
|
|
num_channels_latents = self.transformer.config.in_channels // 4
|
|
latents, latent_image_ids = self.prepare_latents(
|
|
batch_size * num_images_per_prompt,
|
|
num_channels_latents,
|
|
height,
|
|
width,
|
|
prompt_embeds.dtype,
|
|
device,
|
|
generator,
|
|
latents,
|
|
)
|
|
|
|
# 5. Prepare timesteps
|
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
|
image_seq_len = latents.shape[1]
|
|
mu = calculate_shift(
|
|
image_seq_len,
|
|
self.scheduler.config.get("base_image_seq_len", 256),
|
|
self.scheduler.config.get("max_image_seq_len", 4096),
|
|
self.scheduler.config.get("base_shift", 0.5),
|
|
self.scheduler.config.get("max_shift", 1.15),
|
|
)
|
|
|
|
attention_mask = self._prepare_attention_mask(
|
|
batch_size=latents.shape[0],
|
|
sequence_length=image_seq_len,
|
|
dtype=latents.dtype,
|
|
attention_mask=prompt_attention_mask,
|
|
)
|
|
negative_attention_mask = self._prepare_attention_mask(
|
|
batch_size=latents.shape[0],
|
|
sequence_length=image_seq_len,
|
|
dtype=latents.dtype,
|
|
attention_mask=negative_prompt_attention_mask,
|
|
)
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps(
|
|
self.scheduler,
|
|
num_inference_steps,
|
|
device,
|
|
sigmas=sigmas,
|
|
mu=mu,
|
|
)
|
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
self._num_timesteps = len(timesteps)
|
|
|
|
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
|
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
|
):
|
|
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
|
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
|
|
|
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
|
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
|
):
|
|
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
|
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
|
|
|
if self.joint_attention_kwargs is None:
|
|
self._joint_attention_kwargs = {}
|
|
|
|
image_embeds = None
|
|
negative_image_embeds = None
|
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
|
image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
ip_adapter_image,
|
|
ip_adapter_image_embeds,
|
|
device,
|
|
batch_size * num_images_per_prompt,
|
|
)
|
|
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
|
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
|
negative_ip_adapter_image,
|
|
negative_ip_adapter_image_embeds,
|
|
device,
|
|
batch_size * num_images_per_prompt,
|
|
)
|
|
|
|
# 6. Denoising loop
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
for i, t in enumerate(timesteps):
|
|
if self.interrupt:
|
|
continue
|
|
|
|
self._current_timestep = t
|
|
if image_embeds is not None:
|
|
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
|
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
|
|
|
noise_pred = self.transformer(
|
|
hidden_states=latents,
|
|
timestep=timestep / 1000,
|
|
encoder_hidden_states=prompt_embeds,
|
|
txt_ids=text_ids,
|
|
img_ids=latent_image_ids,
|
|
attention_mask=attention_mask,
|
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
if self.do_classifier_free_guidance:
|
|
if negative_image_embeds is not None:
|
|
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
|
neg_noise_pred = self.transformer(
|
|
hidden_states=latents,
|
|
timestep=timestep / 1000,
|
|
encoder_hidden_states=negative_prompt_embeds,
|
|
txt_ids=negative_text_ids,
|
|
img_ids=latent_image_ids,
|
|
attention_mask=negative_attention_mask,
|
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
|
return_dict=False,
|
|
)[0]
|
|
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
latents_dtype = latents.dtype
|
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
|
|
|
if latents.dtype != latents_dtype:
|
|
if torch.backends.mps.is_available():
|
|
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
latents = latents.to(latents_dtype)
|
|
|
|
if callback_on_step_end is not None:
|
|
callback_kwargs = {}
|
|
for k in callback_on_step_end_tensor_inputs:
|
|
callback_kwargs[k] = locals()[k]
|
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
|
latents = callback_outputs.pop("latents", latents)
|
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
|
|
|
# call the callback, if provided
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
progress_bar.update()
|
|
|
|
if XLA_AVAILABLE:
|
|
xm.mark_step()
|
|
|
|
self._current_timestep = None
|
|
|
|
if output_type == "latent":
|
|
image = latents
|
|
else:
|
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
|
image = self.vae.decode(latents, return_dict=False)[0]
|
|
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
# Offload all models
|
|
self.maybe_free_model_hooks()
|
|
|
|
if not return_dict:
|
|
return (image,)
|
|
|
|
return ChromaPipelineOutput(images=image)
|