# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import json import os from dataclasses import dataclass from typing import Any, Union @dataclass class DistributedConfig: """ Base class for distributed configs """ enable_expert_parallel: bool = False # TODO: add tp_plan, pp_plan, device_mesh etc.. @classmethod def from_dict(cls, config_dict, **kwargs): """ Constructs a DistributedConfig instance from a dictionary of parameters. Args: config_dict (Dict[str, Any]): Dictionary containing configuration parameters. **kwargs: Additional keyword arguments to override dictionary values. Returns: DistributedConfig: Instance of DistributedConfig constructed from the dictionary. """ config = cls(**config_dict) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) return config # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save this instance to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file in which this configuration instance's parameters will be saved. use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `QuantizationConfig()` is serialized to JSON file. """ with open(json_file_path, "w", encoding="utf-8") as writer: config_dict = self.to_dict() json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" writer.write(json_string) def to_dict(self) -> dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ return copy.deepcopy(self.__dict__) # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ def __iter__(self): """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" for attr, value in copy.deepcopy(self.__dict__).items(): yield attr, value # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" def to_json_string(self): """ Serializes this instance to a JSON formatted string. Returns: str: JSON formatted string representing the configuration instance. """ return json.dumps(self.__dict__, indent=2) + "\n" def update(self, **kwargs): """ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, returning all the unused kwargs. Args: kwargs (`Dict[str, Any]`): Dictionary of attributes to tentatively update this class. Returns: `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. """ to_remove = [] for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) to_remove.append(key) # Remove all the attributes that were updated, without modifying the input dict unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs