554 lines
21 KiB
Python
554 lines
21 KiB
Python
# mypy: allow-untyped-defs
|
|
import enum
|
|
import inspect
|
|
import numbers
|
|
import types
|
|
import typing
|
|
import warnings
|
|
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._jit_internal import boolean_dispatched
|
|
from torch._ops import OpOverload, OpOverloadPacket
|
|
|
|
from ._compatibility import compatibility
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .node import Argument
|
|
|
|
__all__ = [
|
|
"ArgsKwargsPair",
|
|
"check_for_mutable_operation",
|
|
"get_signature_for_torch_op",
|
|
"create_type_hint",
|
|
"type_matches",
|
|
"normalize_function",
|
|
"normalize_module",
|
|
]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class ArgsKwargsPair(NamedTuple):
|
|
"""
|
|
Simple named tuple for wrapping args/kwargs pairs.
|
|
"""
|
|
|
|
args: tuple[Any, ...]
|
|
kwargs: dict[str, Any]
|
|
|
|
|
|
_manual_overrides: dict[Callable, list[inspect.Signature]] = {}
|
|
|
|
|
|
def _nonzero_schemas():
|
|
signatures = []
|
|
|
|
def nonzero(self):
|
|
pass
|
|
|
|
signatures.append(inspect.signature(nonzero))
|
|
|
|
def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
|
|
pass
|
|
|
|
signatures.append(inspect.signature(nonzero))
|
|
|
|
return signatures
|
|
|
|
|
|
_manual_overrides[torch.nonzero] = _nonzero_schemas()
|
|
|
|
|
|
class _FakeGlobalNamespace:
|
|
def __getattr__(self, name):
|
|
if name == "torch":
|
|
return torch
|
|
raise RuntimeError("Expected a torch namespace lookup")
|
|
|
|
|
|
_type_eval_globals = {
|
|
"Tensor": torch.Tensor,
|
|
"Device": torch.device,
|
|
"Layout": torch.layout,
|
|
"number": numbers.Number,
|
|
"Future": torch.jit.Future,
|
|
"AnyEnumType": enum.Enum,
|
|
"QScheme": torch.qscheme,
|
|
"__torch__": _FakeGlobalNamespace(),
|
|
"NoneType": type(None),
|
|
"Storage": torch.UntypedStorage,
|
|
"t": typing.TypeVar("t"),
|
|
}
|
|
for k in dir(typing):
|
|
_type_eval_globals[k] = getattr(typing, k)
|
|
|
|
|
|
def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
|
|
"""
|
|
Convert a TorchScript type to a Python type (including subtypes) via
|
|
eval'ing the annotation_str. _type_eval_globals sets up expressions
|
|
like "List" and "Future" to map to actual types (typing.List and jit.Future)
|
|
"""
|
|
return eval(ts_type.annotation_str, _type_eval_globals)
|
|
|
|
|
|
def _torchscript_schema_to_signature_impl(
|
|
ts_schema: torch._C.FunctionSchema,
|
|
) -> inspect.Signature:
|
|
from inspect import Parameter
|
|
|
|
parameters: list[Parameter] = []
|
|
for arg in ts_schema.arguments:
|
|
arg_type = _torchscript_type_to_python_type(arg.type)
|
|
default = arg.default_value if arg.has_default_value() else Parameter.empty
|
|
# TODO: Figure out if this is safe. It seems like when generating the type signatures for
|
|
# PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
|
|
# argument name. Downstream, if someone converts that positional argument to a keyword
|
|
# argument, the name mismatch will break things, so here we're going to normalize the
|
|
# name to "input"
|
|
name = arg.name if arg.name != "self" else "input"
|
|
kind = (
|
|
Parameter.KEYWORD_ONLY
|
|
if arg.kwarg_only
|
|
else Parameter.POSITIONAL_OR_KEYWORD
|
|
)
|
|
# "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
|
|
if name == "from":
|
|
assert kind == Parameter.POSITIONAL_OR_KEYWORD
|
|
# ParameterKind type is internal implementation detail to inspec package
|
|
# which makes it hard to do type annotation
|
|
kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
|
|
# This renders all previous arguments to positional only
|
|
for idx, p in enumerate(parameters):
|
|
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
|
|
parameters[idx] = Parameter(
|
|
name=p.name,
|
|
kind=Parameter.POSITIONAL_ONLY,
|
|
default=p.default,
|
|
annotation=p.annotation,
|
|
)
|
|
parameters.append(
|
|
Parameter(name=name, kind=kind, default=default, annotation=arg_type)
|
|
)
|
|
return_types = [
|
|
_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
|
|
]
|
|
if len(return_types) == 0:
|
|
return_type = None
|
|
elif len(return_types) == 1:
|
|
return_type = return_types[0]
|
|
else:
|
|
return_type = tuple(return_types)
|
|
|
|
return inspect.Signature(parameters, return_annotation=return_type)
|
|
|
|
|
|
_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
|
|
|
|
|
|
def _torchscript_schema_to_signature(
|
|
ts_schema: torch._C.FunctionSchema,
|
|
) -> inspect.Signature:
|
|
# Cached as it's called in the hot path of FakeTensor dispatch
|
|
cache_key = ts_schema.name, ts_schema.overload_name
|
|
cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
|
|
if cache_val is not None:
|
|
return cache_val
|
|
|
|
res = _torchscript_schema_to_signature_impl(ts_schema)
|
|
_SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
|
|
return res
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def check_for_mutable_operation(
|
|
target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
|
|
):
|
|
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
|
|
|
|
if signatures and schemas:
|
|
matched_schemas = []
|
|
|
|
# Iterate through all of the schema until we find one that matches
|
|
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
|
|
# values. If none matches, `new_args_and_kwargs` will be None
|
|
for candidate_signature, schema in zip(signatures, schemas):
|
|
try:
|
|
candidate_signature.bind(*args, **kwargs)
|
|
matched_schemas.append((candidate_signature, schema))
|
|
except TypeError:
|
|
continue
|
|
|
|
def throw_if_mutable(schema):
|
|
if schema.is_mutable:
|
|
raise RuntimeError(
|
|
f"Tried to trace mutable operation {schema}. FX only supports functional "
|
|
f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
|
|
f"are not supported"
|
|
)
|
|
|
|
if len(matched_schemas) == 0:
|
|
# Did not match any schema. Cannot check for mutation
|
|
pass
|
|
elif len(matched_schemas) == 1:
|
|
# Matched exactly one schema, unambiguous
|
|
_, schema_to_check = matched_schemas[0]
|
|
throw_if_mutable(schema_to_check)
|
|
else:
|
|
# Ambiguous schema match. Since mutability checking is best effort,
|
|
# do nothing.
|
|
pass
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
|
|
"""
|
|
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
|
|
objects corresponding to the overloads of that op.. May return `None` if a signature
|
|
could not be retrieved.
|
|
|
|
Args:
|
|
op (Callable): An operator on the `torch` namespace to look up a signature for
|
|
|
|
Returns:
|
|
Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
|
|
operator, or None if the operator signatures could not be retrieved. If
|
|
return_schemas=True, returns a tuple containing the optional Python signatures
|
|
and the optional TorchScript Function signature
|
|
"""
|
|
if isinstance(op, OpOverload):
|
|
schemas = [op._schema]
|
|
elif isinstance(op, OpOverloadPacket):
|
|
schemas = [getattr(op, overload)._schema for overload in op.overloads()]
|
|
else:
|
|
override = _manual_overrides.get(op)
|
|
if override:
|
|
return (override, None) if return_schemas else None
|
|
|
|
aten_fn = torch.jit._builtins._find_builtin(op)
|
|
|
|
if aten_fn is None:
|
|
return (None, None) if return_schemas else None
|
|
schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
|
|
|
|
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
|
|
return (signatures, schemas) if return_schemas else signatures
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def create_type_hint(x):
|
|
"""
|
|
Produces a type hint for the given argument.
|
|
|
|
The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`.
|
|
|
|
If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass
|
|
of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned.
|
|
If no such object is found, it defaults to `List[Any]`.
|
|
|
|
If `x` is neither a `list` nor a `tuple`, it returns `x`.
|
|
"""
|
|
try:
|
|
if isinstance(x, (list, tuple)):
|
|
# todo(chilli): Figure out the right way for mypy to handle this
|
|
if isinstance(x, list):
|
|
|
|
def ret_type(x):
|
|
return list[x] # type: ignore[valid-type]
|
|
|
|
else:
|
|
|
|
def ret_type(x):
|
|
return tuple[x, ...] # type: ignore[valid-type]
|
|
|
|
if len(x) == 0:
|
|
return ret_type(Any)
|
|
base_type = x[0]
|
|
for t in x:
|
|
if issubclass(t, base_type):
|
|
continue
|
|
elif issubclass(base_type, t):
|
|
base_type = t
|
|
else:
|
|
return ret_type(Any)
|
|
return ret_type(base_type)
|
|
except Exception:
|
|
# We tried to create a type hint for list but failed.
|
|
warnings.warn(
|
|
f"We were not able to successfully create type hint from the type {x}"
|
|
)
|
|
return x
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def type_matches(signature_type: Any, argument_type: Any):
|
|
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
|
|
|
|
if signature_type is argument_type:
|
|
return True
|
|
|
|
# Union types in signature. Given type needs to match one of the
|
|
# contained types in the Union
|
|
if sig_origin_type is typing.Union and signature_type != argument_type:
|
|
sig_contained = signature_type.__args__
|
|
return any(type_matches(c, argument_type) for c in sig_contained)
|
|
|
|
if getattr(signature_type, "__origin__", None) is list:
|
|
sig_el_type = signature_type.__args__[0]
|
|
|
|
# int can be promoted to list[int]
|
|
if argument_type is int and sig_el_type is int:
|
|
return True
|
|
|
|
if not inspect.isclass(sig_el_type):
|
|
warnings.warn(
|
|
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
|
|
)
|
|
return False
|
|
if getattr(argument_type, "__origin__", None) is list:
|
|
return issubclass(argument_type.__args__[0], sig_el_type)
|
|
|
|
def is_homogeneous_tuple(t):
|
|
if getattr(t, "__origin__", None) is not tuple:
|
|
return False
|
|
contained = t.__args__
|
|
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
|
|
return True
|
|
return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
|
|
|
|
# Tuple[T] is accepted for List[T] parameters
|
|
return is_homogeneous_tuple(argument_type)
|
|
|
|
# Dtype is an int in schemas
|
|
if signature_type is int and argument_type is torch.dtype:
|
|
return True
|
|
|
|
if signature_type is numbers.Number and argument_type in {int, float}:
|
|
return True
|
|
if inspect.isclass(argument_type) and inspect.isclass(signature_type):
|
|
return issubclass(argument_type, signature_type)
|
|
|
|
return False
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def normalize_function(
|
|
target: Callable,
|
|
args: tuple[Any, ...],
|
|
kwargs: Optional[dict[str, Any]] = None,
|
|
arg_types: Optional[tuple[Any]] = None,
|
|
kwarg_types: Optional[dict[str, Any]] = None,
|
|
normalize_to_only_use_kwargs: bool = False,
|
|
) -> Optional[ArgsKwargsPair]:
|
|
"""
|
|
Returns normalized arguments to PyTorch functions. This means that
|
|
`args/kwargs` will be matched up to the functional's
|
|
signature and return exclusively kwargs in positional order if
|
|
`normalize_to_only_use_kwargs` is True.
|
|
Also populates default values. Does not support positional-only
|
|
parameters or varargs parameters (*args, **kwargs). Does not support modules.
|
|
|
|
May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
|
|
|
|
Args:
|
|
target (Callable): Function that we are normalizing
|
|
args (Tuple[Any]): Tuple of args to the function
|
|
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
|
|
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
|
|
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
|
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
|
|
|
|
Returns:
|
|
|
|
Returns normalized_args_and_kwargs, or `None` if not successful.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
new_args_and_kwargs = None
|
|
if not isinstance(target, types.BuiltinFunctionType) and not (
|
|
isinstance(target, (OpOverloadPacket, OpOverload))
|
|
):
|
|
target_for_analysis = target
|
|
if target in boolean_dispatched:
|
|
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
|
|
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
|
|
# branches of the dispatch have exactly the same signature. If they do, use the `true`
|
|
# branch signature for analysis. Otherwise, leave this un-normalized
|
|
assert not isinstance(target, str)
|
|
dispatched = boolean_dispatched[target]
|
|
if_true, if_false = dispatched["if_true"], dispatched["if_false"]
|
|
if (
|
|
inspect.signature(if_true).parameters
|
|
!= inspect.signature(if_false).parameters
|
|
):
|
|
return None
|
|
target_for_analysis = if_true
|
|
|
|
assert callable(target_for_analysis)
|
|
sig = inspect.signature(inspect.unwrap(target_for_analysis))
|
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
|
sig, args, kwargs, normalize_to_only_use_kwargs
|
|
)
|
|
else:
|
|
assert callable(target)
|
|
torch_op_schemas = get_signature_for_torch_op(target)
|
|
matched_schemas = []
|
|
if torch_op_schemas:
|
|
# Iterate through all of the schema until we find one that matches
|
|
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
|
|
# values. If none matches, `new_args_and_kwargs` will be None
|
|
for candidate_signature in torch_op_schemas:
|
|
try:
|
|
candidate_signature.bind(*args, **kwargs)
|
|
matched_schemas.append(candidate_signature)
|
|
except TypeError:
|
|
continue
|
|
|
|
if len(matched_schemas) == 0:
|
|
# Did not match any schema. Cannot normalize
|
|
pass
|
|
elif len(matched_schemas) == 1:
|
|
# Matched exactly one schema, unambiguous
|
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
|
matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
|
|
)
|
|
else:
|
|
if arg_types is not None or kwarg_types is not None:
|
|
arg_types = arg_types if arg_types else cast(tuple[Any], ())
|
|
kwarg_types = kwarg_types if kwarg_types else {}
|
|
for candidate_signature in torch_op_schemas:
|
|
sig_matches = True
|
|
try:
|
|
bound_types = candidate_signature.bind(
|
|
*arg_types, **kwarg_types
|
|
)
|
|
for arg_name, arg_type in bound_types.arguments.items():
|
|
param = candidate_signature.parameters[arg_name]
|
|
sig_matches = sig_matches and type_matches(
|
|
param.annotation, arg_type
|
|
)
|
|
except TypeError:
|
|
sig_matches = False
|
|
if sig_matches:
|
|
new_args_and_kwargs = (
|
|
_args_kwargs_to_normalized_args_kwargs(
|
|
candidate_signature,
|
|
args,
|
|
kwargs,
|
|
normalize_to_only_use_kwargs,
|
|
)
|
|
)
|
|
break
|
|
else:
|
|
# Matched more than one schema. In this situation, the caller must provide the types of
|
|
# the arguments of the overload they expect.
|
|
schema_printouts = "\n".join(
|
|
str(schema) for schema in matched_schemas
|
|
)
|
|
raise RuntimeError(
|
|
f"Tried to normalize arguments to {torch.typename(target)} but "
|
|
f"the schema match was ambiguous! Please provide argument types to "
|
|
f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
|
|
)
|
|
|
|
return new_args_and_kwargs
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def normalize_module(
|
|
root: torch.nn.Module,
|
|
target: str,
|
|
args: tuple[Any],
|
|
kwargs: Optional[dict[str, Any]] = None,
|
|
normalize_to_only_use_kwargs: bool = False,
|
|
) -> Optional[ArgsKwargsPair]:
|
|
"""
|
|
Returns normalized arguments to PyTorch modules. This means that
|
|
`args/kwargs` will be matched up to the functional's
|
|
signature and return exclusively kwargs in positional order if
|
|
`normalize_to_only_use_kwargs` is True.
|
|
Also populates default values. Does not support positional-only
|
|
parameters or varargs parameters (*args, **kwargs).
|
|
|
|
Args:
|
|
root (nn.Module): root module upon which we query modules
|
|
target (Callable): Function that we are normalizing
|
|
args (Tuple[Any]): Tuple of args to the function
|
|
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
|
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
|
|
|
|
Returns:
|
|
|
|
Returns normalized_args_and_kwargs, or `None` if not successful.
|
|
"""
|
|
try:
|
|
submod = root.get_submodule(target)
|
|
except AttributeError as e:
|
|
raise RuntimeError(
|
|
f"Tried to normalize node with target {target} but root did not "
|
|
f"have that target!"
|
|
) from e
|
|
if hasattr(submod.__class__, "__name__"):
|
|
classname = submod.__class__.__name__
|
|
if getattr(torch.nn, classname, None) == submod.__class__:
|
|
sig = inspect.signature(inspect.unwrap(submod.forward))
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
|
sig, args, kwargs, normalize_to_only_use_kwargs
|
|
)
|
|
return new_args_and_kwargs
|
|
return None
|
|
|
|
|
|
def _args_kwargs_to_normalized_args_kwargs(
|
|
sig: inspect.Signature,
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
normalize_to_only_use_kwargs: bool,
|
|
) -> Optional[ArgsKwargsPair]:
|
|
"""
|
|
Given a call target, args, and kwargs, return the arguments normalized into
|
|
an ArgsKwargsPair, or None if the type signature is not supported by
|
|
this normalization.
|
|
|
|
Args:
|
|
|
|
sig (inspect.Signature): Signature object for the target
|
|
args (Tuple): Arguments that appear at the callsite for `target`
|
|
kwargs (Dict): Keyword arguments that appear at the callsite for `target`
|
|
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
|
|
|
|
Returns:
|
|
|
|
Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
|
|
this target is not supported.
|
|
"""
|
|
|
|
# Don't currently support positional-only
|
|
# or varargs (*args, **kwargs) signatures
|
|
supported_parameter_types = {
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
}
|
|
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
|
|
# Add an exception for one signature, which is common for random/uniform, i.e.:
|
|
# Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
|
|
# `from` is Python keyword and as such functions with that signature should have
|
|
# positional-only args, but at the same time they could be dispatched as kwargs
|
|
if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
|
|
return None
|
|
|
|
bound_args = sig.bind(*args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
|
|
new_kwargs: dict[str, Any] = {}
|
|
new_args: list[Any] = []
|
|
for i, param in enumerate(sig.parameters):
|
|
if not normalize_to_only_use_kwargs and i < len(args):
|
|
new_args.append(bound_args.arguments[param])
|
|
else:
|
|
new_kwargs[param] = bound_args.arguments[param]
|
|
|
|
return ArgsKwargsPair(tuple(new_args), new_kwargs)
|