147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
import warnings
|
|
from collections.abc import Iterable
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.export
|
|
import torch.export._trace
|
|
from torch._utils_internal import log_export_usage
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
__all__ = ["report_exportability"]
|
|
|
|
|
|
def _generate_inputs_for_submodules(
|
|
model: torch.nn.Module,
|
|
target_submodules: Iterable[str],
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[dict[str, Any]] = None,
|
|
) -> dict[str, tuple[Any, Any]]:
|
|
"""
|
|
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
|
function doesn't work.
|
|
|
|
Args:
|
|
model: root model.
|
|
inputs: inputs to the root model.
|
|
target_submodules: submodules that we want to generate inputs for.
|
|
|
|
Returns:
|
|
A dict that maps from submodule name to its inputs.
|
|
"""
|
|
kwargs = kwargs or {}
|
|
|
|
handles = []
|
|
results = {}
|
|
submodule_to_names = {mod: name for name, mod in model.named_modules()}
|
|
|
|
def pre_forward(module, module_args, module_kwargs):
|
|
results[submodule_to_names[module]] = (module_args, module_kwargs)
|
|
|
|
try:
|
|
for name, mod in model.named_modules():
|
|
if name in target_submodules:
|
|
handles.append(
|
|
mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
|
)
|
|
model(*args, **kwargs)
|
|
except Exception as e:
|
|
warnings.warn(
|
|
f"Failed to generate submodule inputs because of the following error:\n{e}"
|
|
)
|
|
finally:
|
|
for h in handles:
|
|
h.remove()
|
|
return results
|
|
|
|
|
|
def report_exportability(
|
|
mod: torch.nn.Module,
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[dict[str, Any]] = None,
|
|
*,
|
|
strict: bool = True,
|
|
pre_dispatch: bool = False,
|
|
) -> dict[str, Optional[Exception]]:
|
|
"""
|
|
Report exportability issues for a module in one-shot.
|
|
|
|
Args:
|
|
mod: root module.
|
|
args: args to the root module.
|
|
kwargs: kwargs to the root module.
|
|
Returns:
|
|
A dict that maps from submodule name to the exception that was raised when trying to export it.
|
|
`None` means the module is exportable without issue.
|
|
Sample output:
|
|
{
|
|
'': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
|
|
'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
|
|
'submod_2': None
|
|
}
|
|
"""
|
|
|
|
log_export_usage(event="export.report_exportability")
|
|
|
|
kwargs = kwargs or {}
|
|
|
|
all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
|
|
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
|
|
|
|
tried_module_types = set()
|
|
report: dict[str, Optional[Exception]] = {}
|
|
|
|
def try_export(module, module_name, args, kwargs):
|
|
nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
|
|
|
|
if type(module) in tried_module_types:
|
|
return
|
|
tried_module_types.add(type(module))
|
|
|
|
if args is not None or kwargs is not None:
|
|
try:
|
|
torch.export._trace._export(
|
|
module,
|
|
args,
|
|
kwargs,
|
|
strict=strict,
|
|
pre_dispatch=pre_dispatch,
|
|
)
|
|
report[module_name] = None
|
|
log.info("Successfully exported `%s`", module_name)
|
|
return
|
|
except Exception as e:
|
|
short_msg = repr(e).split("\n")[0]
|
|
log.warning(
|
|
"Failed exporting `%s` with exception: %s", module_name, short_msg
|
|
)
|
|
report[module_name] = e
|
|
|
|
for name, submod in module.named_children():
|
|
sub_module_name = name if module_name == "" else f"{module_name}.{name}"
|
|
|
|
submod_args, submod_kwargs = submod_inputs.get(
|
|
sub_module_name, (None, None)
|
|
)
|
|
|
|
try_export(submod, sub_module_name, submod_args, submod_kwargs)
|
|
|
|
return
|
|
|
|
try_export(mod, "", args, kwargs)
|
|
|
|
unique_issues = set()
|
|
for exception in report.values():
|
|
if exception is not None:
|
|
key = repr(exception).split("\\n")[0]
|
|
unique_issues.add(key)
|
|
|
|
log.warning("Found %d export issues:", len(unique_issues))
|
|
for issue in unique_issues:
|
|
log.warning(issue)
|
|
|
|
return report
|