121 lines
5.2 KiB
Python
121 lines
5.2 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import contextlib
|
|
import os
|
|
import tempfile
|
|
from typing import TYPE_CHECKING, Dict
|
|
|
|
from huggingface_hub import DDUFEntry
|
|
from tqdm import tqdm
|
|
|
|
from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
|
|
if is_transformers_available():
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
|
|
if is_safetensors_available():
|
|
import safetensors.torch
|
|
|
|
|
|
def _load_tokenizer_from_dduf(
|
|
cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
|
|
) -> "PreTrainedTokenizer":
|
|
"""
|
|
Load a tokenizer from a DDUF archive.
|
|
|
|
In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a
|
|
workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted
|
|
files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually
|
|
small-ish.
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
for entry_name, entry in dduf_entries.items():
|
|
if entry_name.startswith(name + "/"):
|
|
tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
|
|
# need to create intermediary directory if they don't exist
|
|
os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True)
|
|
with open(tmp_entry_path, "wb") as f:
|
|
with entry.as_mmap() as mm:
|
|
f.write(mm)
|
|
return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs)
|
|
|
|
|
|
def _load_transformers_model_from_dduf(
|
|
cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
|
|
) -> "PreTrainedModel":
|
|
"""
|
|
Load a transformers model from a DDUF archive.
|
|
|
|
In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
|
|
by instantiating a model from the config file and loading the weights from the DDUF archive directly.
|
|
"""
|
|
config_file = dduf_entries.get(f"{name}/config.json")
|
|
if config_file is None:
|
|
raise EnvironmentError(
|
|
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
|
|
)
|
|
generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
|
|
|
|
weight_files = [
|
|
entry
|
|
for entry_name, entry in dduf_entries.items()
|
|
if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
|
|
]
|
|
if not weight_files:
|
|
raise EnvironmentError(
|
|
f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
|
|
)
|
|
if not is_safetensors_available():
|
|
raise EnvironmentError(
|
|
"Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
|
|
)
|
|
if is_transformers_version("<", "4.47.0"):
|
|
raise ImportError(
|
|
"You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. "
|
|
"You can install it with: `pip install --upgrade transformers`"
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
from transformers import AutoConfig, GenerationConfig
|
|
|
|
tmp_config_file = os.path.join(tmp_dir, "config.json")
|
|
with open(tmp_config_file, "w") as f:
|
|
f.write(config_file.read_text())
|
|
config = AutoConfig.from_pretrained(tmp_config_file)
|
|
if generation_config is not None:
|
|
tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
|
|
with open(tmp_generation_config_file, "w") as f:
|
|
f.write(generation_config.read_text())
|
|
generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
|
|
state_dict = {}
|
|
with contextlib.ExitStack() as stack:
|
|
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
|
|
# Memory-map the safetensors file
|
|
mmap = stack.enter_context(entry.as_mmap())
|
|
# Load tensors from the memory-mapped file
|
|
tensors = safetensors.torch.load(mmap)
|
|
# Update the state dictionary with tensors
|
|
state_dict.update(tensors)
|
|
return cls.from_pretrained(
|
|
pretrained_model_name_or_path=None,
|
|
config=config,
|
|
generation_config=generation_config,
|
|
state_dict=state_dict,
|
|
**kwargs,
|
|
)
|