522 lines
21 KiB
Python
522 lines
21 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
from array import array
|
|
from collections import OrderedDict
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
from zipfile import is_zipfile
|
|
|
|
import safetensors
|
|
import torch
|
|
from huggingface_hub import DDUFEntry
|
|
from huggingface_hub.utils import EntryNotFoundError
|
|
|
|
from ..quantizers import DiffusersQuantizer
|
|
from ..utils import (
|
|
GGUF_FILE_EXTENSION,
|
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
SAFETENSORS_FILE_EXTENSION,
|
|
WEIGHTS_INDEX_NAME,
|
|
_add_variant,
|
|
_get_model_file,
|
|
deprecate,
|
|
is_accelerate_available,
|
|
is_gguf_available,
|
|
is_torch_available,
|
|
is_torch_version,
|
|
logging,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CLASS_REMAPPING_DICT = {
|
|
"Transformer2DModel": {
|
|
"ada_norm_zero": "DiTTransformer2DModel",
|
|
"ada_norm_single": "PixArtTransformer2DModel",
|
|
}
|
|
}
|
|
|
|
|
|
if is_accelerate_available():
|
|
from accelerate import infer_auto_device_map
|
|
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
|
|
|
|
|
# Adapted from `transformers` (see modeling_utils.py)
|
|
def _determine_device_map(
|
|
model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
|
|
):
|
|
if isinstance(device_map, str):
|
|
special_dtypes = {}
|
|
if hf_quantizer is not None:
|
|
special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
|
|
special_dtypes.update(
|
|
{
|
|
name: torch.float32
|
|
for name, _ in model.named_parameters()
|
|
if any(m in name for m in keep_in_fp32_modules)
|
|
}
|
|
)
|
|
|
|
target_dtype = torch_dtype
|
|
if hf_quantizer is not None:
|
|
target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
|
|
|
|
no_split_modules = model._get_no_split_modules(device_map)
|
|
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
|
|
|
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
|
|
device_map_kwargs["special_dtypes"] = special_dtypes
|
|
elif len(special_dtypes) > 0:
|
|
logger.warning(
|
|
"This model has some weights that should be kept in higher precision, you need to upgrade "
|
|
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
|
|
)
|
|
|
|
if device_map != "sequential":
|
|
max_memory = get_balanced_memory(
|
|
model,
|
|
dtype=torch_dtype,
|
|
low_zero=(device_map == "balanced_low_0"),
|
|
max_memory=max_memory,
|
|
**device_map_kwargs,
|
|
)
|
|
else:
|
|
max_memory = get_max_memory(max_memory)
|
|
|
|
if hf_quantizer is not None:
|
|
max_memory = hf_quantizer.adjust_max_memory(max_memory)
|
|
|
|
device_map_kwargs["max_memory"] = max_memory
|
|
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
|
|
|
|
if hf_quantizer is not None:
|
|
hf_quantizer.validate_environment(device_map=device_map)
|
|
|
|
return device_map
|
|
|
|
|
|
def _fetch_remapped_cls_from_config(config, old_class):
|
|
previous_class_name = old_class.__name__
|
|
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
|
|
|
|
# Details:
|
|
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
|
|
if remapped_class_name:
|
|
# load diffusers library to import compatible and original scheduler
|
|
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
|
remapped_class = getattr(diffusers_library, remapped_class_name)
|
|
logger.info(
|
|
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
|
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
|
|
" DOESN'T affect the final results."
|
|
)
|
|
return remapped_class
|
|
else:
|
|
return old_class
|
|
|
|
|
|
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
|
"""
|
|
Find the device of param_name from the device_map.
|
|
"""
|
|
if device_map is None:
|
|
return "cpu"
|
|
else:
|
|
module_name = param_name
|
|
# find next higher level module that is defined in device_map:
|
|
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
|
while len(module_name) > 0 and module_name not in device_map:
|
|
module_name = ".".join(module_name.split(".")[:-1])
|
|
if module_name == "" and "" not in device_map:
|
|
raise ValueError(f"{param_name} doesn't have any device set.")
|
|
return device_map[module_name]
|
|
|
|
|
|
def load_state_dict(
|
|
checkpoint_file: Union[str, os.PathLike],
|
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
|
disable_mmap: bool = False,
|
|
map_location: Union[str, torch.device] = "cpu",
|
|
):
|
|
"""
|
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
|
"""
|
|
# TODO: maybe refactor a bit this part where we pass a dict here
|
|
if isinstance(checkpoint_file, dict):
|
|
return checkpoint_file
|
|
try:
|
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
|
if dduf_entries:
|
|
# tensors are loaded on cpu
|
|
with dduf_entries[checkpoint_file].as_mmap() as mm:
|
|
return safetensors.torch.load(mm)
|
|
if disable_mmap:
|
|
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
|
else:
|
|
return safetensors.torch.load_file(checkpoint_file, device=map_location)
|
|
elif file_extension == GGUF_FILE_EXTENSION:
|
|
return load_gguf_checkpoint(checkpoint_file)
|
|
else:
|
|
extra_args = {}
|
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
|
# mmap can only be used with files serialized with zipfile-based format.
|
|
if (
|
|
isinstance(checkpoint_file, str)
|
|
and map_location != "meta"
|
|
and is_torch_version(">=", "2.1.0")
|
|
and is_zipfile(checkpoint_file)
|
|
and not disable_mmap
|
|
):
|
|
extra_args = {"mmap": True}
|
|
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
|
|
except Exception as e:
|
|
try:
|
|
with open(checkpoint_file) as f:
|
|
if f.read().startswith("version"):
|
|
raise OSError(
|
|
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
|
"you cloned."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
|
"model. Make sure you have saved the model properly."
|
|
) from e
|
|
except (UnicodeDecodeError, ValueError):
|
|
raise OSError(
|
|
f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
|
|
)
|
|
|
|
|
|
def load_model_dict_into_meta(
|
|
model,
|
|
state_dict: OrderedDict,
|
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
model_name_or_path: Optional[str] = None,
|
|
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
|
keep_in_fp32_modules: Optional[List] = None,
|
|
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
|
unexpected_keys: Optional[List[str]] = None,
|
|
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
|
offload_index: Optional[Dict] = None,
|
|
state_dict_index: Optional[Dict] = None,
|
|
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
|
|
) -> List[str]:
|
|
"""
|
|
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
|
params on a `meta` device. It replaces the model params with the data from the `state_dict`
|
|
"""
|
|
|
|
is_quantized = hf_quantizer is not None
|
|
empty_state_dict = model.state_dict()
|
|
|
|
for param_name, param in state_dict.items():
|
|
if param_name not in empty_state_dict:
|
|
continue
|
|
|
|
set_module_kwargs = {}
|
|
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
|
|
# in int/uint/bool and not cast them.
|
|
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
|
|
if dtype is not None and torch.is_floating_point(param):
|
|
if keep_in_fp32_modules is not None and any(
|
|
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
|
):
|
|
param = param.to(torch.float32)
|
|
set_module_kwargs["dtype"] = torch.float32
|
|
# For quantizers have save weights using torch.float8_e4m3fn
|
|
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
|
pass
|
|
else:
|
|
param = param.to(dtype)
|
|
set_module_kwargs["dtype"] = dtype
|
|
|
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
|
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
|
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
|
old_param = model
|
|
splits = param_name.split(".")
|
|
for split in splits:
|
|
old_param = getattr(old_param, split)
|
|
|
|
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
|
old_param = None
|
|
|
|
if old_param is not None:
|
|
if dtype is None:
|
|
param = param.to(old_param.dtype)
|
|
|
|
if old_param.is_contiguous():
|
|
param = param.contiguous()
|
|
|
|
param_device = _determine_param_device(param_name, device_map)
|
|
|
|
# bnb params are flattened.
|
|
# gguf quants have a different shape based on the type of quantization applied
|
|
if empty_state_dict[param_name].shape != param.shape:
|
|
if (
|
|
is_quantized
|
|
and hf_quantizer.pre_quantized
|
|
and hf_quantizer.check_if_quantized_param(
|
|
model, param, param_name, state_dict, param_device=param_device
|
|
)
|
|
):
|
|
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
|
else:
|
|
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
|
raise ValueError(
|
|
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
|
)
|
|
if param_device == "disk":
|
|
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
|
elif param_device == "cpu" and state_dict_index is not None:
|
|
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
|
elif is_quantized and (
|
|
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
|
):
|
|
hf_quantizer.create_quantized_param(
|
|
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
|
)
|
|
else:
|
|
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
|
|
|
return offload_index, state_dict_index
|
|
|
|
|
|
def _load_state_dict_into_model(
|
|
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
|
) -> List[str]:
|
|
# Convert old format to new format if needed from a PyTorch state_dict
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
state_dict = state_dict.copy()
|
|
error_msgs = []
|
|
|
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
# so we need to apply the function recursively.
|
|
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
|
|
local_metadata = {}
|
|
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
|
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
|
|
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
|
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
|
module._load_from_state_dict(*args)
|
|
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, prefix + name + ".", assign_to_params_buffers)
|
|
|
|
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
|
|
|
|
return error_msgs
|
|
|
|
|
|
def _fetch_index_file(
|
|
is_local,
|
|
pretrained_model_name_or_path,
|
|
subfolder,
|
|
use_safetensors,
|
|
cache_dir,
|
|
variant,
|
|
force_download,
|
|
proxies,
|
|
local_files_only,
|
|
token,
|
|
revision,
|
|
user_agent,
|
|
commit_hash,
|
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
|
):
|
|
if is_local:
|
|
index_file = Path(
|
|
pretrained_model_name_or_path,
|
|
subfolder or "",
|
|
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
|
)
|
|
else:
|
|
index_file_in_repo = Path(
|
|
subfolder or "",
|
|
_add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
|
|
).as_posix()
|
|
try:
|
|
index_file = _get_model_file(
|
|
pretrained_model_name_or_path,
|
|
weights_name=index_file_in_repo,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
revision=revision,
|
|
subfolder=None,
|
|
user_agent=user_agent,
|
|
commit_hash=commit_hash,
|
|
dduf_entries=dduf_entries,
|
|
)
|
|
if not dduf_entries:
|
|
index_file = Path(index_file)
|
|
except (EntryNotFoundError, EnvironmentError):
|
|
index_file = None
|
|
|
|
return index_file
|
|
|
|
|
|
def _fetch_index_file_legacy(
|
|
is_local,
|
|
pretrained_model_name_or_path,
|
|
subfolder,
|
|
use_safetensors,
|
|
cache_dir,
|
|
variant,
|
|
force_download,
|
|
proxies,
|
|
local_files_only,
|
|
token,
|
|
revision,
|
|
user_agent,
|
|
commit_hash,
|
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
|
):
|
|
if is_local:
|
|
index_file = Path(
|
|
pretrained_model_name_or_path,
|
|
subfolder or "",
|
|
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
|
).as_posix()
|
|
splits = index_file.split(".")
|
|
split_index = -3 if ".cache" in index_file else -2
|
|
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
|
index_file = ".".join(splits)
|
|
if os.path.exists(index_file):
|
|
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
|
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
|
index_file = Path(index_file)
|
|
else:
|
|
index_file = None
|
|
else:
|
|
if variant is not None:
|
|
index_file_in_repo = Path(
|
|
subfolder or "",
|
|
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
|
).as_posix()
|
|
splits = index_file_in_repo.split(".")
|
|
split_index = -2
|
|
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
|
index_file_in_repo = ".".join(splits)
|
|
try:
|
|
index_file = _get_model_file(
|
|
pretrained_model_name_or_path,
|
|
weights_name=index_file_in_repo,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
revision=revision,
|
|
subfolder=None,
|
|
user_agent=user_agent,
|
|
commit_hash=commit_hash,
|
|
dduf_entries=dduf_entries,
|
|
)
|
|
index_file = Path(index_file)
|
|
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
|
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
|
except (EntryNotFoundError, EnvironmentError):
|
|
index_file = None
|
|
|
|
return index_file
|
|
|
|
|
|
def _gguf_parse_value(_value, data_type):
|
|
if not isinstance(data_type, list):
|
|
data_type = [data_type]
|
|
if len(data_type) == 1:
|
|
data_type = data_type[0]
|
|
array_data_type = None
|
|
else:
|
|
if data_type[0] != 9:
|
|
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
|
data_type, array_data_type = data_type
|
|
|
|
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
|
_value = int(_value[0])
|
|
elif data_type in [6, 12]:
|
|
_value = float(_value[0])
|
|
elif data_type in [7]:
|
|
_value = bool(_value[0])
|
|
elif data_type in [8]:
|
|
_value = array("B", list(_value)).tobytes().decode()
|
|
elif data_type in [9]:
|
|
_value = _gguf_parse_value(_value, array_data_type)
|
|
return _value
|
|
|
|
|
|
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
|
"""
|
|
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
|
|
attributes.
|
|
|
|
Args:
|
|
gguf_checkpoint_path (`str`):
|
|
The path the to GGUF file to load
|
|
return_tensors (`bool`, defaults to `True`):
|
|
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
|
|
metadata in memory.
|
|
"""
|
|
|
|
if is_gguf_available() and is_torch_available():
|
|
import gguf
|
|
from gguf import GGUFReader
|
|
|
|
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
|
|
else:
|
|
logger.error(
|
|
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
|
|
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
|
)
|
|
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
|
|
|
|
reader = GGUFReader(gguf_checkpoint_path)
|
|
|
|
parsed_parameters = {}
|
|
for tensor in reader.tensors:
|
|
name = tensor.name
|
|
quant_type = tensor.tensor_type
|
|
|
|
# if the tensor is a torch supported dtype do not use GGUFParameter
|
|
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
|
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
|
|
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
|
|
raise ValueError(
|
|
(
|
|
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
|
|
"\n\nCurrently the following quantization types are supported: \n\n"
|
|
f"{_supported_quants_str}"
|
|
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
|
|
)
|
|
)
|
|
|
|
weights = torch.from_numpy(tensor.data.copy())
|
|
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
|
|
|
return parsed_parameters
|