315 lines
13 KiB
Python
315 lines
13 KiB
Python
![]() |
# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
|
|||
|
#
|
|||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
# you may not use this file except in compliance with the License.
|
|||
|
# You may obtain a copy of the License at
|
|||
|
#
|
|||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
#
|
|||
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
# See the License for the specific language governing permissions and
|
|||
|
# limitations under the License.
|
|||
|
|
|||
|
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
|||
|
# and https://github.com/hojonathanho/diffusion
|
|||
|
|
|||
|
from dataclasses import dataclass
|
|||
|
from typing import Optional, Tuple, Union
|
|||
|
|
|||
|
import flax
|
|||
|
import jax.numpy as jnp
|
|||
|
|
|||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
|||
|
from .scheduling_utils_flax import (
|
|||
|
CommonSchedulerState,
|
|||
|
FlaxKarrasDiffusionSchedulers,
|
|||
|
FlaxSchedulerMixin,
|
|||
|
FlaxSchedulerOutput,
|
|||
|
add_noise_common,
|
|||
|
get_velocity_common,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
@flax.struct.dataclass
|
|||
|
class DDIMSchedulerState:
|
|||
|
common: CommonSchedulerState
|
|||
|
final_alpha_cumprod: jnp.ndarray
|
|||
|
|
|||
|
# setable values
|
|||
|
init_noise_sigma: jnp.ndarray
|
|||
|
timesteps: jnp.ndarray
|
|||
|
num_inference_steps: Optional[int] = None
|
|||
|
|
|||
|
@classmethod
|
|||
|
def create(
|
|||
|
cls,
|
|||
|
common: CommonSchedulerState,
|
|||
|
final_alpha_cumprod: jnp.ndarray,
|
|||
|
init_noise_sigma: jnp.ndarray,
|
|||
|
timesteps: jnp.ndarray,
|
|||
|
):
|
|||
|
return cls(
|
|||
|
common=common,
|
|||
|
final_alpha_cumprod=final_alpha_cumprod,
|
|||
|
init_noise_sigma=init_noise_sigma,
|
|||
|
timesteps=timesteps,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
@dataclass
|
|||
|
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
|
|||
|
state: DDIMSchedulerState
|
|||
|
|
|||
|
|
|||
|
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
|
"""
|
|||
|
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
|||
|
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
|||
|
|
|||
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
|||
|
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
|||
|
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
|||
|
[`~SchedulerMixin.from_pretrained`] functions.
|
|||
|
|
|||
|
For more details, see the original paper: https://huggingface.co/papers/2010.02502
|
|||
|
|
|||
|
Args:
|
|||
|
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
|||
|
beta_start (`float`): the starting `beta` value of inference.
|
|||
|
beta_end (`float`): the final `beta` value.
|
|||
|
beta_schedule (`str`):
|
|||
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
|||
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
|||
|
trained_betas (`jnp.ndarray`, optional):
|
|||
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
|||
|
clip_sample (`bool`, default `True`):
|
|||
|
option to clip predicted sample between for numerical stability. The clip range is determined by
|
|||
|
`clip_sample_range`.
|
|||
|
clip_sample_range (`float`, default `1.0`):
|
|||
|
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
|||
|
set_alpha_to_one (`bool`, default `True`):
|
|||
|
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
|||
|
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
|||
|
otherwise it uses the value of alpha at step 0.
|
|||
|
steps_offset (`int`, default `0`):
|
|||
|
An offset added to the inference steps, as required by some model families.
|
|||
|
prediction_type (`str`, default `epsilon`):
|
|||
|
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
|||
|
`v-prediction` is not supported for this scheduler.
|
|||
|
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
|||
|
the `dtype` used for params and computation.
|
|||
|
"""
|
|||
|
|
|||
|
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
|||
|
|
|||
|
dtype: jnp.dtype
|
|||
|
|
|||
|
@property
|
|||
|
def has_state(self):
|
|||
|
return True
|
|||
|
|
|||
|
@register_to_config
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
num_train_timesteps: int = 1000,
|
|||
|
beta_start: float = 0.0001,
|
|||
|
beta_end: float = 0.02,
|
|||
|
beta_schedule: str = "linear",
|
|||
|
trained_betas: Optional[jnp.ndarray] = None,
|
|||
|
clip_sample: bool = True,
|
|||
|
clip_sample_range: float = 1.0,
|
|||
|
set_alpha_to_one: bool = True,
|
|||
|
steps_offset: int = 0,
|
|||
|
prediction_type: str = "epsilon",
|
|||
|
dtype: jnp.dtype = jnp.float32,
|
|||
|
):
|
|||
|
self.dtype = dtype
|
|||
|
|
|||
|
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
|||
|
if common is None:
|
|||
|
common = CommonSchedulerState.create(self)
|
|||
|
|
|||
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
|||
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
|||
|
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
|||
|
# whether we use the final alpha of the "non-previous" one.
|
|||
|
final_alpha_cumprod = (
|
|||
|
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
|
|||
|
)
|
|||
|
|
|||
|
# standard deviation of the initial noise distribution
|
|||
|
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
|||
|
|
|||
|
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
|||
|
|
|||
|
return DDIMSchedulerState.create(
|
|||
|
common=common,
|
|||
|
final_alpha_cumprod=final_alpha_cumprod,
|
|||
|
init_noise_sigma=init_noise_sigma,
|
|||
|
timesteps=timesteps,
|
|||
|
)
|
|||
|
|
|||
|
def scale_model_input(
|
|||
|
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
|||
|
) -> jnp.ndarray:
|
|||
|
"""
|
|||
|
Args:
|
|||
|
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
|
|||
|
sample (`jnp.ndarray`): input sample
|
|||
|
timestep (`int`, optional): current timestep
|
|||
|
|
|||
|
Returns:
|
|||
|
`jnp.ndarray`: scaled input sample
|
|||
|
"""
|
|||
|
return sample
|
|||
|
|
|||
|
def set_timesteps(
|
|||
|
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
|||
|
) -> DDIMSchedulerState:
|
|||
|
"""
|
|||
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
|||
|
|
|||
|
Args:
|
|||
|
state (`DDIMSchedulerState`):
|
|||
|
the `FlaxDDIMScheduler` state data class instance.
|
|||
|
num_inference_steps (`int`):
|
|||
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
|||
|
"""
|
|||
|
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
|||
|
# creates integer timesteps by multiplying by ratio
|
|||
|
# rounding to avoid issues when num_inference_step is power of 3
|
|||
|
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset
|
|||
|
|
|||
|
return state.replace(
|
|||
|
num_inference_steps=num_inference_steps,
|
|||
|
timesteps=timesteps,
|
|||
|
)
|
|||
|
|
|||
|
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
|
|||
|
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
|||
|
alpha_prod_t_prev = jnp.where(
|
|||
|
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
|||
|
)
|
|||
|
beta_prod_t = 1 - alpha_prod_t
|
|||
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
|||
|
|
|||
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
|||
|
|
|||
|
return variance
|
|||
|
|
|||
|
def step(
|
|||
|
self,
|
|||
|
state: DDIMSchedulerState,
|
|||
|
model_output: jnp.ndarray,
|
|||
|
timestep: int,
|
|||
|
sample: jnp.ndarray,
|
|||
|
eta: float = 0.0,
|
|||
|
return_dict: bool = True,
|
|||
|
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
|
|||
|
"""
|
|||
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
|||
|
process from the learned model outputs (most often the predicted noise).
|
|||
|
|
|||
|
Args:
|
|||
|
state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
|
|||
|
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
|||
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
|||
|
sample (`jnp.ndarray`):
|
|||
|
current instance of sample being created by diffusion process.
|
|||
|
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
|
|||
|
|
|||
|
Returns:
|
|||
|
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
|
|||
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
|||
|
|
|||
|
"""
|
|||
|
if state.num_inference_steps is None:
|
|||
|
raise ValueError(
|
|||
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
|||
|
)
|
|||
|
|
|||
|
# See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
|
|||
|
# Ideally, read DDIM paper in-detail understanding
|
|||
|
|
|||
|
# Notation (<variable name> -> <name in paper>
|
|||
|
# - pred_noise_t -> e_theta(x_t, t)
|
|||
|
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
|||
|
# - std_dev_t -> sigma_t
|
|||
|
# - eta -> η
|
|||
|
# - pred_sample_direction -> "direction pointing to x_t"
|
|||
|
# - pred_prev_sample -> "x_t-1"
|
|||
|
|
|||
|
# 1. get previous step value (=t-1)
|
|||
|
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
|||
|
|
|||
|
alphas_cumprod = state.common.alphas_cumprod
|
|||
|
final_alpha_cumprod = state.final_alpha_cumprod
|
|||
|
|
|||
|
# 2. compute alphas, betas
|
|||
|
alpha_prod_t = alphas_cumprod[timestep]
|
|||
|
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)
|
|||
|
|
|||
|
beta_prod_t = 1 - alpha_prod_t
|
|||
|
|
|||
|
# 3. compute predicted original sample from predicted noise also called
|
|||
|
# "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
|
|||
|
if self.config.prediction_type == "epsilon":
|
|||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
|||
|
pred_epsilon = model_output
|
|||
|
elif self.config.prediction_type == "sample":
|
|||
|
pred_original_sample = model_output
|
|||
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
|||
|
elif self.config.prediction_type == "v_prediction":
|
|||
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
|||
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
|||
|
else:
|
|||
|
raise ValueError(
|
|||
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
|||
|
" `v_prediction`"
|
|||
|
)
|
|||
|
|
|||
|
# 4. Clip or threshold "predicted x_0"
|
|||
|
if self.config.clip_sample:
|
|||
|
pred_original_sample = pred_original_sample.clip(
|
|||
|
-self.config.clip_sample_range, self.config.clip_sample_range
|
|||
|
)
|
|||
|
|
|||
|
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
|||
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
|||
|
variance = self._get_variance(state, timestep, prev_timestep)
|
|||
|
std_dev_t = eta * variance ** (0.5)
|
|||
|
|
|||
|
# 5. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
|
|||
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
|||
|
|
|||
|
# 6. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
|
|||
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
|||
|
|
|||
|
if not return_dict:
|
|||
|
return (prev_sample, state)
|
|||
|
|
|||
|
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)
|
|||
|
|
|||
|
def add_noise(
|
|||
|
self,
|
|||
|
state: DDIMSchedulerState,
|
|||
|
original_samples: jnp.ndarray,
|
|||
|
noise: jnp.ndarray,
|
|||
|
timesteps: jnp.ndarray,
|
|||
|
) -> jnp.ndarray:
|
|||
|
return add_noise_common(state.common, original_samples, noise, timesteps)
|
|||
|
|
|||
|
def get_velocity(
|
|||
|
self,
|
|||
|
state: DDIMSchedulerState,
|
|||
|
sample: jnp.ndarray,
|
|||
|
noise: jnp.ndarray,
|
|||
|
timesteps: jnp.ndarray,
|
|||
|
) -> jnp.ndarray:
|
|||
|
return get_velocity_common(state.common, sample, noise, timesteps)
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return self.config.num_train_timesteps
|