252 lines
9.7 KiB
Python
252 lines
9.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Optional, Union
|
|
|
|
import yaml
|
|
|
|
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
|
from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION
|
|
|
|
|
|
hf_cache_home = os.path.expanduser(
|
|
os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
|
)
|
|
cache_dir = os.path.join(hf_cache_home, "accelerate")
|
|
default_json_config_file = os.path.join(cache_dir, "default_config.yaml")
|
|
default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml")
|
|
|
|
# For backward compatibility: the default config is the json one if it's the only existing file.
|
|
if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file):
|
|
default_config_file = default_yaml_config_file
|
|
else:
|
|
default_config_file = default_json_config_file
|
|
|
|
|
|
def load_config_from_file(config_file):
|
|
if config_file is not None:
|
|
if not os.path.isfile(config_file):
|
|
raise FileNotFoundError(
|
|
f"The passed configuration file `{config_file}` does not exist. "
|
|
"Please pass an existing file to `accelerate launch`, or use the default one "
|
|
"created through `accelerate config` and run `accelerate launch` "
|
|
"without the `--config_file` argument."
|
|
)
|
|
else:
|
|
config_file = default_config_file
|
|
with open(config_file, encoding="utf-8") as f:
|
|
if config_file.endswith(".json"):
|
|
if (
|
|
json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
|
== ComputeEnvironment.LOCAL_MACHINE
|
|
):
|
|
config_class = ClusterConfig
|
|
else:
|
|
config_class = SageMakerConfig
|
|
return config_class.from_json_file(json_file=config_file)
|
|
else:
|
|
if (
|
|
yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE)
|
|
== ComputeEnvironment.LOCAL_MACHINE
|
|
):
|
|
config_class = ClusterConfig
|
|
else:
|
|
config_class = SageMakerConfig
|
|
return config_class.from_yaml_file(yaml_file=config_file)
|
|
|
|
|
|
@dataclass
|
|
class BaseConfig:
|
|
compute_environment: ComputeEnvironment
|
|
distributed_type: Union[DistributedType, SageMakerDistributedType]
|
|
mixed_precision: str
|
|
use_cpu: bool
|
|
debug: bool
|
|
|
|
def to_dict(self):
|
|
result = self.__dict__
|
|
# For serialization, it's best to convert Enums to strings (or their underlying value type).
|
|
|
|
def _convert_enums(value):
|
|
if isinstance(value, Enum):
|
|
return value.value
|
|
if isinstance(value, dict):
|
|
if not bool(value):
|
|
return None
|
|
for key1, value1 in value.items():
|
|
value[key1] = _convert_enums(value1)
|
|
return value
|
|
|
|
for key, value in result.items():
|
|
result[key] = _convert_enums(value)
|
|
result = {k: v for k, v in result.items() if v is not None}
|
|
return result
|
|
|
|
@staticmethod
|
|
def process_config(config_dict):
|
|
"""
|
|
Processes `config_dict` and sets default values for any missing keys
|
|
"""
|
|
if "compute_environment" not in config_dict:
|
|
config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE
|
|
if "distributed_type" not in config_dict:
|
|
raise ValueError("A `distributed_type` must be specified in the config file.")
|
|
if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO:
|
|
config_dict["num_processes"] = 1
|
|
if "mixed_precision" not in config_dict:
|
|
config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None
|
|
if "fp16" in config_dict: # Convert the config to the new format.
|
|
del config_dict["fp16"]
|
|
if "dynamo_backend" in config_dict: # Convert the config to the new format.
|
|
dynamo_backend = config_dict.pop("dynamo_backend")
|
|
config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend}
|
|
if "use_cpu" not in config_dict:
|
|
config_dict["use_cpu"] = False
|
|
if "debug" not in config_dict:
|
|
config_dict["debug"] = False
|
|
if "enable_cpu_affinity" not in config_dict:
|
|
config_dict["enable_cpu_affinity"] = False
|
|
return config_dict
|
|
|
|
@classmethod
|
|
def from_json_file(cls, json_file=None):
|
|
json_file = default_json_config_file if json_file is None else json_file
|
|
with open(json_file, encoding="utf-8") as f:
|
|
config_dict = json.load(f)
|
|
config_dict = cls.process_config(config_dict)
|
|
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
|
if len(extra_keys) > 0:
|
|
raise ValueError(
|
|
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
|
" version or fix (and potentially remove) these keys from your config file."
|
|
)
|
|
|
|
return cls(**config_dict)
|
|
|
|
def to_json_file(self, json_file):
|
|
with open(json_file, "w", encoding="utf-8") as f:
|
|
content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
f.write(content)
|
|
|
|
@classmethod
|
|
def from_yaml_file(cls, yaml_file=None):
|
|
yaml_file = default_yaml_config_file if yaml_file is None else yaml_file
|
|
with open(yaml_file, encoding="utf-8") as f:
|
|
config_dict = yaml.safe_load(f)
|
|
config_dict = cls.process_config(config_dict)
|
|
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
|
|
if len(extra_keys) > 0:
|
|
raise ValueError(
|
|
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
|
|
" version or fix (and potentially remove) these keys from your config file."
|
|
)
|
|
return cls(**config_dict)
|
|
|
|
def to_yaml_file(self, yaml_file):
|
|
with open(yaml_file, "w", encoding="utf-8") as f:
|
|
yaml.safe_dump(self.to_dict(), f)
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.compute_environment, str):
|
|
self.compute_environment = ComputeEnvironment(self.compute_environment)
|
|
if isinstance(self.distributed_type, str):
|
|
if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
|
self.distributed_type = SageMakerDistributedType(self.distributed_type)
|
|
else:
|
|
self.distributed_type = DistributedType(self.distributed_type)
|
|
if getattr(self, "dynamo_config", None) is None:
|
|
self.dynamo_config = {}
|
|
|
|
|
|
@dataclass
|
|
class ClusterConfig(BaseConfig):
|
|
num_processes: int = -1 # For instance if we use SLURM and the user manually passes it in
|
|
machine_rank: int = 0
|
|
num_machines: int = 1
|
|
gpu_ids: Optional[str] = None
|
|
main_process_ip: Optional[str] = None
|
|
main_process_port: Optional[int] = None
|
|
rdzv_backend: Optional[str] = "static"
|
|
same_network: Optional[bool] = False
|
|
main_training_function: str = "main"
|
|
enable_cpu_affinity: bool = False
|
|
|
|
# args for FP8 training
|
|
fp8_config: dict = None
|
|
# args for deepspeed_plugin
|
|
deepspeed_config: dict = None
|
|
# args for fsdp
|
|
fsdp_config: dict = None
|
|
# args for megatron_lm
|
|
megatron_lm_config: dict = None
|
|
# args for ipex
|
|
ipex_config: dict = None
|
|
# args for mpirun
|
|
mpirun_config: dict = None
|
|
# args for TPU
|
|
downcast_bf16: bool = False
|
|
|
|
# args for TPU pods
|
|
tpu_name: str = None
|
|
tpu_zone: str = None
|
|
tpu_use_cluster: bool = False
|
|
tpu_use_sudo: bool = False
|
|
command_file: str = None
|
|
commands: list[str] = None
|
|
tpu_vm: list[str] = None
|
|
tpu_env: list[str] = None
|
|
|
|
# args for dynamo
|
|
dynamo_config: dict = None
|
|
|
|
def __post_init__(self):
|
|
if self.deepspeed_config is None:
|
|
self.deepspeed_config = {}
|
|
if self.fsdp_config is None:
|
|
self.fsdp_config = {}
|
|
if self.megatron_lm_config is None:
|
|
self.megatron_lm_config = {}
|
|
if self.ipex_config is None:
|
|
self.ipex_config = {}
|
|
if self.mpirun_config is None:
|
|
self.mpirun_config = {}
|
|
if self.fp8_config is None:
|
|
self.fp8_config = {}
|
|
return super().__post_init__()
|
|
|
|
|
|
@dataclass
|
|
class SageMakerConfig(BaseConfig):
|
|
ec2_instance_type: str
|
|
iam_role_name: str
|
|
image_uri: Optional[str] = None
|
|
profile: Optional[str] = None
|
|
region: str = "us-east-1"
|
|
num_machines: int = 1
|
|
gpu_ids: str = "all"
|
|
base_job_name: str = f"accelerate-sagemaker-{num_machines}"
|
|
pytorch_version: str = SAGEMAKER_PYTORCH_VERSION
|
|
transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION
|
|
py_version: str = SAGEMAKER_PYTHON_VERSION
|
|
sagemaker_inputs_file: str = None
|
|
sagemaker_metrics_file: str = None
|
|
additional_args: dict = None
|
|
dynamo_config: dict = None
|
|
enable_cpu_affinity: bool = False
|