122 lines
3.2 KiB
Python
122 lines
3.2 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
from ...utils.dataclasses import (
|
|
ComputeEnvironment,
|
|
DistributedType,
|
|
DynamoBackend,
|
|
FP8BackendType,
|
|
PrecisionType,
|
|
SageMakerDistributedType,
|
|
)
|
|
from ..menu import BulletMenu
|
|
|
|
|
|
DYNAMO_BACKENDS = [
|
|
"EAGER",
|
|
"AOT_EAGER",
|
|
"INDUCTOR",
|
|
"AOT_TS_NVFUSER",
|
|
"NVPRIMS_NVFUSER",
|
|
"CUDAGRAPHS",
|
|
"OFI",
|
|
"FX2TRT",
|
|
"ONNXRT",
|
|
"TENSORRT",
|
|
"AOT_TORCHXLA_TRACE_ONCE",
|
|
"TORHCHXLA_TRACE_ONCE",
|
|
"IPEX",
|
|
"TVM",
|
|
]
|
|
|
|
|
|
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
|
ask_again = True
|
|
while ask_again:
|
|
result = input(input_text)
|
|
try:
|
|
if default is not None and len(result) == 0:
|
|
return default
|
|
return convert_value(result) if convert_value is not None else result
|
|
except Exception:
|
|
if error_message is not None:
|
|
print(error_message)
|
|
|
|
|
|
def _ask_options(input_text, options=[], convert_value=None, default=0):
|
|
menu = BulletMenu(input_text, options)
|
|
result = menu.run(default_choice=default)
|
|
return convert_value(result) if convert_value is not None else result
|
|
|
|
|
|
def _convert_compute_environment(value):
|
|
value = int(value)
|
|
return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value])
|
|
|
|
|
|
def _convert_distributed_mode(value):
|
|
value = int(value)
|
|
return DistributedType(
|
|
[
|
|
"NO",
|
|
"MULTI_CPU",
|
|
"MULTI_XPU",
|
|
"MULTI_HPU",
|
|
"MULTI_GPU",
|
|
"MULTI_NPU",
|
|
"MULTI_MLU",
|
|
"MULTI_SDAA",
|
|
"MULTI_MUSA",
|
|
"XLA",
|
|
][value]
|
|
)
|
|
|
|
|
|
def _convert_dynamo_backend(value):
|
|
value = int(value)
|
|
return DynamoBackend(DYNAMO_BACKENDS[value]).value
|
|
|
|
|
|
def _convert_mixed_precision(value):
|
|
value = int(value)
|
|
return PrecisionType(["no", "fp16", "bf16", "fp8"][value])
|
|
|
|
|
|
def _convert_sagemaker_distributed_mode(value):
|
|
value = int(value)
|
|
return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value])
|
|
|
|
|
|
def _convert_fp8_backend(value):
|
|
value = int(value)
|
|
return FP8BackendType(["TE", "MSAMP"][value])
|
|
|
|
|
|
def _convert_yes_no_to_bool(value):
|
|
return {"yes": True, "no": False}[value.lower()]
|
|
|
|
|
|
class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter):
|
|
"""
|
|
A custom formatter that will remove the usage line from the help message for subcommands.
|
|
"""
|
|
|
|
def _format_usage(self, usage, actions, groups, prefix):
|
|
usage = super()._format_usage(usage, actions, groups, prefix)
|
|
usage = usage.replace("<command> [<args>] ", "")
|
|
return usage
|