194 lines
8.5 KiB
Python
194 lines
8.5 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import importlib
|
|
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from huggingface_hub.utils import validate_hf_hub_args
|
|
from typing_extensions import Self
|
|
|
|
from ..utils import BaseOutput, PushToHubMixin
|
|
|
|
|
|
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
|
|
|
|
|
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
|
|
# circular imports when used for `_compatibles` within the schedulers module.
|
|
# When it's used as a type in pipelines, it really is a Union because the actual
|
|
# scheduler instance is passed in.
|
|
class KarrasDiffusionSchedulers(Enum):
|
|
DDIMScheduler = 1
|
|
DDPMScheduler = 2
|
|
PNDMScheduler = 3
|
|
LMSDiscreteScheduler = 4
|
|
EulerDiscreteScheduler = 5
|
|
HeunDiscreteScheduler = 6
|
|
EulerAncestralDiscreteScheduler = 7
|
|
DPMSolverMultistepScheduler = 8
|
|
DPMSolverSinglestepScheduler = 9
|
|
KDPM2DiscreteScheduler = 10
|
|
KDPM2AncestralDiscreteScheduler = 11
|
|
DEISMultistepScheduler = 12
|
|
UniPCMultistepScheduler = 13
|
|
DPMSolverSDEScheduler = 14
|
|
EDMEulerScheduler = 15
|
|
|
|
|
|
AysSchedules = {
|
|
"StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
|
|
"StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
|
|
"StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
|
|
"StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
|
|
"StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class SchedulerOutput(BaseOutput):
|
|
"""
|
|
Base class for the output of a scheduler's `step` function.
|
|
|
|
Args:
|
|
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
|
denoising loop.
|
|
"""
|
|
|
|
prev_sample: torch.Tensor
|
|
|
|
|
|
class SchedulerMixin(PushToHubMixin):
|
|
"""
|
|
Base class for all schedulers.
|
|
|
|
[`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving
|
|
functionalities.
|
|
|
|
[`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to
|
|
the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`.
|
|
|
|
Class attributes:
|
|
- **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler
|
|
class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden
|
|
by parent class).
|
|
"""
|
|
|
|
config_name = SCHEDULER_CONFIG_NAME
|
|
_compatibles = []
|
|
has_compatibles = True
|
|
|
|
@classmethod
|
|
@validate_hf_hub_args
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
|
subfolder: Optional[str] = None,
|
|
return_unused_kwargs=False,
|
|
**kwargs,
|
|
) -> Self:
|
|
r"""
|
|
Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
|
|
|
|
Parameters:
|
|
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
|
Can be either:
|
|
|
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
|
the Hub.
|
|
- A path to a *directory* (for example `./my_model_directory`) containing the scheduler
|
|
configuration saved with [`~SchedulerMixin.save_pretrained`].
|
|
subfolder (`str`, *optional*):
|
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
|
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
|
Whether kwargs that are not consumed by the Python class should be returned or not.
|
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
is not used.
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
cached versions if they exist.
|
|
|
|
proxies (`Dict[str, str]`, *optional*):
|
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
|
won't be downloaded from the Hub.
|
|
token (`str` or *bool*, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
revision (`str`, *optional*, defaults to `"main"`):
|
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
allowed by Git.
|
|
|
|
<Tip>
|
|
|
|
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
|
`huggingface-cli login`. You can also activate the special
|
|
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
|
firewalled environment.
|
|
|
|
</Tip>
|
|
|
|
"""
|
|
config, kwargs, commit_hash = cls.load_config(
|
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
subfolder=subfolder,
|
|
return_unused_kwargs=True,
|
|
return_commit_hash=True,
|
|
**kwargs,
|
|
)
|
|
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
|
|
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
|
"""
|
|
Save a scheduler configuration object to a directory so that it can be reloaded using the
|
|
[`~SchedulerMixin.from_pretrained`] class method.
|
|
|
|
Args:
|
|
save_directory (`str` or `os.PathLike`):
|
|
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
namespace).
|
|
kwargs (`Dict[str, Any]`, *optional*):
|
|
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
|
"""
|
|
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
|
|
|
@property
|
|
def compatibles(self):
|
|
"""
|
|
Returns all schedulers that are compatible with this scheduler
|
|
|
|
Returns:
|
|
`List[SchedulerMixin]`: List of compatible schedulers
|
|
"""
|
|
return self._get_compatibles()
|
|
|
|
@classmethod
|
|
def _get_compatibles(cls):
|
|
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
|
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
|
compatible_classes = [
|
|
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
|
]
|
|
return compatible_classes
|