111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Any, Union
|
|
|
|
|
|
@dataclass
|
|
class DistributedConfig:
|
|
"""
|
|
Base class for distributed configs
|
|
"""
|
|
|
|
enable_expert_parallel: bool = False
|
|
# TODO: add tp_plan, pp_plan, device_mesh etc..
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict, **kwargs):
|
|
"""
|
|
Constructs a DistributedConfig instance from a dictionary of parameters.
|
|
Args:
|
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
|
**kwargs: Additional keyword arguments to override dictionary values.
|
|
Returns:
|
|
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
|
|
"""
|
|
config = cls(**config_dict)
|
|
to_remove = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(config, key):
|
|
setattr(config, key, value)
|
|
to_remove.append(key)
|
|
for key in to_remove:
|
|
kwargs.pop(key, None)
|
|
return config
|
|
|
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
|
"""
|
|
Save this instance to a JSON file.
|
|
Args:
|
|
json_file_path (`str` or `os.PathLike`):
|
|
Path to the JSON file in which this configuration instance's parameters will be saved.
|
|
use_diff (`bool`, *optional*, defaults to `True`):
|
|
If set to `True`, only the difference between the config instance and the default
|
|
`QuantizationConfig()` is serialized to JSON file.
|
|
"""
|
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
|
config_dict = self.to_dict()
|
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
|
|
|
writer.write(json_string)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""
|
|
Serializes this instance to a Python dictionary. Returns:
|
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
|
"""
|
|
return copy.deepcopy(self.__dict__)
|
|
|
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
|
def __iter__(self):
|
|
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
|
for attr, value in copy.deepcopy(self.__dict__).items():
|
|
yield attr, value
|
|
|
|
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
|
|
|
def to_json_string(self):
|
|
"""
|
|
Serializes this instance to a JSON formatted string.
|
|
Returns:
|
|
str: JSON formatted string representing the configuration instance.
|
|
"""
|
|
return json.dumps(self.__dict__, indent=2) + "\n"
|
|
|
|
def update(self, **kwargs):
|
|
"""
|
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
|
|
returning all the unused kwargs.
|
|
Args:
|
|
kwargs (`Dict[str, Any]`):
|
|
Dictionary of attributes to tentatively update this class.
|
|
Returns:
|
|
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
|
"""
|
|
to_remove = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
to_remove.append(key)
|
|
|
|
# Remove all the attributes that were updated, without modifying the input dict
|
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
|
return unused_kwargs
|