172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import enum
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
from accelerate.commands.utils import CustomArgumentParser
|
|
|
|
|
|
class ConversionStatus(enum.Enum):
|
|
NOT_YET_IMPLEMENTED = 0
|
|
REMOVED = -1
|
|
|
|
|
|
ARGUMENT_KEY_MAPPING = {
|
|
# New keys in FSDP2
|
|
"fsdp_version": "fsdp_version",
|
|
"fsdp_reshard_after_forward": "fsdp_reshard_after_forward",
|
|
# https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
|
|
# https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
|
|
"fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
|
|
"fsdp_backward_prefetch": ConversionStatus.REMOVED,
|
|
"fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
|
|
"fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
|
|
"fsdp_offload_params": "fsdp_offload_params",
|
|
"fsdp_sharding_strategy": "fsdp_reshard_after_forward",
|
|
"fsdp_state_dict_type": "fsdp_state_dict_type",
|
|
"fsdp_sync_module_states": ConversionStatus.REMOVED,
|
|
"fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
|
|
"fsdp_min_num_params": "fsdp_min_num_params",
|
|
"fsdp_use_orig_params": ConversionStatus.REMOVED,
|
|
"fsdp_activation_checkpointing": "fsdp_activation_checkpointing",
|
|
}
|
|
|
|
ARGUMENT_VALUE_MAPPING = {
|
|
"fsdp_sharding_strategy": {
|
|
"FULL_SHARD": True,
|
|
"SHARD_GRAD_OP": False,
|
|
"HYBRID_SHARD": True,
|
|
"HYBRID_SHARD_ZERO2": False,
|
|
"NO_SHARD": False,
|
|
},
|
|
"fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2
|
|
"FULL_SHARD": True,
|
|
"SHARD_GRAD_OP": False,
|
|
"HYBRID_SHARD": True,
|
|
"HYBRID_SHARD_ZERO2": False,
|
|
"NO_SHARD": False,
|
|
},
|
|
}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _validate_to_fsdp2_args(args):
|
|
if not Path(args.config_file).exists():
|
|
raise FileNotFoundError(f"Config file {args.config_file} not found")
|
|
|
|
if not args.overwrite and args.output_file is None:
|
|
raise ValueError("If --overwrite is not set, --output_file must be provided")
|
|
|
|
if not args.overwrite and Path(args.output_file).exists():
|
|
raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")
|
|
|
|
|
|
def convert_config_to_fsdp2(config: dict) -> dict:
|
|
fsdp_config = config.get("fsdp_config", {})
|
|
|
|
if not fsdp_config:
|
|
logger.info("No FSDP config found in the config file, skipping conversion...")
|
|
return config
|
|
|
|
new_fsdp_config = {}
|
|
|
|
if fsdp_config.get("fsdp_version", 1) == 2:
|
|
logger.warning("Config already specfies FSDP2, skipping conversion...")
|
|
logger.warning(
|
|
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
|
|
)
|
|
return config
|
|
|
|
for key, value in fsdp_config.items():
|
|
conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
|
|
if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
|
|
conversion_status = key
|
|
new_fsdp_config[conversion_status] = value
|
|
continue
|
|
|
|
if conversion_status == ConversionStatus.REMOVED:
|
|
logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
|
|
continue
|
|
|
|
if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
|
|
logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
|
|
continue
|
|
|
|
if conversion_status is None:
|
|
logger.warning(f"Argument {key} is not being converted, skipping this key...")
|
|
new_fsdp_config[key] = value
|
|
else:
|
|
if key in ARGUMENT_VALUE_MAPPING:
|
|
value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
|
|
new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value
|
|
|
|
new_fsdp_config["fsdp_version"] = 2
|
|
config["fsdp_config"] = new_fsdp_config
|
|
return config
|
|
|
|
|
|
def to_fsdp2_command_parser(subparsers=None):
|
|
description = "Convert an Accelerate config from FSDP1 to FSDP2"
|
|
|
|
if subparsers is not None:
|
|
parser = subparsers.add_parser("to-fsdp2", description=description)
|
|
else:
|
|
parser = CustomArgumentParser(description=description)
|
|
|
|
parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
|
|
parser.add_argument(
|
|
"--overwrite",
|
|
action="store_true",
|
|
help="Overwrite the config file if it exists",
|
|
default=False,
|
|
)
|
|
parser.add_argument(
|
|
"--output_file",
|
|
type=str,
|
|
help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
|
|
default=None,
|
|
)
|
|
if subparsers is not None:
|
|
parser.set_defaults(func=to_fsdp2_command)
|
|
|
|
return parser
|
|
|
|
|
|
def load_config(config_file: str) -> dict:
|
|
with open(config_file) as f:
|
|
config = yaml.safe_load(f)
|
|
if not config:
|
|
raise ValueError("Config file is empty")
|
|
|
|
return config
|
|
|
|
|
|
def to_fsdp2_command(args):
|
|
_validate_to_fsdp2_args(args)
|
|
config = load_config(args.config_file)
|
|
|
|
if args.overwrite and args.output_file is None:
|
|
args.output_file = args.config_file
|
|
|
|
new_config = convert_config_to_fsdp2(config)
|
|
|
|
with open(args.output_file, "w") as f:
|
|
yaml.dump(new_config, f)
|