691 lines
34 KiB
Python
691 lines
34 KiB
Python
![]() |
# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||
|
# and https://github.com/hojonathanho/diffusion
|
||
|
|
||
|
import math
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||
|
from ..schedulers.scheduling_utils import SchedulerMixin
|
||
|
from ..utils import BaseOutput, logging
|
||
|
from ..utils.torch_utils import randn_tensor
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class TCDSchedulerOutput(BaseOutput):
|
||
|
"""
|
||
|
Output class for the scheduler's `step` function output.
|
||
|
|
||
|
Args:
|
||
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||
|
denoising loop.
|
||
|
pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||
|
The predicted noised sample `(x_{s})` based on the model output from the current timestep.
|
||
|
"""
|
||
|
|
||
|
prev_sample: torch.Tensor
|
||
|
pred_noised_sample: Optional[torch.Tensor] = None
|
||
|
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||
|
def betas_for_alpha_bar(
|
||
|
num_diffusion_timesteps,
|
||
|
max_beta=0.999,
|
||
|
alpha_transform_type="cosine",
|
||
|
):
|
||
|
"""
|
||
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||
|
(1-beta) over time from t = [0,1].
|
||
|
|
||
|
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||
|
to that part of the diffusion process.
|
||
|
|
||
|
|
||
|
Args:
|
||
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||
|
prevent singularities.
|
||
|
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||
|
Choose from `cosine` or `exp`
|
||
|
|
||
|
Returns:
|
||
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||
|
"""
|
||
|
if alpha_transform_type == "cosine":
|
||
|
|
||
|
def alpha_bar_fn(t):
|
||
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||
|
|
||
|
elif alpha_transform_type == "exp":
|
||
|
|
||
|
def alpha_bar_fn(t):
|
||
|
return math.exp(t * -12.0)
|
||
|
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
||
|
|
||
|
betas = []
|
||
|
for i in range(num_diffusion_timesteps):
|
||
|
t1 = i / num_diffusion_timesteps
|
||
|
t2 = (i + 1) / num_diffusion_timesteps
|
||
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||
|
return torch.tensor(betas, dtype=torch.float32)
|
||
|
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||
|
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||
|
|
||
|
|
||
|
Args:
|
||
|
betas (`torch.Tensor`):
|
||
|
the betas that the scheduler is being initialized with.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`: rescaled betas with zero terminal SNR
|
||
|
"""
|
||
|
# Convert betas to alphas_bar_sqrt
|
||
|
alphas = 1.0 - betas
|
||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||
|
|
||
|
# Store old values.
|
||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||
|
|
||
|
# Shift so the last timestep is zero.
|
||
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||
|
|
||
|
# Scale so the first timestep is back to the old value.
|
||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||
|
|
||
|
# Convert alphas_bar_sqrt to betas
|
||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||
|
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||
|
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||
|
betas = 1 - alphas
|
||
|
|
||
|
return betas
|
||
|
|
||
|
|
||
|
class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||
|
"""
|
||
|
`TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency
|
||
|
Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
|
||
|
|
||
|
This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
|
||
|
|
||
|
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
|
||
|
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
|
||
|
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
|
||
|
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
|
||
|
|
||
|
Args:
|
||
|
num_train_timesteps (`int`, defaults to 1000):
|
||
|
The number of diffusion steps to train the model.
|
||
|
beta_start (`float`, defaults to 0.0001):
|
||
|
The starting `beta` value of inference.
|
||
|
beta_end (`float`, defaults to 0.02):
|
||
|
The final `beta` value.
|
||
|
beta_schedule (`str`, defaults to `"linear"`):
|
||
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||
|
trained_betas (`np.ndarray`, *optional*):
|
||
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||
|
original_inference_steps (`int`, *optional*, defaults to 50):
|
||
|
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
|
||
|
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
|
||
|
clip_sample (`bool`, defaults to `True`):
|
||
|
Clip the predicted sample for numerical stability.
|
||
|
clip_sample_range (`float`, defaults to 1.0):
|
||
|
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||
|
set_alpha_to_one (`bool`, defaults to `True`):
|
||
|
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||
|
otherwise it uses the alpha value at step 0.
|
||
|
steps_offset (`int`, defaults to 0):
|
||
|
An offset added to the inference steps, as required by some model families.
|
||
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||
|
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||
|
thresholding (`bool`, defaults to `False`):
|
||
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||
|
as Stable Diffusion.
|
||
|
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||
|
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||
|
sample_max_value (`float`, defaults to 1.0):
|
||
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||
|
timestep_spacing (`str`, defaults to `"leading"`):
|
||
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||
|
timestep_scaling (`float`, defaults to 10.0):
|
||
|
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
|
||
|
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
|
||
|
error at the default of `10.0` is already pretty small).
|
||
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||
|
"""
|
||
|
|
||
|
order = 1
|
||
|
|
||
|
@register_to_config
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_train_timesteps: int = 1000,
|
||
|
beta_start: float = 0.00085,
|
||
|
beta_end: float = 0.012,
|
||
|
beta_schedule: str = "scaled_linear",
|
||
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||
|
original_inference_steps: int = 50,
|
||
|
clip_sample: bool = False,
|
||
|
clip_sample_range: float = 1.0,
|
||
|
set_alpha_to_one: bool = True,
|
||
|
steps_offset: int = 0,
|
||
|
prediction_type: str = "epsilon",
|
||
|
thresholding: bool = False,
|
||
|
dynamic_thresholding_ratio: float = 0.995,
|
||
|
sample_max_value: float = 1.0,
|
||
|
timestep_spacing: str = "leading",
|
||
|
timestep_scaling: float = 10.0,
|
||
|
rescale_betas_zero_snr: bool = False,
|
||
|
):
|
||
|
if trained_betas is not None:
|
||
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||
|
elif beta_schedule == "linear":
|
||
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||
|
elif beta_schedule == "scaled_linear":
|
||
|
# this schedule is very specific to the latent diffusion model.
|
||
|
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||
|
elif beta_schedule == "squaredcos_cap_v2":
|
||
|
# Glide cosine schedule
|
||
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||
|
else:
|
||
|
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||
|
|
||
|
# Rescale for zero SNR
|
||
|
if rescale_betas_zero_snr:
|
||
|
self.betas = rescale_zero_terminal_snr(self.betas)
|
||
|
|
||
|
self.alphas = 1.0 - self.betas
|
||
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||
|
|
||
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||
|
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||
|
# whether we use the final alpha of the "non-previous" one.
|
||
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||
|
|
||
|
# standard deviation of the initial noise distribution
|
||
|
self.init_noise_sigma = 1.0
|
||
|
|
||
|
# setable values
|
||
|
self.num_inference_steps = None
|
||
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||
|
self.custom_timesteps = False
|
||
|
|
||
|
self._step_index = None
|
||
|
self._begin_index = None
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||
|
if schedule_timesteps is None:
|
||
|
schedule_timesteps = self.timesteps
|
||
|
|
||
|
indices = (schedule_timesteps == timestep).nonzero()
|
||
|
|
||
|
# The sigma index that is taken for the **very** first `step`
|
||
|
# is always the second index (or the last index if there is only 1)
|
||
|
# This way we can ensure we don't accidentally skip a sigma in
|
||
|
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||
|
pos = 1 if len(indices) > 1 else 0
|
||
|
|
||
|
return indices[pos].item()
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||
|
def _init_step_index(self, timestep):
|
||
|
if self.begin_index is None:
|
||
|
if isinstance(timestep, torch.Tensor):
|
||
|
timestep = timestep.to(self.timesteps.device)
|
||
|
self._step_index = self.index_for_timestep(timestep)
|
||
|
else:
|
||
|
self._step_index = self._begin_index
|
||
|
|
||
|
@property
|
||
|
def step_index(self):
|
||
|
return self._step_index
|
||
|
|
||
|
@property
|
||
|
def begin_index(self):
|
||
|
"""
|
||
|
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||
|
"""
|
||
|
return self._begin_index
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||
|
def set_begin_index(self, begin_index: int = 0):
|
||
|
"""
|
||
|
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||
|
|
||
|
Args:
|
||
|
begin_index (`int`):
|
||
|
The begin index for the scheduler.
|
||
|
"""
|
||
|
self._begin_index = begin_index
|
||
|
|
||
|
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||
|
"""
|
||
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||
|
current timestep.
|
||
|
|
||
|
Args:
|
||
|
sample (`torch.Tensor`):
|
||
|
The input sample.
|
||
|
timestep (`int`, *optional*):
|
||
|
The current timestep in the diffusion chain.
|
||
|
|
||
|
Returns:
|
||
|
`torch.Tensor`:
|
||
|
A scaled input sample.
|
||
|
"""
|
||
|
return sample
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
|
||
|
def _get_variance(self, timestep, prev_timestep):
|
||
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||
|
beta_prod_t = 1 - alpha_prod_t
|
||
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||
|
|
||
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||
|
|
||
|
return variance
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||
|
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||
|
"""
|
||
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||
|
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||
|
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||
|
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||
|
|
||
|
https://huggingface.co/papers/2205.11487
|
||
|
"""
|
||
|
dtype = sample.dtype
|
||
|
batch_size, channels, *remaining_dims = sample.shape
|
||
|
|
||
|
if dtype not in (torch.float32, torch.float64):
|
||
|
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||
|
|
||
|
# Flatten sample for doing quantile calculation along each image
|
||
|
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||
|
|
||
|
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||
|
|
||
|
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||
|
s = torch.clamp(
|
||
|
s, min=1, max=self.config.sample_max_value
|
||
|
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||
|
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||
|
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||
|
|
||
|
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||
|
sample = sample.to(dtype)
|
||
|
|
||
|
return sample
|
||
|
|
||
|
def set_timesteps(
|
||
|
self,
|
||
|
num_inference_steps: Optional[int] = None,
|
||
|
device: Union[str, torch.device] = None,
|
||
|
original_inference_steps: Optional[int] = None,
|
||
|
timesteps: Optional[List[int]] = None,
|
||
|
strength: float = 1.0,
|
||
|
):
|
||
|
"""
|
||
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||
|
|
||
|
Args:
|
||
|
num_inference_steps (`int`, *optional*):
|
||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||
|
`timesteps` must be `None`.
|
||
|
device (`str` or `torch.device`, *optional*):
|
||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||
|
original_inference_steps (`int`, *optional*):
|
||
|
The original number of inference steps, which will be used to generate a linearly-spaced timestep
|
||
|
schedule (which is different from the standard `diffusers` implementation). We will then take
|
||
|
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
||
|
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
||
|
timesteps (`List[int]`, *optional*):
|
||
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||
|
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
||
|
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
||
|
strength (`float`, *optional*, defaults to 1.0):
|
||
|
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
|
||
|
"""
|
||
|
# 0. Check inputs
|
||
|
if num_inference_steps is None and timesteps is None:
|
||
|
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
|
||
|
|
||
|
if num_inference_steps is not None and timesteps is not None:
|
||
|
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
||
|
|
||
|
# 1. Calculate the TCD original training/distillation timestep schedule.
|
||
|
original_steps = (
|
||
|
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
||
|
)
|
||
|
|
||
|
if original_inference_steps is None:
|
||
|
# default option, timesteps align with discrete inference steps
|
||
|
if original_steps > self.config.num_train_timesteps:
|
||
|
raise ValueError(
|
||
|
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
|
||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||
|
)
|
||
|
# TCD Timesteps Setting
|
||
|
# The skipping step parameter k from the paper.
|
||
|
k = self.config.num_train_timesteps // original_steps
|
||
|
# TCD Training/Distillation Steps Schedule
|
||
|
tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
|
||
|
else:
|
||
|
# customised option, sampled timesteps can be any arbitrary value
|
||
|
tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength))))
|
||
|
|
||
|
# 2. Calculate the TCD inference timestep schedule.
|
||
|
if timesteps is not None:
|
||
|
# 2.1 Handle custom timestep schedules.
|
||
|
train_timesteps = set(tcd_origin_timesteps)
|
||
|
non_train_timesteps = []
|
||
|
for i in range(1, len(timesteps)):
|
||
|
if timesteps[i] >= timesteps[i - 1]:
|
||
|
raise ValueError("`custom_timesteps` must be in descending order.")
|
||
|
|
||
|
if timesteps[i] not in train_timesteps:
|
||
|
non_train_timesteps.append(timesteps[i])
|
||
|
|
||
|
if timesteps[0] >= self.config.num_train_timesteps:
|
||
|
raise ValueError(
|
||
|
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
||
|
)
|
||
|
|
||
|
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
|
||
|
if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
|
||
|
logger.warning(
|
||
|
f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
|
||
|
f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
|
||
|
f" unexpected results when using this timestep schedule."
|
||
|
)
|
||
|
|
||
|
# Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
|
||
|
if non_train_timesteps:
|
||
|
logger.warning(
|
||
|
f"The custom timestep schedule contains the following timesteps which are not on the original"
|
||
|
f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
|
||
|
f" when using this timestep schedule."
|
||
|
)
|
||
|
|
||
|
# Raise warning if custom timestep schedule is longer than original_steps
|
||
|
if original_steps is not None:
|
||
|
if len(timesteps) > original_steps:
|
||
|
logger.warning(
|
||
|
f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
|
||
|
f" the length of the timestep schedule used for training: {original_steps}. You may get some"
|
||
|
f" unexpected results when using this timestep schedule."
|
||
|
)
|
||
|
else:
|
||
|
if len(timesteps) > self.config.num_train_timesteps:
|
||
|
logger.warning(
|
||
|
f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
|
||
|
f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some"
|
||
|
f" unexpected results when using this timestep schedule."
|
||
|
)
|
||
|
|
||
|
timesteps = np.array(timesteps, dtype=np.int64)
|
||
|
self.num_inference_steps = len(timesteps)
|
||
|
self.custom_timesteps = True
|
||
|
|
||
|
# Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
|
||
|
init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
|
||
|
t_start = max(self.num_inference_steps - init_timestep, 0)
|
||
|
timesteps = timesteps[t_start * self.order :]
|
||
|
# TODO: also reset self.num_inference_steps?
|
||
|
else:
|
||
|
# 2.2 Create the "standard" TCD inference timestep schedule.
|
||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||
|
raise ValueError(
|
||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||
|
)
|
||
|
|
||
|
if original_steps is not None:
|
||
|
skipping_step = len(tcd_origin_timesteps) // num_inference_steps
|
||
|
|
||
|
if skipping_step < 1:
|
||
|
raise ValueError(
|
||
|
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
|
||
|
)
|
||
|
|
||
|
self.num_inference_steps = num_inference_steps
|
||
|
|
||
|
if original_steps is not None:
|
||
|
if num_inference_steps > original_steps:
|
||
|
raise ValueError(
|
||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
||
|
f" {original_steps} because the final timestep schedule will be a subset of the"
|
||
|
f" `original_inference_steps`-sized initial timestep schedule."
|
||
|
)
|
||
|
else:
|
||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||
|
raise ValueError(
|
||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:"
|
||
|
f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the"
|
||
|
f" `num_train_timesteps`-sized initial timestep schedule."
|
||
|
)
|
||
|
|
||
|
# TCD Inference Steps Schedule
|
||
|
tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
|
||
|
# Select (approximately) evenly spaced indices from tcd_origin_timesteps.
|
||
|
inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
|
||
|
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||
|
timesteps = tcd_origin_timesteps[inference_indices]
|
||
|
|
||
|
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
||
|
|
||
|
self._step_index = None
|
||
|
self._begin_index = None
|
||
|
|
||
|
def step(
|
||
|
self,
|
||
|
model_output: torch.Tensor,
|
||
|
timestep: int,
|
||
|
sample: torch.Tensor,
|
||
|
eta: float = 0.3,
|
||
|
generator: Optional[torch.Generator] = None,
|
||
|
return_dict: bool = True,
|
||
|
) -> Union[TCDSchedulerOutput, Tuple]:
|
||
|
"""
|
||
|
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||
|
process from the learned model outputs (most often the predicted noise).
|
||
|
|
||
|
Args:
|
||
|
model_output (`torch.Tensor`):
|
||
|
The direct output from learned diffusion model.
|
||
|
timestep (`int`):
|
||
|
The current discrete timestep in the diffusion chain.
|
||
|
sample (`torch.Tensor`):
|
||
|
A current instance of a sample created by the diffusion process.
|
||
|
eta (`float`):
|
||
|
A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
|
||
|
step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic
|
||
|
sampling.
|
||
|
generator (`torch.Generator`, *optional*):
|
||
|
A random number generator.
|
||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||
|
Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`.
|
||
|
Returns:
|
||
|
[`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`:
|
||
|
If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a
|
||
|
tuple is returned where the first element is the sample tensor.
|
||
|
"""
|
||
|
if self.num_inference_steps is None:
|
||
|
raise ValueError(
|
||
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||
|
)
|
||
|
|
||
|
if self.step_index is None:
|
||
|
self._init_step_index(timestep)
|
||
|
|
||
|
assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0"
|
||
|
|
||
|
# 1. get previous step value
|
||
|
prev_step_index = self.step_index + 1
|
||
|
if prev_step_index < len(self.timesteps):
|
||
|
prev_timestep = self.timesteps[prev_step_index]
|
||
|
else:
|
||
|
prev_timestep = torch.tensor(0)
|
||
|
|
||
|
timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long)
|
||
|
|
||
|
# 2. compute alphas, betas
|
||
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||
|
beta_prod_t = 1 - alpha_prod_t
|
||
|
|
||
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||
|
|
||
|
alpha_prod_s = self.alphas_cumprod[timestep_s]
|
||
|
beta_prod_s = 1 - alpha_prod_s
|
||
|
|
||
|
# 3. Compute the predicted noised sample x_s based on the model parameterization
|
||
|
if self.config.prediction_type == "epsilon": # noise-prediction
|
||
|
pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
||
|
pred_epsilon = model_output
|
||
|
pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
|
||
|
elif self.config.prediction_type == "sample": # x-prediction
|
||
|
pred_original_sample = model_output
|
||
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||
|
pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
|
||
|
elif self.config.prediction_type == "v_prediction": # v-prediction
|
||
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||
|
pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||
|
" `v_prediction` for `TCDScheduler`."
|
||
|
)
|
||
|
|
||
|
# 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
||
|
# Noise is not used on the final timestep of the timestep schedule.
|
||
|
# This also means that noise is not used for one-step sampling.
|
||
|
# Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step.
|
||
|
# When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
|
||
|
if eta > 0:
|
||
|
if self.step_index != self.num_inference_steps - 1:
|
||
|
noise = randn_tensor(
|
||
|
model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype
|
||
|
)
|
||
|
prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
|
||
|
1 - alpha_prod_t_prev / alpha_prod_s
|
||
|
).sqrt() * noise
|
||
|
else:
|
||
|
prev_sample = pred_noised_sample
|
||
|
else:
|
||
|
prev_sample = pred_noised_sample
|
||
|
|
||
|
# upon completion increase step index by one
|
||
|
self._step_index += 1
|
||
|
|
||
|
if not return_dict:
|
||
|
return (prev_sample, pred_noised_sample)
|
||
|
|
||
|
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||
|
def add_noise(
|
||
|
self,
|
||
|
original_samples: torch.Tensor,
|
||
|
noise: torch.Tensor,
|
||
|
timesteps: torch.IntTensor,
|
||
|
) -> torch.Tensor:
|
||
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
||
|
# for the subsequent add_noise calls
|
||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
||
|
timesteps = timesteps.to(original_samples.device)
|
||
|
|
||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||
|
|
||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||
|
|
||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||
|
return noisy_samples
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||
|
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
||
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
||
|
timesteps = timesteps.to(sample.device)
|
||
|
|
||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||
|
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||
|
|
||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||
|
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||
|
|
||
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||
|
return velocity
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.config.num_train_timesteps
|
||
|
|
||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||
|
def previous_timestep(self, timestep):
|
||
|
if self.custom_timesteps or self.num_inference_steps:
|
||
|
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||
|
if index == self.timesteps.shape[0] - 1:
|
||
|
prev_t = torch.tensor(-1)
|
||
|
else:
|
||
|
prev_t = self.timesteps[index + 1]
|
||
|
else:
|
||
|
prev_t = timestep - 1
|
||
|
return prev_t
|