139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
![]() |
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""
|
||
|
Generic utilities
|
||
|
"""
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
from dataclasses import fields, is_dataclass
|
||
|
from typing import Any, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .import_utils import is_torch_available, is_torch_version
|
||
|
|
||
|
|
||
|
def is_tensor(x) -> bool:
|
||
|
"""
|
||
|
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
||
|
"""
|
||
|
if is_torch_available():
|
||
|
import torch
|
||
|
|
||
|
if isinstance(x, torch.Tensor):
|
||
|
return True
|
||
|
|
||
|
return isinstance(x, np.ndarray)
|
||
|
|
||
|
|
||
|
class BaseOutput(OrderedDict):
|
||
|
"""
|
||
|
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
|
||
|
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
|
||
|
Python dictionary.
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
|
||
|
first.
|
||
|
|
||
|
</Tip>
|
||
|
"""
|
||
|
|
||
|
def __init_subclass__(cls) -> None:
|
||
|
"""Register subclasses as pytree nodes.
|
||
|
|
||
|
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
|
||
|
`static_graph=True` with modules that output `ModelOutput` subclasses.
|
||
|
"""
|
||
|
if is_torch_available():
|
||
|
import torch.utils._pytree
|
||
|
|
||
|
if is_torch_version("<", "2.2"):
|
||
|
torch.utils._pytree._register_pytree_node(
|
||
|
cls,
|
||
|
torch.utils._pytree._dict_flatten,
|
||
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||
|
)
|
||
|
else:
|
||
|
torch.utils._pytree.register_pytree_node(
|
||
|
cls,
|
||
|
torch.utils._pytree._dict_flatten,
|
||
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
||
|
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
||
|
)
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
class_fields = fields(self)
|
||
|
|
||
|
# Safety and consistency checks
|
||
|
if not len(class_fields):
|
||
|
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||
|
|
||
|
first_field = getattr(self, class_fields[0].name)
|
||
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||
|
|
||
|
if other_fields_are_none and isinstance(first_field, dict):
|
||
|
for key, value in first_field.items():
|
||
|
self[key] = value
|
||
|
else:
|
||
|
for field in class_fields:
|
||
|
v = getattr(self, field.name)
|
||
|
if v is not None:
|
||
|
self[field.name] = v
|
||
|
|
||
|
def __delitem__(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def setdefault(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def pop(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def update(self, *args, **kwargs):
|
||
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||
|
|
||
|
def __getitem__(self, k: Any) -> Any:
|
||
|
if isinstance(k, str):
|
||
|
inner_dict = dict(self.items())
|
||
|
return inner_dict[k]
|
||
|
else:
|
||
|
return self.to_tuple()[k]
|
||
|
|
||
|
def __setattr__(self, name: Any, value: Any) -> None:
|
||
|
if name in self.keys() and value is not None:
|
||
|
# Don't call self.__setitem__ to avoid recursion errors
|
||
|
super().__setitem__(name, value)
|
||
|
super().__setattr__(name, value)
|
||
|
|
||
|
def __setitem__(self, key, value):
|
||
|
# Will raise a KeyException if needed
|
||
|
super().__setitem__(key, value)
|
||
|
# Don't call self.__setattr__ to avoid recursion errors
|
||
|
super().__setattr__(key, value)
|
||
|
|
||
|
def __reduce__(self):
|
||
|
if not is_dataclass(self):
|
||
|
return super().__reduce__()
|
||
|
callable, _args, *remaining = super().__reduce__()
|
||
|
args = tuple(getattr(self, field.name) for field in fields(self))
|
||
|
return callable, args, *remaining
|
||
|
|
||
|
def to_tuple(self) -> Tuple[Any, ...]:
|
||
|
"""
|
||
|
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||
|
"""
|
||
|
return tuple(self[k] for k in self.keys())
|