292 lines
12 KiB
Python
292 lines
12 KiB
Python
![]() |
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import importlib
|
||
|
import math
|
||
|
import os
|
||
|
from dataclasses import dataclass
|
||
|
from enum import Enum
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import flax
|
||
|
import jax.numpy as jnp
|
||
|
from huggingface_hub.utils import validate_hf_hub_args
|
||
|
|
||
|
from ..utils import BaseOutput, PushToHubMixin
|
||
|
|
||
|
|
||
|
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||
|
|
||
|
|
||
|
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
|
||
|
# circular imports when used for `_compatibles` within the schedulers module.
|
||
|
# When it's used as a type in pipelines, it really is a Union because the actual
|
||
|
# scheduler instance is passed in.
|
||
|
class FlaxKarrasDiffusionSchedulers(Enum):
|
||
|
FlaxDDIMScheduler = 1
|
||
|
FlaxDDPMScheduler = 2
|
||
|
FlaxPNDMScheduler = 3
|
||
|
FlaxLMSDiscreteScheduler = 4
|
||
|
FlaxDPMSolverMultistepScheduler = 5
|
||
|
FlaxEulerDiscreteScheduler = 6
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class FlaxSchedulerOutput(BaseOutput):
|
||
|
"""
|
||
|
Base class for the scheduler's step function output.
|
||
|
|
||
|
Args:
|
||
|
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
|
||
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||
|
denoising loop.
|
||
|
"""
|
||
|
|
||
|
prev_sample: jnp.ndarray
|
||
|
|
||
|
|
||
|
class FlaxSchedulerMixin(PushToHubMixin):
|
||
|
"""
|
||
|
Mixin containing common functions for the schedulers.
|
||
|
|
||
|
Class attributes:
|
||
|
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||
|
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||
|
by parent class).
|
||
|
"""
|
||
|
|
||
|
config_name = SCHEDULER_CONFIG_NAME
|
||
|
ignore_for_config = ["dtype"]
|
||
|
_compatibles = []
|
||
|
has_compatibles = True
|
||
|
|
||
|
@classmethod
|
||
|
@validate_hf_hub_args
|
||
|
def from_pretrained(
|
||
|
cls,
|
||
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||
|
subfolder: Optional[str] = None,
|
||
|
return_unused_kwargs=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
r"""
|
||
|
Instantiate a Scheduler class from a pre-defined JSON-file.
|
||
|
|
||
|
Parameters:
|
||
|
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||
|
Can be either:
|
||
|
|
||
|
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||
|
organization name, like `google/ddpm-celebahq-256`.
|
||
|
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
|
||
|
e.g., `./my_model_directory/`.
|
||
|
subfolder (`str`, *optional*):
|
||
|
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||
|
huggingface.co or downloaded locally), you can specify the folder name here.
|
||
|
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||
|
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||
|
|
||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||
|
standard cache should not be used.
|
||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||
|
cached versions if they exist.
|
||
|
|
||
|
proxies (`Dict[str, str]`, *optional*):
|
||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
||
|
token (`str` or *bool*, *optional*):
|
||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||
|
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||
|
identifier allowed by git.
|
||
|
|
||
|
<Tip>
|
||
|
|
||
|
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||
|
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
<Tip>
|
||
|
|
||
|
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||
|
use this method in a firewalled environment.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
"""
|
||
|
config, kwargs = cls.load_config(
|
||
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||
|
subfolder=subfolder,
|
||
|
return_unused_kwargs=True,
|
||
|
**kwargs,
|
||
|
)
|
||
|
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
|
||
|
|
||
|
if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
|
||
|
state = scheduler.create_state()
|
||
|
|
||
|
if return_unused_kwargs:
|
||
|
return scheduler, state, unused_kwargs
|
||
|
|
||
|
return scheduler, state
|
||
|
|
||
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||
|
"""
|
||
|
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||
|
[`~FlaxSchedulerMixin.from_pretrained`] class method.
|
||
|
|
||
|
Args:
|
||
|
save_directory (`str` or `os.PathLike`):
|
||
|
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||
|
namespace).
|
||
|
kwargs (`Dict[str, Any]`, *optional*):
|
||
|
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||
|
"""
|
||
|
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||
|
|
||
|
@property
|
||
|
def compatibles(self):
|
||
|
"""
|
||
|
Returns all schedulers that are compatible with this scheduler
|
||
|
|
||
|
Returns:
|
||
|
`List[SchedulerMixin]`: List of compatible schedulers
|
||
|
"""
|
||
|
return self._get_compatibles()
|
||
|
|
||
|
@classmethod
|
||
|
def _get_compatibles(cls):
|
||
|
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
||
|
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||
|
compatible_classes = [
|
||
|
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||
|
]
|
||
|
return compatible_classes
|
||
|
|
||
|
|
||
|
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
|
||
|
assert len(shape) >= x.ndim
|
||
|
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
|
||
|
|
||
|
|
||
|
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray:
|
||
|
"""
|
||
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||
|
(1-beta) over time from t = [0,1].
|
||
|
|
||
|
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||
|
to that part of the diffusion process.
|
||
|
|
||
|
|
||
|
Args:
|
||
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||
|
prevent singularities.
|
||
|
|
||
|
Returns:
|
||
|
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
|
||
|
"""
|
||
|
|
||
|
def alpha_bar(time_step):
|
||
|
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||
|
|
||
|
betas = []
|
||
|
for i in range(num_diffusion_timesteps):
|
||
|
t1 = i / num_diffusion_timesteps
|
||
|
t2 = (i + 1) / num_diffusion_timesteps
|
||
|
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||
|
return jnp.array(betas, dtype=dtype)
|
||
|
|
||
|
|
||
|
@flax.struct.dataclass
|
||
|
class CommonSchedulerState:
|
||
|
alphas: jnp.ndarray
|
||
|
betas: jnp.ndarray
|
||
|
alphas_cumprod: jnp.ndarray
|
||
|
|
||
|
@classmethod
|
||
|
def create(cls, scheduler):
|
||
|
config = scheduler.config
|
||
|
|
||
|
if config.trained_betas is not None:
|
||
|
betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype)
|
||
|
elif config.beta_schedule == "linear":
|
||
|
betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype)
|
||
|
elif config.beta_schedule == "scaled_linear":
|
||
|
# this schedule is very specific to the latent diffusion model.
|
||
|
betas = (
|
||
|
jnp.linspace(
|
||
|
config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype
|
||
|
)
|
||
|
** 2
|
||
|
)
|
||
|
elif config.beta_schedule == "squaredcos_cap_v2":
|
||
|
# Glide cosine schedule
|
||
|
betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype)
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}"
|
||
|
)
|
||
|
|
||
|
alphas = 1.0 - betas
|
||
|
|
||
|
alphas_cumprod = jnp.cumprod(alphas, axis=0)
|
||
|
|
||
|
return cls(
|
||
|
alphas=alphas,
|
||
|
betas=betas,
|
||
|
alphas_cumprod=alphas_cumprod,
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_sqrt_alpha_prod(
|
||
|
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
||
|
):
|
||
|
alphas_cumprod = state.alphas_cumprod
|
||
|
|
||
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||
|
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
|
||
|
|
||
|
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||
|
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
|
||
|
|
||
|
return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
|
||
|
|
||
|
|
||
|
def add_noise_common(
|
||
|
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
|
||
|
):
|
||
|
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
|
||
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||
|
return noisy_samples
|
||
|
|
||
|
|
||
|
def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
|
||
|
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
|
||
|
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||
|
return velocity
|