1027 lines
48 KiB
Python
1027 lines
48 KiB
Python
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
||
|
||
import math
|
||
from typing import List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
|
||
from ..configuration_utils import ConfigMixin, register_to_config
|
||
from ..utils import deprecate, is_scipy_available
|
||
from ..utils.torch_utils import randn_tensor
|
||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||
|
||
|
||
if is_scipy_available():
|
||
import scipy.stats
|
||
|
||
|
||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||
def betas_for_alpha_bar(
|
||
num_diffusion_timesteps,
|
||
max_beta=0.999,
|
||
alpha_transform_type="cosine",
|
||
):
|
||
"""
|
||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||
(1-beta) over time from t = [0,1].
|
||
|
||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||
to that part of the diffusion process.
|
||
|
||
|
||
Args:
|
||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||
prevent singularities.
|
||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||
Choose from `cosine` or `exp`
|
||
|
||
Returns:
|
||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||
"""
|
||
if alpha_transform_type == "cosine":
|
||
|
||
def alpha_bar_fn(t):
|
||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||
|
||
elif alpha_transform_type == "exp":
|
||
|
||
def alpha_bar_fn(t):
|
||
return math.exp(t * -12.0)
|
||
|
||
else:
|
||
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
||
|
||
betas = []
|
||
for i in range(num_diffusion_timesteps):
|
||
t1 = i / num_diffusion_timesteps
|
||
t2 = (i + 1) / num_diffusion_timesteps
|
||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||
return torch.tensor(betas, dtype=torch.float32)
|
||
|
||
|
||
class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||
"""
|
||
`DPMSolverMultistepInverseScheduler` is the reverse scheduler of [`DPMSolverMultistepScheduler`].
|
||
|
||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||
methods the library implements for all schedulers such as loading and saving.
|
||
|
||
Args:
|
||
num_train_timesteps (`int`, defaults to 1000):
|
||
The number of diffusion steps to train the model.
|
||
beta_start (`float`, defaults to 0.0001):
|
||
The starting `beta` value of inference.
|
||
beta_end (`float`, defaults to 0.02):
|
||
The final `beta` value.
|
||
beta_schedule (`str`, defaults to `"linear"`):
|
||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||
trained_betas (`np.ndarray`, *optional*):
|
||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||
solver_order (`int`, defaults to 2):
|
||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||
sampling, and `solver_order=3` for unconditional sampling.
|
||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||
thresholding (`bool`, defaults to `False`):
|
||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||
as Stable Diffusion.
|
||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||
sample_max_value (`float`, defaults to 1.0):
|
||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||
`algorithm_type="dpmsolver++"`.
|
||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||
paper, and the `dpmsolver++` type implements the algorithms in the
|
||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||
solver_type (`str`, defaults to `midpoint`):
|
||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||
lower_order_final (`bool`, defaults to `True`):
|
||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||
euler_at_final (`bool`, defaults to `False`):
|
||
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
||
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
||
steps, but sometimes may result in blurring.
|
||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||
variance_type (`str`, *optional*):
|
||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
||
contains the predicted Gaussian variance.
|
||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||
steps_offset (`int`, defaults to 0):
|
||
An offset added to the inference steps, as required by some model families.
|
||
"""
|
||
|
||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||
order = 1
|
||
|
||
@register_to_config
|
||
def __init__(
|
||
self,
|
||
num_train_timesteps: int = 1000,
|
||
beta_start: float = 0.0001,
|
||
beta_end: float = 0.02,
|
||
beta_schedule: str = "linear",
|
||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||
solver_order: int = 2,
|
||
prediction_type: str = "epsilon",
|
||
thresholding: bool = False,
|
||
dynamic_thresholding_ratio: float = 0.995,
|
||
sample_max_value: float = 1.0,
|
||
algorithm_type: str = "dpmsolver++",
|
||
solver_type: str = "midpoint",
|
||
lower_order_final: bool = True,
|
||
euler_at_final: bool = False,
|
||
use_karras_sigmas: Optional[bool] = False,
|
||
use_exponential_sigmas: Optional[bool] = False,
|
||
use_beta_sigmas: Optional[bool] = False,
|
||
use_flow_sigmas: Optional[bool] = False,
|
||
flow_shift: Optional[float] = 1.0,
|
||
lambda_min_clipped: float = -float("inf"),
|
||
variance_type: Optional[str] = None,
|
||
timestep_spacing: str = "linspace",
|
||
steps_offset: int = 0,
|
||
):
|
||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||
raise ValueError(
|
||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||
)
|
||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||
|
||
if trained_betas is not None:
|
||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||
elif beta_schedule == "linear":
|
||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||
elif beta_schedule == "scaled_linear":
|
||
# this schedule is very specific to the latent diffusion model.
|
||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||
elif beta_schedule == "squaredcos_cap_v2":
|
||
# Glide cosine schedule
|
||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||
else:
|
||
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||
|
||
self.alphas = 1.0 - self.betas
|
||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||
# Currently we only support VP-type noise schedule
|
||
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
||
|
||
# standard deviation of the initial noise distribution
|
||
self.init_noise_sigma = 1.0
|
||
|
||
# settings for DPM-Solver
|
||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||
if algorithm_type == "deis":
|
||
self.register_to_config(algorithm_type="dpmsolver++")
|
||
else:
|
||
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
||
|
||
if solver_type not in ["midpoint", "heun"]:
|
||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||
self.register_to_config(solver_type="midpoint")
|
||
else:
|
||
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
||
|
||
# setable values
|
||
self.num_inference_steps = None
|
||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy()
|
||
self.timesteps = torch.from_numpy(timesteps)
|
||
self.model_outputs = [None] * solver_order
|
||
self.lower_order_nums = 0
|
||
self._step_index = None
|
||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||
self.use_karras_sigmas = use_karras_sigmas
|
||
self.use_exponential_sigmas = use_exponential_sigmas
|
||
self.use_beta_sigmas = use_beta_sigmas
|
||
|
||
@property
|
||
def step_index(self):
|
||
"""
|
||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||
"""
|
||
return self._step_index
|
||
|
||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||
"""
|
||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||
|
||
Args:
|
||
num_inference_steps (`int`):
|
||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||
device (`str` or `torch.device`, *optional*):
|
||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||
"""
|
||
# Clipping the minimum of all lambda(t) for numerical stability.
|
||
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
||
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item()
|
||
self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
|
||
|
||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||
if self.config.timestep_spacing == "linspace":
|
||
timesteps = (
|
||
np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
|
||
)
|
||
elif self.config.timestep_spacing == "leading":
|
||
step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1)
|
||
# creates integer timesteps by multiplying by ratio
|
||
# casting to int to avoid issues when num_inference_step is power of 3
|
||
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64)
|
||
timesteps += self.config.steps_offset
|
||
elif self.config.timestep_spacing == "trailing":
|
||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||
# creates integer timesteps by multiplying by ratio
|
||
# casting to int to avoid issues when num_inference_step is power of 3
|
||
timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64)
|
||
timesteps -= 1
|
||
else:
|
||
raise ValueError(
|
||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', "
|
||
"'leading' or 'trailing'."
|
||
)
|
||
|
||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||
log_sigmas = np.log(sigmas)
|
||
|
||
if self.config.use_karras_sigmas:
|
||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||
timesteps = timesteps.copy().astype(np.int64)
|
||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||
elif self.config.use_exponential_sigmas:
|
||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||
elif self.config.use_beta_sigmas:
|
||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||
elif self.config.use_flow_sigmas:
|
||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||
sigmas = 1.0 - alphas
|
||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||
else:
|
||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||
sigma_max = (
|
||
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
|
||
) ** 0.5
|
||
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
|
||
|
||
self.sigmas = torch.from_numpy(sigmas)
|
||
|
||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||
# duplicates in timesteps.
|
||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||
timesteps = timesteps[np.sort(unique_indices)]
|
||
|
||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||
|
||
self.num_inference_steps = len(timesteps)
|
||
|
||
self.model_outputs = [
|
||
None,
|
||
] * self.config.solver_order
|
||
self.lower_order_nums = 0
|
||
|
||
# add an index counter for schedulers that allow duplicated timesteps
|
||
self._step_index = None
|
||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||
|
||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||
|
||
https://huggingface.co/papers/2205.11487
|
||
"""
|
||
dtype = sample.dtype
|
||
batch_size, channels, *remaining_dims = sample.shape
|
||
|
||
if dtype not in (torch.float32, torch.float64):
|
||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||
|
||
# Flatten sample for doing quantile calculation along each image
|
||
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||
|
||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||
|
||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||
s = torch.clamp(
|
||
s, min=1, max=self.config.sample_max_value
|
||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||
|
||
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||
sample = sample.to(dtype)
|
||
|
||
return sample
|
||
|
||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||
def _sigma_to_t(self, sigma, log_sigmas):
|
||
# get log sigma
|
||
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
||
|
||
# get distribution
|
||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||
|
||
# get sigmas range
|
||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||
high_idx = low_idx + 1
|
||
|
||
low = log_sigmas[low_idx]
|
||
high = log_sigmas[high_idx]
|
||
|
||
# interpolate sigmas
|
||
w = (low - log_sigma) / (low - high)
|
||
w = np.clip(w, 0, 1)
|
||
|
||
# transform interpolation to time range
|
||
t = (1 - w) * low_idx + w * high_idx
|
||
t = t.reshape(sigma.shape)
|
||
return t
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||
if self.config.use_flow_sigmas:
|
||
alpha_t = 1 - sigma
|
||
sigma_t = sigma
|
||
else:
|
||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||
sigma_t = sigma * alpha_t
|
||
|
||
return alpha_t, sigma_t
|
||
|
||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||
|
||
# Hack to make sure that other schedulers which copy this function don't break
|
||
# TODO: Add this logic to the other schedulers
|
||
if hasattr(self.config, "sigma_min"):
|
||
sigma_min = self.config.sigma_min
|
||
else:
|
||
sigma_min = None
|
||
|
||
if hasattr(self.config, "sigma_max"):
|
||
sigma_max = self.config.sigma_max
|
||
else:
|
||
sigma_max = None
|
||
|
||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||
|
||
rho = 7.0 # 7.0 is the value used in the paper
|
||
ramp = np.linspace(0, 1, num_inference_steps)
|
||
min_inv_rho = sigma_min ** (1 / rho)
|
||
max_inv_rho = sigma_max ** (1 / rho)
|
||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||
return sigmas
|
||
|
||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
||
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||
"""Constructs an exponential noise schedule."""
|
||
|
||
# Hack to make sure that other schedulers which copy this function don't break
|
||
# TODO: Add this logic to the other schedulers
|
||
if hasattr(self.config, "sigma_min"):
|
||
sigma_min = self.config.sigma_min
|
||
else:
|
||
sigma_min = None
|
||
|
||
if hasattr(self.config, "sigma_max"):
|
||
sigma_max = self.config.sigma_max
|
||
else:
|
||
sigma_max = None
|
||
|
||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||
|
||
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
||
return sigmas
|
||
|
||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||
def _convert_to_beta(
|
||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||
) -> torch.Tensor:
|
||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||
|
||
# Hack to make sure that other schedulers which copy this function don't break
|
||
# TODO: Add this logic to the other schedulers
|
||
if hasattr(self.config, "sigma_min"):
|
||
sigma_min = self.config.sigma_min
|
||
else:
|
||
sigma_min = None
|
||
|
||
if hasattr(self.config, "sigma_max"):
|
||
sigma_max = self.config.sigma_max
|
||
else:
|
||
sigma_max = None
|
||
|
||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||
|
||
sigmas = np.array(
|
||
[
|
||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||
for ppf in [
|
||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||
]
|
||
]
|
||
)
|
||
return sigmas
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
||
def convert_model_output(
|
||
self,
|
||
model_output: torch.Tensor,
|
||
*args,
|
||
sample: torch.Tensor = None,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
||
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
||
integral of the data prediction model.
|
||
|
||
<Tip>
|
||
|
||
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
||
prediction and data prediction models.
|
||
|
||
</Tip>
|
||
|
||
Args:
|
||
model_output (`torch.Tensor`):
|
||
The direct output from the learned diffusion model.
|
||
sample (`torch.Tensor`):
|
||
A current instance of a sample created by the diffusion process.
|
||
|
||
Returns:
|
||
`torch.Tensor`:
|
||
The converted model output.
|
||
"""
|
||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||
if sample is None:
|
||
if len(args) > 1:
|
||
sample = args[1]
|
||
else:
|
||
raise ValueError("missing `sample` as a required keyword argument")
|
||
if timestep is not None:
|
||
deprecate(
|
||
"timesteps",
|
||
"1.0.0",
|
||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||
if self.config.prediction_type == "epsilon":
|
||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||
if self.config.variance_type in ["learned", "learned_range"]:
|
||
model_output = model_output[:, :3]
|
||
sigma = self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||
elif self.config.prediction_type == "sample":
|
||
x0_pred = model_output
|
||
elif self.config.prediction_type == "v_prediction":
|
||
sigma = self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||
elif self.config.prediction_type == "flow_prediction":
|
||
sigma_t = self.sigmas[self.step_index]
|
||
x0_pred = sample - sigma_t * model_output
|
||
else:
|
||
raise ValueError(
|
||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
||
"`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
|
||
)
|
||
|
||
if self.config.thresholding:
|
||
x0_pred = self._threshold_sample(x0_pred)
|
||
|
||
return x0_pred
|
||
|
||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||
if self.config.prediction_type == "epsilon":
|
||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||
if self.config.variance_type in ["learned", "learned_range"]:
|
||
epsilon = model_output[:, :3]
|
||
else:
|
||
epsilon = model_output
|
||
elif self.config.prediction_type == "sample":
|
||
sigma = self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||
elif self.config.prediction_type == "v_prediction":
|
||
sigma = self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
epsilon = alpha_t * model_output + sigma_t * sample
|
||
else:
|
||
raise ValueError(
|
||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||
" `v_prediction` for the DPMSolverMultistepScheduler."
|
||
)
|
||
|
||
if self.config.thresholding:
|
||
sigma = self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||
x0_pred = self._threshold_sample(x0_pred)
|
||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||
|
||
return epsilon
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
||
def dpm_solver_first_order_update(
|
||
self,
|
||
model_output: torch.Tensor,
|
||
*args,
|
||
sample: torch.Tensor = None,
|
||
noise: Optional[torch.Tensor] = None,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
"""
|
||
One step for the first-order DPMSolver (equivalent to DDIM).
|
||
|
||
Args:
|
||
model_output (`torch.Tensor`):
|
||
The direct output from the learned diffusion model.
|
||
sample (`torch.Tensor`):
|
||
A current instance of a sample created by the diffusion process.
|
||
|
||
Returns:
|
||
`torch.Tensor`:
|
||
The sample tensor at the previous timestep.
|
||
"""
|
||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||
if sample is None:
|
||
if len(args) > 2:
|
||
sample = args[2]
|
||
else:
|
||
raise ValueError("missing `sample` as a required keyword argument")
|
||
if timestep is not None:
|
||
deprecate(
|
||
"timesteps",
|
||
"1.0.0",
|
||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
if prev_timestep is not None:
|
||
deprecate(
|
||
"prev_timestep",
|
||
"1.0.0",
|
||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
||
|
||
h = lambda_t - lambda_s
|
||
if self.config.algorithm_type == "dpmsolver++":
|
||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||
elif self.config.algorithm_type == "dpmsolver":
|
||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||
assert noise is not None
|
||
x_t = (
|
||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||
)
|
||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||
assert noise is not None
|
||
x_t = (
|
||
(alpha_t / alpha_s) * sample
|
||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||
)
|
||
return x_t
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
||
def multistep_dpm_solver_second_order_update(
|
||
self,
|
||
model_output_list: List[torch.Tensor],
|
||
*args,
|
||
sample: torch.Tensor = None,
|
||
noise: Optional[torch.Tensor] = None,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
"""
|
||
One step for the second-order multistep DPMSolver.
|
||
|
||
Args:
|
||
model_output_list (`List[torch.Tensor]`):
|
||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||
sample (`torch.Tensor`):
|
||
A current instance of a sample created by the diffusion process.
|
||
|
||
Returns:
|
||
`torch.Tensor`:
|
||
The sample tensor at the previous timestep.
|
||
"""
|
||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||
if sample is None:
|
||
if len(args) > 2:
|
||
sample = args[2]
|
||
else:
|
||
raise ValueError("missing `sample` as a required keyword argument")
|
||
if timestep_list is not None:
|
||
deprecate(
|
||
"timestep_list",
|
||
"1.0.0",
|
||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
if prev_timestep is not None:
|
||
deprecate(
|
||
"prev_timestep",
|
||
"1.0.0",
|
||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
sigma_t, sigma_s0, sigma_s1 = (
|
||
self.sigmas[self.step_index + 1],
|
||
self.sigmas[self.step_index],
|
||
self.sigmas[self.step_index - 1],
|
||
)
|
||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||
|
||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||
|
||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||
|
||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
||
r0 = h_0 / h
|
||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||
if self.config.algorithm_type == "dpmsolver++":
|
||
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
||
if self.config.solver_type == "midpoint":
|
||
x_t = (
|
||
(sigma_t / sigma_s0) * sample
|
||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
||
)
|
||
elif self.config.solver_type == "heun":
|
||
x_t = (
|
||
(sigma_t / sigma_s0) * sample
|
||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
||
)
|
||
elif self.config.algorithm_type == "dpmsolver":
|
||
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
||
if self.config.solver_type == "midpoint":
|
||
x_t = (
|
||
(alpha_t / alpha_s0) * sample
|
||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||
)
|
||
elif self.config.solver_type == "heun":
|
||
x_t = (
|
||
(alpha_t / alpha_s0) * sample
|
||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||
)
|
||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||
assert noise is not None
|
||
if self.config.solver_type == "midpoint":
|
||
x_t = (
|
||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||
)
|
||
elif self.config.solver_type == "heun":
|
||
x_t = (
|
||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||
)
|
||
elif self.config.algorithm_type == "sde-dpmsolver":
|
||
assert noise is not None
|
||
if self.config.solver_type == "midpoint":
|
||
x_t = (
|
||
(alpha_t / alpha_s0) * sample
|
||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||
)
|
||
elif self.config.solver_type == "heun":
|
||
x_t = (
|
||
(alpha_t / alpha_s0) * sample
|
||
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
||
)
|
||
return x_t
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
||
def multistep_dpm_solver_third_order_update(
|
||
self,
|
||
model_output_list: List[torch.Tensor],
|
||
*args,
|
||
sample: torch.Tensor = None,
|
||
noise: Optional[torch.Tensor] = None,
|
||
**kwargs,
|
||
) -> torch.Tensor:
|
||
"""
|
||
One step for the third-order multistep DPMSolver.
|
||
|
||
Args:
|
||
model_output_list (`List[torch.Tensor]`):
|
||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||
sample (`torch.Tensor`):
|
||
A current instance of a sample created by diffusion process.
|
||
|
||
Returns:
|
||
`torch.Tensor`:
|
||
The sample tensor at the previous timestep.
|
||
"""
|
||
|
||
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
||
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
||
if sample is None:
|
||
if len(args) > 2:
|
||
sample = args[2]
|
||
else:
|
||
raise ValueError("missing `sample` as a required keyword argument")
|
||
if timestep_list is not None:
|
||
deprecate(
|
||
"timestep_list",
|
||
"1.0.0",
|
||
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
if prev_timestep is not None:
|
||
deprecate(
|
||
"prev_timestep",
|
||
"1.0.0",
|
||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||
)
|
||
|
||
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
||
self.sigmas[self.step_index + 1],
|
||
self.sigmas[self.step_index],
|
||
self.sigmas[self.step_index - 1],
|
||
self.sigmas[self.step_index - 2],
|
||
)
|
||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
||
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
||
|
||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
||
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
||
|
||
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
||
|
||
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
||
r0, r1 = h_0 / h, h_1 / h
|
||
D0 = m0
|
||
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
||
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
||
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
||
if self.config.algorithm_type == "dpmsolver++":
|
||
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
||
x_t = (
|
||
(sigma_t / sigma_s0) * sample
|
||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
||
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
||
)
|
||
elif self.config.algorithm_type == "dpmsolver":
|
||
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
||
x_t = (
|
||
(alpha_t / alpha_s0) * sample
|
||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
||
)
|
||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||
assert noise is not None
|
||
x_t = (
|
||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
|
||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
|
||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||
)
|
||
return x_t
|
||
|
||
def _init_step_index(self, timestep):
|
||
if isinstance(timestep, torch.Tensor):
|
||
timestep = timestep.to(self.timesteps.device)
|
||
|
||
index_candidates = (self.timesteps == timestep).nonzero()
|
||
|
||
if len(index_candidates) == 0:
|
||
step_index = len(self.timesteps) - 1
|
||
# The sigma index that is taken for the **very** first `step`
|
||
# is always the second index (or the last index if there is only 1)
|
||
# This way we can ensure we don't accidentally skip a sigma in
|
||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||
elif len(index_candidates) > 1:
|
||
step_index = index_candidates[1].item()
|
||
else:
|
||
step_index = index_candidates[0].item()
|
||
|
||
self._step_index = step_index
|
||
|
||
def step(
|
||
self,
|
||
model_output: torch.Tensor,
|
||
timestep: Union[int, torch.Tensor],
|
||
sample: torch.Tensor,
|
||
generator=None,
|
||
variance_noise: Optional[torch.Tensor] = None,
|
||
return_dict: bool = True,
|
||
) -> Union[SchedulerOutput, Tuple]:
|
||
"""
|
||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
||
the multistep DPMSolver.
|
||
|
||
Args:
|
||
model_output (`torch.Tensor`):
|
||
The direct output from learned diffusion model.
|
||
timestep (`int`):
|
||
The current discrete timestep in the diffusion chain.
|
||
sample (`torch.Tensor`):
|
||
A current instance of a sample created by the diffusion process.
|
||
generator (`torch.Generator`, *optional*):
|
||
A random number generator.
|
||
variance_noise (`torch.Tensor`):
|
||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||
itself. Useful for methods such as [`CycleDiffusion`].
|
||
return_dict (`bool`):
|
||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||
|
||
Returns:
|
||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||
tuple is returned where the first element is the sample tensor.
|
||
|
||
"""
|
||
if self.num_inference_steps is None:
|
||
raise ValueError(
|
||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||
)
|
||
|
||
if self.step_index is None:
|
||
self._init_step_index(timestep)
|
||
|
||
# Improve numerical stability for small number of steps
|
||
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
||
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
|
||
)
|
||
lower_order_second = (
|
||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||
)
|
||
|
||
model_output = self.convert_model_output(model_output, sample=sample)
|
||
for i in range(self.config.solver_order - 1):
|
||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||
self.model_outputs[-1] = model_output
|
||
|
||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||
noise = randn_tensor(
|
||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||
)
|
||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||
noise = variance_noise
|
||
else:
|
||
noise = None
|
||
|
||
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
||
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
||
else:
|
||
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
||
|
||
if self.lower_order_nums < self.config.solver_order:
|
||
self.lower_order_nums += 1
|
||
|
||
# upon completion increase step index by one
|
||
self._step_index += 1
|
||
|
||
if not return_dict:
|
||
return (prev_sample,)
|
||
|
||
return SchedulerOutput(prev_sample=prev_sample)
|
||
|
||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
|
||
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||
"""
|
||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||
current timestep.
|
||
|
||
Args:
|
||
sample (`torch.Tensor`):
|
||
The input sample.
|
||
|
||
Returns:
|
||
`torch.Tensor`:
|
||
A scaled input sample.
|
||
"""
|
||
return sample
|
||
|
||
def add_noise(
|
||
self,
|
||
original_samples: torch.Tensor,
|
||
noise: torch.Tensor,
|
||
timesteps: torch.IntTensor,
|
||
) -> torch.Tensor:
|
||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||
# mps does not support float64
|
||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||
else:
|
||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||
timesteps = timesteps.to(original_samples.device)
|
||
|
||
step_indices = []
|
||
for timestep in timesteps:
|
||
index_candidates = (schedule_timesteps == timestep).nonzero()
|
||
if len(index_candidates) == 0:
|
||
step_index = len(schedule_timesteps) - 1
|
||
elif len(index_candidates) > 1:
|
||
step_index = index_candidates[1].item()
|
||
else:
|
||
step_index = index_candidates[0].item()
|
||
step_indices.append(step_index)
|
||
|
||
sigma = sigmas[step_indices].flatten()
|
||
while len(sigma.shape) < len(original_samples.shape):
|
||
sigma = sigma.unsqueeze(-1)
|
||
|
||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||
return noisy_samples
|
||
|
||
def __len__(self):
|
||
return self.config.num_train_timesteps
|