241 lines
9.7 KiB
Python
241 lines
9.7 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import importlib
|
||
|
import inspect
|
||
|
import os
|
||
|
from typing import Optional, Union
|
||
|
|
||
|
from huggingface_hub import snapshot_download
|
||
|
from PIL import Image
|
||
|
|
||
|
from .configuration_utils import ConfigMixin
|
||
|
from .utils import DIFFUSERS_CACHE, logging
|
||
|
|
||
|
|
||
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
LOADABLE_CLASSES = {
|
||
|
"diffusers": {
|
||
|
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||
|
"SchedulerMixin": ["save_config", "from_config"],
|
||
|
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||
|
},
|
||
|
"transformers": {
|
||
|
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||
|
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||
|
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||
|
},
|
||
|
}
|
||
|
|
||
|
ALL_IMPORTABLE_CLASSES = {}
|
||
|
for library in LOADABLE_CLASSES:
|
||
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||
|
|
||
|
|
||
|
class DiffusionPipeline(ConfigMixin):
|
||
|
|
||
|
config_name = "model_index.json"
|
||
|
|
||
|
def register_modules(self, **kwargs):
|
||
|
# import it here to avoid circular import
|
||
|
from diffusers import pipelines
|
||
|
|
||
|
for name, module in kwargs.items():
|
||
|
# retrive library
|
||
|
library = module.__module__.split(".")[0]
|
||
|
|
||
|
# check if the module is a pipeline module
|
||
|
pipeline_file = module.__module__.split(".")[-1]
|
||
|
pipeline_dir = module.__module__.split(".")[-2]
|
||
|
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
|
||
|
|
||
|
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||
|
# Or if it's a pipeline module, then the module is inside the pipeline
|
||
|
# folder so we set the library to module name.
|
||
|
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||
|
library = pipeline_dir
|
||
|
|
||
|
# retrive class_name
|
||
|
class_name = module.__class__.__name__
|
||
|
|
||
|
register_dict = {name: (library, class_name)}
|
||
|
|
||
|
# save model index config
|
||
|
self.register_to_config(**register_dict)
|
||
|
|
||
|
# set models
|
||
|
setattr(self, name, module)
|
||
|
|
||
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||
|
self.save_config(save_directory)
|
||
|
|
||
|
model_index_dict = dict(self.config)
|
||
|
model_index_dict.pop("_class_name")
|
||
|
model_index_dict.pop("_diffusers_version")
|
||
|
model_index_dict.pop("_module", None)
|
||
|
|
||
|
for pipeline_component_name in model_index_dict.keys():
|
||
|
sub_model = getattr(self, pipeline_component_name)
|
||
|
model_cls = sub_model.__class__
|
||
|
|
||
|
save_method_name = None
|
||
|
# search for the model's base class in LOADABLE_CLASSES
|
||
|
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||
|
library = importlib.import_module(library_name)
|
||
|
for base_class, save_load_methods in library_classes.items():
|
||
|
class_candidate = getattr(library, base_class)
|
||
|
if issubclass(model_cls, class_candidate):
|
||
|
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||
|
save_method_name = save_load_methods[0]
|
||
|
break
|
||
|
if save_method_name is not None:
|
||
|
break
|
||
|
|
||
|
save_method = getattr(sub_model, save_method_name)
|
||
|
save_method(os.path.join(save_directory, pipeline_component_name))
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||
|
r"""
|
||
|
Add docstrings
|
||
|
"""
|
||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||
|
resume_download = kwargs.pop("resume_download", False)
|
||
|
proxies = kwargs.pop("proxies", None)
|
||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||
|
revision = kwargs.pop("revision", None)
|
||
|
|
||
|
# 1. Download the checkpoints and configs
|
||
|
# use snapshot download here to get it working from from_pretrained
|
||
|
if not os.path.isdir(pretrained_model_name_or_path):
|
||
|
cached_folder = snapshot_download(
|
||
|
pretrained_model_name_or_path,
|
||
|
cache_dir=cache_dir,
|
||
|
resume_download=resume_download,
|
||
|
proxies=proxies,
|
||
|
local_files_only=local_files_only,
|
||
|
use_auth_token=use_auth_token,
|
||
|
revision=revision,
|
||
|
)
|
||
|
else:
|
||
|
cached_folder = pretrained_model_name_or_path
|
||
|
|
||
|
config_dict = cls.get_config_dict(cached_folder)
|
||
|
|
||
|
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||
|
# if we load from explicit class, let's use it
|
||
|
if cls != DiffusionPipeline:
|
||
|
pipeline_class = cls
|
||
|
else:
|
||
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||
|
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||
|
|
||
|
# some modules can be passed directly to the init
|
||
|
# in this case they are already instantiated in `kwargs`
|
||
|
# extract them here
|
||
|
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||
|
|
||
|
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||
|
|
||
|
init_kwargs = {}
|
||
|
|
||
|
# import it here to avoid circular import
|
||
|
from diffusers import pipelines
|
||
|
|
||
|
# 3. Load each module in the pipeline
|
||
|
for name, (library_name, class_name) in init_dict.items():
|
||
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||
|
loaded_sub_model = None
|
||
|
|
||
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||
|
if name in passed_class_obj:
|
||
|
# 1. check that passed_class_obj has correct parent class
|
||
|
if not is_pipeline_module:
|
||
|
library = importlib.import_module(library_name)
|
||
|
class_obj = getattr(library, class_name)
|
||
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||
|
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||
|
|
||
|
expected_class_obj = None
|
||
|
for class_name, class_candidate in class_candidates.items():
|
||
|
if issubclass(class_obj, class_candidate):
|
||
|
expected_class_obj = class_candidate
|
||
|
|
||
|
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||
|
raise ValueError(
|
||
|
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||
|
f" {expected_class_obj}"
|
||
|
)
|
||
|
else:
|
||
|
logger.warn(
|
||
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||
|
" has the correct type"
|
||
|
)
|
||
|
|
||
|
# set passed class object
|
||
|
loaded_sub_model = passed_class_obj[name]
|
||
|
elif is_pipeline_module:
|
||
|
pipeline_module = getattr(pipelines, library_name)
|
||
|
class_obj = getattr(pipeline_module, class_name)
|
||
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||
|
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||
|
else:
|
||
|
# else we just import it from the library.
|
||
|
library = importlib.import_module(library_name)
|
||
|
class_obj = getattr(library, class_name)
|
||
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||
|
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||
|
|
||
|
if loaded_sub_model is None:
|
||
|
load_method_name = None
|
||
|
for class_name, class_candidate in class_candidates.items():
|
||
|
if issubclass(class_obj, class_candidate):
|
||
|
load_method_name = importable_classes[class_name][1]
|
||
|
|
||
|
load_method = getattr(class_obj, load_method_name)
|
||
|
|
||
|
# check if the module is in a subdirectory
|
||
|
if os.path.isdir(os.path.join(cached_folder, name)):
|
||
|
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||
|
else:
|
||
|
# else load from the root directory
|
||
|
loaded_sub_model = load_method(cached_folder)
|
||
|
|
||
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||
|
|
||
|
# 4. Instantiate the pipeline
|
||
|
model = pipeline_class(**init_kwargs)
|
||
|
return model
|
||
|
|
||
|
@staticmethod
|
||
|
def numpy_to_pil(images):
|
||
|
"""
|
||
|
Convert a numpy image or a batch of images to a PIL image.
|
||
|
"""
|
||
|
if images.ndim == 3:
|
||
|
images = images[None, ...]
|
||
|
images = (images * 255).round().astype("uint8")
|
||
|
pil_images = [Image.fromarray(image) for image in images]
|
||
|
|
||
|
return pil_images
|