891 lines
31 KiB
Python
891 lines
31 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Import utilities: Utilities related to imports and our lazy inits.
|
|
"""
|
|
|
|
import importlib.util
|
|
import inspect
|
|
import operator as op
|
|
import os
|
|
import sys
|
|
from collections import OrderedDict, defaultdict
|
|
from itertools import chain
|
|
from types import ModuleType
|
|
from typing import Any, Tuple, Union
|
|
|
|
from huggingface_hub.utils import is_jinja_available # noqa: F401
|
|
from packaging.version import Version, parse
|
|
|
|
from . import logging
|
|
|
|
|
|
# The package importlib_metadata is in a different place, depending on the python version.
|
|
if sys.version_info < (3, 8):
|
|
import importlib_metadata
|
|
else:
|
|
import importlib.metadata as importlib_metadata
|
|
try:
|
|
_package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls
|
|
except Exception:
|
|
_package_map = None
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
|
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
|
|
DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper()
|
|
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
|
|
|
|
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
|
|
|
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
|
|
|
|
|
def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
|
|
global _package_map
|
|
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
|
pkg_version = "N/A"
|
|
|
|
if pkg_exists:
|
|
if _package_map is None:
|
|
_package_map = defaultdict(list)
|
|
try:
|
|
# Fallback for Python < 3.10
|
|
for dist in importlib_metadata.distributions():
|
|
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
|
|
_infered_opt_names = {
|
|
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
|
|
} - {None}
|
|
_top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names)
|
|
for pkg in _top_level_declared or _top_level_inferred:
|
|
_package_map[pkg].append(dist.metadata["Name"])
|
|
except Exception as _:
|
|
pass
|
|
try:
|
|
if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]:
|
|
if len(_package_map[pkg_name]) > 1:
|
|
logger.warning(
|
|
f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}"
|
|
)
|
|
pkg_name = _package_map[pkg_name][0]
|
|
pkg_version = importlib_metadata.version(pkg_name)
|
|
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
|
|
except (ImportError, importlib_metadata.PackageNotFoundError):
|
|
pkg_exists = False
|
|
|
|
return pkg_exists, pkg_version
|
|
|
|
|
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
_torch_available, _torch_version = _is_package_available("torch")
|
|
|
|
else:
|
|
logger.info("Disabling PyTorch because USE_TORCH is set")
|
|
_torch_available = False
|
|
_torch_version = "N/A"
|
|
|
|
_jax_version = "N/A"
|
|
_flax_version = "N/A"
|
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
|
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
|
if _flax_available:
|
|
try:
|
|
_jax_version = importlib_metadata.version("jax")
|
|
_flax_version = importlib_metadata.version("flax")
|
|
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_flax_available = False
|
|
else:
|
|
_flax_available = False
|
|
|
|
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
|
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
|
|
|
|
else:
|
|
logger.info("Disabling Safetensors because USE_TF is set")
|
|
_safetensors_available = False
|
|
|
|
_onnxruntime_version = "N/A"
|
|
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
|
if _onnx_available:
|
|
candidates = (
|
|
"onnxruntime",
|
|
"onnxruntime-cann",
|
|
"onnxruntime-directml",
|
|
"ort_nightly_directml",
|
|
"onnxruntime-gpu",
|
|
"ort_nightly_gpu",
|
|
"onnxruntime-migraphx",
|
|
"onnxruntime-openvino",
|
|
"onnxruntime-qnn",
|
|
"onnxruntime-rocm",
|
|
"onnxruntime-training",
|
|
"onnxruntime-vitisai",
|
|
)
|
|
_onnxruntime_version = None
|
|
# For the metadata, we have to look for both onnxruntime and onnxruntime-x
|
|
for pkg in candidates:
|
|
try:
|
|
_onnxruntime_version = importlib_metadata.version(pkg)
|
|
break
|
|
except importlib_metadata.PackageNotFoundError:
|
|
pass
|
|
_onnx_available = _onnxruntime_version is not None
|
|
if _onnx_available:
|
|
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
|
|
|
|
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
|
|
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
|
|
try:
|
|
candidates = (
|
|
"opencv-python",
|
|
"opencv-contrib-python",
|
|
"opencv-python-headless",
|
|
"opencv-contrib-python-headless",
|
|
)
|
|
_opencv_version = None
|
|
for pkg in candidates:
|
|
try:
|
|
_opencv_version = importlib_metadata.version(pkg)
|
|
break
|
|
except importlib_metadata.PackageNotFoundError:
|
|
pass
|
|
_opencv_available = _opencv_version is not None
|
|
if _opencv_available:
|
|
logger.debug(f"Successfully imported cv2 version {_opencv_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_opencv_available = False
|
|
|
|
_bs4_available = importlib.util.find_spec("bs4") is not None
|
|
try:
|
|
# importlib metadata under different name
|
|
_bs4_version = importlib_metadata.version("beautifulsoup4")
|
|
logger.debug(f"Successfully imported ftfy version {_bs4_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_bs4_available = False
|
|
|
|
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
|
|
try:
|
|
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
|
|
logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_invisible_watermark_available = False
|
|
|
|
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
|
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
|
_transformers_available, _transformers_version = _is_package_available("transformers")
|
|
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
|
_inflect_available, _inflect_version = _is_package_available("inflect")
|
|
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
|
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
|
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
|
|
_wandb_available, _wandb_version = _is_package_available("wandb")
|
|
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
|
|
_compel_available, _compel_version = _is_package_available("compel")
|
|
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
|
|
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
|
|
_peft_available, _peft_version = _is_package_available("peft")
|
|
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
|
|
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
|
|
_timm_available, _timm_version = _is_package_available("timm")
|
|
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
|
_imageio_available, _imageio_version = _is_package_available("imageio")
|
|
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
|
|
_scipy_available, _scipy_version = _is_package_available("scipy")
|
|
_librosa_available, _librosa_version = _is_package_available("librosa")
|
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
|
|
_xformers_available, _xformers_version = _is_package_available("xformers")
|
|
_gguf_available, _gguf_version = _is_package_available("gguf")
|
|
_torchao_available, _torchao_version = _is_package_available("torchao")
|
|
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
|
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
|
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
|
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
|
_nltk_available, _nltk_version = _is_package_available("nltk")
|
|
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
|
|
|
|
|
def is_torch_available():
|
|
return _torch_available
|
|
|
|
|
|
def is_torch_xla_available():
|
|
return _torch_xla_available
|
|
|
|
|
|
def is_torch_npu_available():
|
|
return _torch_npu_available
|
|
|
|
|
|
def is_flax_available():
|
|
return _flax_available
|
|
|
|
|
|
def is_transformers_available():
|
|
return _transformers_available
|
|
|
|
|
|
def is_inflect_available():
|
|
return _inflect_available
|
|
|
|
|
|
def is_unidecode_available():
|
|
return _unidecode_available
|
|
|
|
|
|
def is_onnx_available():
|
|
return _onnx_available
|
|
|
|
|
|
def is_opencv_available():
|
|
return _opencv_available
|
|
|
|
|
|
def is_scipy_available():
|
|
return _scipy_available
|
|
|
|
|
|
def is_librosa_available():
|
|
return _librosa_available
|
|
|
|
|
|
def is_xformers_available():
|
|
return _xformers_available
|
|
|
|
|
|
def is_accelerate_available():
|
|
return _accelerate_available
|
|
|
|
|
|
def is_k_diffusion_available():
|
|
return _k_diffusion_available
|
|
|
|
|
|
def is_note_seq_available():
|
|
return _note_seq_available
|
|
|
|
|
|
def is_wandb_available():
|
|
return _wandb_available
|
|
|
|
|
|
def is_tensorboard_available():
|
|
return _tensorboard_available
|
|
|
|
|
|
def is_compel_available():
|
|
return _compel_available
|
|
|
|
|
|
def is_ftfy_available():
|
|
return _ftfy_available
|
|
|
|
|
|
def is_bs4_available():
|
|
return _bs4_available
|
|
|
|
|
|
def is_torchsde_available():
|
|
return _torchsde_available
|
|
|
|
|
|
def is_invisible_watermark_available():
|
|
return _invisible_watermark_available
|
|
|
|
|
|
def is_peft_available():
|
|
return _peft_available
|
|
|
|
|
|
def is_torchvision_available():
|
|
return _torchvision_available
|
|
|
|
|
|
def is_matplotlib_available():
|
|
return _matplotlib_available
|
|
|
|
|
|
def is_safetensors_available():
|
|
return _safetensors_available
|
|
|
|
|
|
def is_bitsandbytes_available():
|
|
return _bitsandbytes_available
|
|
|
|
|
|
def is_google_colab():
|
|
return _is_google_colab
|
|
|
|
|
|
def is_sentencepiece_available():
|
|
return _sentencepiece_available
|
|
|
|
|
|
def is_imageio_available():
|
|
return _imageio_available
|
|
|
|
|
|
def is_gguf_available():
|
|
return _gguf_available
|
|
|
|
|
|
def is_torchao_available():
|
|
return _torchao_available
|
|
|
|
|
|
def is_optimum_quanto_available():
|
|
return _optimum_quanto_available
|
|
|
|
|
|
def is_timm_available():
|
|
return _timm_available
|
|
|
|
|
|
def is_pytorch_retinaface_available():
|
|
return _pytorch_retinaface_available
|
|
|
|
|
|
def is_better_profanity_available():
|
|
return _better_profanity_available
|
|
|
|
|
|
def is_nltk_available():
|
|
return _nltk_available
|
|
|
|
|
|
def is_cosmos_guardrail_available():
|
|
return _cosmos_guardrail_available
|
|
|
|
|
|
def is_hpu_available():
|
|
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
|
|
|
|
|
# docstyle-ignore
|
|
FLAX_IMPORT_ERROR = """
|
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
INFLECT_IMPORT_ERROR = """
|
|
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
|
|
inflect`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
PYTORCH_IMPORT_ERROR = """
|
|
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
ONNX_IMPORT_ERROR = """
|
|
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
|
|
install onnxruntime`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
OPENCV_IMPORT_ERROR = """
|
|
{0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip
|
|
install opencv-python`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
SCIPY_IMPORT_ERROR = """
|
|
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
|
scipy`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
LIBROSA_IMPORT_ERROR = """
|
|
{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment.
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
TRANSFORMERS_IMPORT_ERROR = """
|
|
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
|
|
install transformers`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
UNIDECODE_IMPORT_ERROR = """
|
|
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
|
|
Unidecode`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
K_DIFFUSION_IMPORT_ERROR = """
|
|
{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip
|
|
install k-diffusion`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
NOTE_SEQ_IMPORT_ERROR = """
|
|
{0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip
|
|
install note-seq`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
WANDB_IMPORT_ERROR = """
|
|
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
|
|
install wandb`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
TENSORBOARD_IMPORT_ERROR = """
|
|
{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip
|
|
install tensorboard`
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
COMPEL_IMPORT_ERROR = """
|
|
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
BS4_IMPORT_ERROR = """
|
|
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
|
|
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
FTFY_IMPORT_ERROR = """
|
|
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
|
|
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
|
|
that match your environment. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
TORCHSDE_IMPORT_ERROR = """
|
|
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
INVISIBLE_WATERMARK_IMPORT_ERROR = """
|
|
{0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
PEFT_IMPORT_ERROR = """
|
|
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
SAFETENSORS_IMPORT_ERROR = """
|
|
{0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
SENTENCEPIECE_IMPORT_ERROR = """
|
|
{0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece`
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
BITSANDBYTES_IMPORT_ERROR = """
|
|
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
IMAGEIO_IMPORT_ERROR = """
|
|
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
GGUF_IMPORT_ERROR = """
|
|
{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
|
|
"""
|
|
|
|
TORCHAO_IMPORT_ERROR = """
|
|
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
|
|
torchao`
|
|
"""
|
|
|
|
QUANTO_IMPORT_ERROR = """
|
|
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
|
|
install optimum-quanto`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
PYTORCH_RETINAFACE_IMPORT_ERROR = """
|
|
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
BETTER_PROFANITY_IMPORT_ERROR = """
|
|
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
|
|
"""
|
|
|
|
# docstyle-ignore
|
|
NLTK_IMPORT_ERROR = """
|
|
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
|
|
"""
|
|
|
|
|
|
BACKENDS_MAPPING = OrderedDict(
|
|
[
|
|
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
|
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
|
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
|
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
|
|
("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
|
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
|
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
|
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
|
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
|
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
|
|
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
|
|
("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
|
|
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
|
("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
|
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
|
|
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
|
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
|
|
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
|
|
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
|
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
|
|
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
|
|
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
|
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
|
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
|
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
|
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
|
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
|
|
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
|
|
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
|
|
]
|
|
)
|
|
|
|
|
|
def requires_backends(obj, backends):
|
|
if not isinstance(backends, (list, tuple)):
|
|
backends = [backends]
|
|
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
|
failed = [msg.format(name) for available, msg in checks if not available()]
|
|
if failed:
|
|
raise ImportError("".join(failed))
|
|
|
|
if name in [
|
|
"VersatileDiffusionTextToImagePipeline",
|
|
"VersatileDiffusionPipeline",
|
|
"VersatileDiffusionDualGuidedPipeline",
|
|
"StableDiffusionImageVariationPipeline",
|
|
"UnCLIPPipeline",
|
|
] and is_transformers_version("<", "4.25.0"):
|
|
raise ImportError(
|
|
f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
|
|
" --upgrade transformers \n```"
|
|
)
|
|
|
|
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
|
|
"<", "4.26.0"
|
|
):
|
|
raise ImportError(
|
|
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
|
|
" --upgrade transformers \n```"
|
|
)
|
|
|
|
|
|
class DummyObject(type):
|
|
"""
|
|
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
|
`requires_backend` each time a user tries to access any method of that class.
|
|
"""
|
|
|
|
def __getattr__(cls, key):
|
|
if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]:
|
|
return super().__getattr__(cls, key)
|
|
requires_backends(cls, cls._backends)
|
|
|
|
|
|
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
|
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
|
"""
|
|
Compares a library version to some requirement using a given operation.
|
|
|
|
Args:
|
|
library_or_version (`str` or `packaging.version.Version`):
|
|
A library name or a version to check.
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`.
|
|
requirement_version (`str`):
|
|
The version to compare the library version against
|
|
"""
|
|
if operation not in STR_OPERATION_TO_FUNC.keys():
|
|
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
|
operation = STR_OPERATION_TO_FUNC[operation]
|
|
if isinstance(library_or_version, str):
|
|
library_or_version = parse(importlib_metadata.version(library_or_version))
|
|
return operation(library_or_version, parse(requirement_version))
|
|
|
|
|
|
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
|
def is_torch_version(operation: str, version: str):
|
|
"""
|
|
Compares the current PyTorch version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A string version of PyTorch
|
|
"""
|
|
return compare_versions(parse(_torch_version), operation, version)
|
|
|
|
|
|
def is_torch_xla_version(operation: str, version: str):
|
|
"""
|
|
Compares the current torch_xla version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A string version of torch_xla
|
|
"""
|
|
if not is_torch_xla_available:
|
|
return False
|
|
return compare_versions(parse(_torch_xla_version), operation, version)
|
|
|
|
|
|
def is_transformers_version(operation: str, version: str):
|
|
"""
|
|
Compares the current Transformers version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _transformers_available:
|
|
return False
|
|
return compare_versions(parse(_transformers_version), operation, version)
|
|
|
|
|
|
def is_hf_hub_version(operation: str, version: str):
|
|
"""
|
|
Compares the current Hugging Face Hub version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _hf_hub_available:
|
|
return False
|
|
return compare_versions(parse(_hf_hub_version), operation, version)
|
|
|
|
|
|
def is_accelerate_version(operation: str, version: str):
|
|
"""
|
|
Compares the current Accelerate version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _accelerate_available:
|
|
return False
|
|
return compare_versions(parse(_accelerate_version), operation, version)
|
|
|
|
|
|
def is_peft_version(operation: str, version: str):
|
|
"""
|
|
Compares the current PEFT version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _peft_available:
|
|
return False
|
|
return compare_versions(parse(_peft_version), operation, version)
|
|
|
|
|
|
def is_bitsandbytes_version(operation: str, version: str):
|
|
"""
|
|
Args:
|
|
Compares the current bitsandbytes version to a given reference with an operation.
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _bitsandbytes_available:
|
|
return False
|
|
return compare_versions(parse(_bitsandbytes_version), operation, version)
|
|
|
|
|
|
def is_gguf_version(operation: str, version: str):
|
|
"""
|
|
Compares the current Accelerate version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _gguf_available:
|
|
return False
|
|
return compare_versions(parse(_gguf_version), operation, version)
|
|
|
|
|
|
def is_torchao_version(operation: str, version: str):
|
|
"""
|
|
Compares the current torchao version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _torchao_available:
|
|
return False
|
|
return compare_versions(parse(_torchao_version), operation, version)
|
|
|
|
|
|
def is_k_diffusion_version(operation: str, version: str):
|
|
"""
|
|
Compares the current k-diffusion version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _k_diffusion_available:
|
|
return False
|
|
return compare_versions(parse(_k_diffusion_version), operation, version)
|
|
|
|
|
|
def is_optimum_quanto_version(operation: str, version: str):
|
|
"""
|
|
Compares the current Accelerate version to a given reference with an operation.
|
|
|
|
Args:
|
|
operation (`str`):
|
|
A string representation of an operator, such as `">"` or `"<="`
|
|
version (`str`):
|
|
A version string
|
|
"""
|
|
if not _optimum_quanto_available:
|
|
return False
|
|
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
|
|
|
|
|
def get_objects_from_module(module):
|
|
"""
|
|
Returns a dict of object names and values in a module, while skipping private/internal objects
|
|
|
|
Args:
|
|
module (ModuleType):
|
|
Module to extract the objects from.
|
|
|
|
Returns:
|
|
dict: Dictionary of object names and corresponding values
|
|
"""
|
|
|
|
objects = {}
|
|
for name in dir(module):
|
|
if name.startswith("_"):
|
|
continue
|
|
objects[name] = getattr(module, name)
|
|
|
|
return objects
|
|
|
|
|
|
class OptionalDependencyNotAvailable(BaseException):
|
|
"""
|
|
An error indicating that an optional dependency of Diffusers was not found in the environment.
|
|
"""
|
|
|
|
|
|
class _LazyModule(ModuleType):
|
|
"""
|
|
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
|
"""
|
|
|
|
# Very heavily inspired by optuna.integration._IntegrationModule
|
|
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
|
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
|
|
super().__init__(name)
|
|
self._modules = set(import_structure.keys())
|
|
self._class_to_module = {}
|
|
for key, values in import_structure.items():
|
|
for value in values:
|
|
self._class_to_module[value] = key
|
|
# Needed for autocompletion in an IDE
|
|
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
|
|
self.__file__ = module_file
|
|
self.__spec__ = module_spec
|
|
self.__path__ = [os.path.dirname(module_file)]
|
|
self._objects = {} if extra_objects is None else extra_objects
|
|
self._name = name
|
|
self._import_structure = import_structure
|
|
|
|
# Needed for autocompletion in an IDE
|
|
def __dir__(self):
|
|
result = super().__dir__()
|
|
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
|
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
|
for attr in self.__all__:
|
|
if attr not in result:
|
|
result.append(attr)
|
|
return result
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name in self._objects:
|
|
return self._objects[name]
|
|
if name in self._modules:
|
|
value = self._get_module(name)
|
|
elif name in self._class_to_module.keys():
|
|
module = self._get_module(self._class_to_module[name])
|
|
value = getattr(module, name)
|
|
else:
|
|
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
|
|
|
setattr(self, name, value)
|
|
return value
|
|
|
|
def _get_module(self, module_name: str):
|
|
try:
|
|
return importlib.import_module("." + module_name, self.__name__)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
|
|
f" traceback):\n{e}"
|
|
) from e
|
|
|
|
def __reduce__(self):
|
|
return (self.__class__, (self._name, self.__file__, self._import_structure))
|