team-10/venv/Lib/site-packages/torch/_dynamo/variables/functions.py

2022 lines
74 KiB
Python
Raw Normal View History

2025-08-02 02:00:33 +02:00
# mypy: ignore-errors
"""
Function-related variable tracking classes for Dynamo's symbolic execution.
This module contains classes that track different types of functions during graph
compilation, including:
- User-defined functions and methods
- Built-in functions and methods
- Wrapped functions (e.g. from decorators)
- Special function types (e.g. functools.partial)
- Triton kernels and related function types
These classes are responsible for:
- Tracking function calls and their arguments
- Managing function closures and cell variables
- Handling function attributes and special methods
- Maintaining guards for function identity and closure contents
- Supporting function inlining and specialization
- Enabling proper symbolic execution of different function types
The variable trackers here work together with the rest of Dynamo to enable
accurate graph capture while handling Python's various function-related behaviors.
"""
import builtins
import functools
import inspect
import itertools
import sys
import types
from collections.abc import Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
import torch
from .. import graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
get_dynamo_observed_exception,
handle_observed_exception,
InfiniteGeneratorError,
ObservedException,
ObservedGeneratorExit,
ObservedUserStopIteration,
raise_observed_exception,
SkipFrame,
unimplemented_v2,
Unsupported,
)
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import (
check_constant_args,
check_unspec_or_constant_args,
cmp_name_to_op_mapping,
counters,
identity,
is_function,
is_wrapper_or_member_descriptor,
istype,
make_cell,
)
from .base import typestr, ValueMutationNew, VariableTracker
from .constant import ConstantVariable
try:
from torch.distributed.fsdp._fully_shard import _fsdp_param_group
except ModuleNotFoundError:
_fsdp_param_group = None
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._higher_order_ops.triton_kernel_wrap import (
TritonGridType,
TritonKernelType,
)
_F = TypeVar("_F", bound=Callable)
def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.
if isinstance(val, VariableTracker):
return val
elif not source:
return VariableTracker.build(tx, val)
else:
# Create a lazy variable to avoid guarding on __defaults__ unless really
# needed.
return variables.LazyVariableTracker.create(val, source)
def wrap_args_kwargs(tx: "InstructionTranslator", result):
for k, v in list(result.items()):
if isinstance(v, (tuple, dict)):
# args/kwargs
result[k] = wrap_bound_arg(tx, v)
def init_cellvars(parent, result: dict[str, VariableTracker], code):
"""
Update `result` to add mapping from local name to new cells created
directly by `code`, or update SideEffects in `parent` if the a local cell is
already in `result` (cell argument).
"""
side_effects = parent.output.side_effects
for name in code.co_cellvars:
new_cell = side_effects.track_cell_new()
if name in result:
# This handles when a function argument is a cell (e.g., captured by
# a nested func). See `MAKE_CELL` bytecode for more info.
side_effects.store_cell(new_cell, result.pop(name))
result[name] = new_cell
def _create_nested_fn(
code, f_globals, name, defaults, closure, kwdefaults, annotations
):
from types import FunctionType
func = FunctionType(code, f_globals, name, defaults, closure)
func.__kwdefaults__ = kwdefaults
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert annotations is None or isinstance(annotations, dict)
func.__annotations__ = annotations
return func
fn_known_dunder_attrs = {
"__annotations__",
"__defaults__",
"__kwdefaults__",
"__code__",
"__globals__",
"__closure__",
"__doc__",
}
def fn_var_getattr(tx, fn, source, name):
source = source and AttrSource(source, name)
try:
subobj = inspect.getattr_static(fn, name)
except AttributeError:
# function does not have a __getattr__ or __getattribute__ method,
# so we can safely assume that this attribute is absent
raise_observed_exception(AttributeError, tx)
# Special handling for known dunder attributes
if name in fn_known_dunder_attrs:
subobj = getattr(fn, name)
if source:
return variables.LazyVariableTracker.create(subobj, source)
return VariableTracker.build(tx, subobj)
class BaseUserFunctionVariable(VariableTracker):
def get_filename(self):
return self.get_code().co_filename
def get_name(self):
return self.get_code().co_name
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
result = False
try:
result = hasattr(self.get_function(), name)
except NotImplementedError:
if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
result = True
return variables.ConstantVariable.create(result)
def inspect_parameter_names(self):
return list(inspect.signature(self.get_function()).parameters)
def closure_vars(self, tx):
return {}
class UserFunctionVariable(BaseUserFunctionVariable):
"""Some unsupported user-defined global function"""
_nonvar_fields = {
"fn",
"is_constant",
*BaseUserFunctionVariable._nonvar_fields,
}
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
return cls(value, source=source)
def __init__(self, fn, is_constant=False, **kwargs) -> None:
super().__init__(**kwargs)
if getattr(fn, "_dynamo_marked_constant", False):
# This method should be treated as a constant for the purposes of compilation
self.is_constant = True
else:
self.is_constant = False
assert isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)), (
f"expected FunctionType found {typestr(fn)} {fn}"
)
# TODO(anijain2305) - Replace directly calling UserFunctionVariable with
# VariableBuilder, which handles the wrapping of _torchdynamo_inline.
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
self.fn: types.FunctionType = fn
def as_python_constant(self):
if istype(self, UserFunctionVariable):
return self.fn
# subclasses (such as methods) usually aren't a constant
return super().as_python_constant()
def self_args(self):
return []
def get_function(self):
return self.fn
def get_code(self):
return self.fn.__code__
def python_type(self):
return types.FunctionType
def has_self(self):
return getattr(self.fn, "__self__", None) is not None
def get_globals(self):
return self.fn.__globals__
def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]:
"""
Assume `args` and `kwargs` are VariableTracker arguments for a call to
this function, create new bindings for initial locals.
"""
assert not self.is_constant
root_tx = parent.output.root_tx
wrap = functools.partial(wrap_bound_arg, tx=root_tx)
fn: types.FunctionType = self.fn
defaults = fn.__defaults__ or []
defaults_sources = [
None if self.source is None else DefaultsSource(self.source, idx)
for idx, _ in enumerate(defaults)
]
fake_func = types.FunctionType(
fn.__code__,
fn.__globals__,
fn.__name__,
tuple(
[
wrap(val=arg, source=source)
for arg, source in zip(defaults, defaults_sources)
]
),
fn.__closure__,
)
if fn.__kwdefaults__:
kwdefaults_sources = {
k: (
None
if self.source is None
else DefaultsSource(self.source, k, is_kw=True)
)
for k in fn.__kwdefaults__
}
fake_func.__kwdefaults__ = {
k: wrap(val=v, source=kwdefaults_sources[k])
for k, v in fn.__kwdefaults__.items()
}
bound = inspect.signature(fake_func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(root_tx, result)
init_cellvars(parent, result, fn.__code__)
closure = self.fn.__closure__ or ()
assert len(closure) == len(self.fn.__code__.co_freevars)
for idx, name, cell in zip(
itertools.count(), self.fn.__code__.co_freevars, closure
):
# TODO refactor these 3 branches.
side_effects = parent.output.side_effects
if cell in side_effects:
cell_var = side_effects[cell]
elif self.source:
closure_cell = GetItemSource(
AttrSource(self.source, "__closure__"), idx
)
closure_cell_contents = AttrSource(closure_cell, "cell_contents")
try:
contents_var = VariableTracker.build(
parent, cell.cell_contents, closure_cell_contents
)
except ValueError:
# Cell has not yet been assigned
contents_var = variables.DeletedVariable()
cell_var = side_effects.track_cell_existing(
closure_cell, cell, contents_var
)
else:
# TODO figure out why source isn't available here, and whether
# we can fix that and remove this branch.
try:
contents_var = VariableTracker.build(parent, cell.cell_contents)
except ValueError:
# Cell has not yet been assigned
contents_var = variables.DeletedVariable()
cell_var = side_effects.track_cell_existing(None, cell, contents_var)
result[name] = cell_var
return result
def var_getattr(self, tx: "InstructionTranslator", name: str):
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
return fn_var_getattr(tx, self.fn, self.source, name)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
result = hasattr(self.fn, name)
return variables.ConstantVariable.create(result)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle a `nonstrict_trace(fn)` call
if self.fn is torch._dynamo.nonstrict_trace:
bound = inspect.signature(self.fn).bind(*args, **kwargs)
fn_var = bound.args[0]
if not isinstance(fn_var, BaseUserFunctionVariable):
typ = fn_var.python_type()
msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
unimplemented_v2(
gb_type="TypeError from user code",
context=f"call_function({self.value}, {args}, {kwargs})",
explanation=msg,
hints=[
*graph_break_hints.USER_ERROR,
],
)
if not isinstance(fn_var, UserFunctionVariable):
fn_name = fn_var.get_name()
msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
unimplemented_v2(
gb_type="Limitation of `nonstrict_trace",
context=f"{self}",
explanation=msg,
hints=[
f"make sure definition of {fn_name} is outside ",
"`torch.compile` region",
],
)
fn = fn_var.fn
return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
if self.is_constant:
return invoke_and_store_as_constant(
tx, self.fn, self.get_name(), args, kwargs
)
if (
tx.output.current_tracer.under_activation_checkpoint
and not tx.output.current_tracer.allow_side_effects_under_checkpoint
):
try:
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
except Exception:
FSDPState = None
if FSDPState is not None and self.fn in [
FSDPState._pre_forward,
FSDPState._post_forward,
]:
with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
return super().call_function(tx, args, kwargs)
return super().call_function(tx, args, kwargs)
class BuiltinMethodVariable(BaseUserFunctionVariable):
def __init__(self, fn, is_constant=False, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(fn, types.BuiltinMethodType)
self.fn = fn
@staticmethod
def is_supported_builtin_method(obj):
method_self = obj.__self__
method_name = obj.__name__
# TODO(anijain2305) - Add support for more builtin methods
# Supports tuple.__new__ and frozenset({....}).__contains__
return (method_self is tuple and method_name == "__new__") or (
type(method_self) is frozenset and method_name == "__contains__"
)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
method_self = self.fn.__self__
name = self.fn.__name__
obj_source = self.source and AttrSource(self.source, "__self__")
obj_vt = VariableTracker.build(tx, method_self, obj_source)
return obj_vt.call_method(tx, name, args, kwargs)
class LocalGeneratorObjectVariable(VariableTracker):
def __init__(
self,
code: types.CodeType,
f_globals,
inline_tracer: Optional["InstructionTranslator"],
**kwargs,
):
super().__init__(**kwargs)
self.code = code
self.f_globals = f_globals
self.inline_tracer = inline_tracer
def get_code(self):
return self.code
def get_filename(self):
return self.get_code().co_filename
def get_name(self):
return self.get_code().co_name
def get_function(self):
raise NotImplementedError
def has_self(self):
return False
def __name__(self):
return self.get_name()
def __str__(self):
return f"{self.__class__.__name__}({self.get_name()})"
__repr__ = __str__
def reconstruct(self, codegen):
from torch._dynamo.side_effects import disallow_side_effects_in_generator
from torch._dynamo.symbolic_convert import (
InstructionTranslator,
save_and_restart_speculation_log,
temporarely_allow_writes_to_output_graph,
)
tx = InstructionTranslator.current_tx()
save = save_and_restart_speculation_log(tx)
disallow = disallow_side_effects_in_generator(tx)
temp = temporarely_allow_writes_to_output_graph(tx)
with save, disallow, temp:
tracer = self._get_inline_tracer(tx)
if not tracer.generator_exhausted:
self.remaining_items = self.force_unpack_var_sequence(tx)
variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
def bind_args(self, tx, args, kwargs):
return self.fn.bind_args(tx, args, kwargs)
def get_globals(self):
return self.f_globals
def python_type(self):
return types.GeneratorType
def _get_inline_tracer(self, tx):
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
if self.inline_tracer is None:
self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
tx, self, [], {}
)
return self.inline_tracer
def next_variable(self, tx):
tracer = self._get_inline_tracer(tx)
if self._is_generator_exhausted():
raise_observed_exception(StopIteration, tx)
try:
# Hierarchically, tx can be seen as the parent of the inline tracer
# created on call_function. Any exception needs to be propagated to tx
# for Dynamo to behave correctly
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_()
except ObservedException as e:
raise e
except InfiniteGeneratorError:
# test/dynamo/test_misc.py::test_iterator_limit
raise
except Unsupported as e:
torch._dynamo.eval_frame.skip_code(self.get_code())
raise SkipFrame from e
finally:
counters["unimplemented"] |= counters["inline_call"]
def has_unpack_var_sequence(self, tx):
return False
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
return True
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
def _setup_exception(self, tx, exc):
tracer = self._get_inline_tracer(tx)
try:
tracer._raise_exception_variable(exc)
except ObservedException as e:
# if no handler is available (i.e. user code doesn't catch it), the
# exception is raised again.
tracer.exception_handler(e)
def _is_generator_just_started(self):
return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
def _is_generator_exhausted(self):
return getattr(self.inline_tracer, "generator_exhausted", False)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
# iter(gen) returns itself
return self
elif name == "send":
# Sends a value into the generator function. Returns the next value
# yielded by the generator, or raises StopIteration if the generator
# exits without yielding another value
if self._is_generator_just_started() and len(args):
# can't send non-None value to a just-started generator
# Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
if not all(
isinstance(arg, ConstantVariable) and arg.value is None
for arg in args
):
raise_observed_exception(TypeError, tx)
tracer = self._get_inline_tracer(tx)
tracer.push_many(args)
return self.next_variable(tx)
elif name == "close":
# * Raises a GeneratorExit at the point where the generator function was paused.
# * If the generator function catches the exception and returns a
# value, this value is returned from close() - Python 3.13+
# * If the generator function is already closed, or raises GeneratorExit
# (by not catching the exception), close() returns None.
# * If the generator yields a value, a RuntimeError is raised.
# * If the generator raises any other exception, it is propagated to the caller.
# * If the generator has already exited due to an exception or normal
# exit, close() returns None and has no other effect.
# Return None if close is called on a just-started generator
# See test GeneratorCloseCpythonTests::test_close_not_started
tracer = self._get_inline_tracer(tx)
if self._is_generator_just_started() or self._is_generator_exhausted():
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
# Raise GeneratorExit to see if user code catches it. Any other exception
# is propagated to the parent frame.
try:
self._setup_exception(
tx, variables.ExceptionVariable(GeneratorExit, ())
)
# There's an extra block on Python 3.12+ to handle StopIteration
# see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
#
# 1 0 RETURN_GENERATOR
# 2 POP_TOP
# 4 RESUME 0
# 2 6 LOAD_CONST 1 (1)
# 8 YIELD_VALUE 1
# 10 RESUME 1
# 12 POP_TOP
# 14 RETURN_CONST 0 (None)
# >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
# 18 RERAISE 1
# ExceptionTable:
# 4 to 14 -> 16 [0] lasti
if (
sys.version_info >= (3, 12)
and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
):
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
except ObservedGeneratorExit:
# If it doesn't catch, we just return None, as per the text above
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
try:
# Raise RuntimeError if the generator yields any other value
if self.next_variable(tx):
raise_observed_exception(RuntimeError, tx)
except ObservedGeneratorExit:
tracer.generator_exhausted = True
return variables.ConstantVariable(None)
except ObservedUserStopIteration:
# In Python 3.13+, one can capture GeneratorExit and return a value
# See test_generator.py::test_close_capture_GeneratorExit_return
# https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
# https://github.com/python/cpython/pull/104771
assert tracer.symbolic_result is not None
return tracer.symbolic_result
elif name == "throw":
# * Raises an exception at the point where the generator was paused, and
# returns the next value yielded by the generator.
# * If the generator exits without yielding, raise StopIteration
# * If the generator function does not catch the passed-in exception,
# or raises a different exception, then that exception propagates to the caller.
# Setup the exception table and jump target in case of try...finally
tracer = self._get_inline_tracer(tx)
try:
# In Python 3.9, the exception is represented as a triple (typ, val, tb)
# In such cases, we re-raise the exception object given to avoid
# creating a new object, so that IS_OP works.
# See: https://github.com/pytorch/pytorch/pull/146496
self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
except ObservedException: # noqa: TRY203
# propagate the exception back to the parent caller
raise
retval = self.next_variable(tx)
# The exception raised before is still active. We need to check the exception
# table one more time to find the next target. But why? Lets walk
# through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
#
# z = 0
# def whoo():
# global z
# z = 0
# try:
# yield 1
# except ValueError:
# yield 2
# finally:
# z += 1
# z += 10
#
# gen = whoo()
# next(gen)
# gen.throw(ValueError)
# print('z', z) -> z = 1
#
# ...
# >> 58 PUSH_EXC_INFO
#
# 8 60 LOAD_GLOBAL 2 (ValueError)
# 70 CHECK_EXC_MATCH
# 72 POP_JUMP_IF_FALSE 7 (to 88)
# 74 POP_TOP
#
# 9 76 LOAD_CONST 3 (2)
# 78 YIELD_VALUE 3 <------ ValueError is still active here
# 80 RESUME 1
# 82 POP_TOP
# 84 POP_EXCEPT
# 86 jump_backward 34 (to 20)
# ...
#
# ExceptionTable:
# 4 to 8 -> 124 [0] lasti
# 12 to 18 -> 58 [0]
# 20 to 56 -> 124 [0] lasti
# 58 to 82 -> 90 [1] lasti <------ move to 90
# 84 to 86 -> 96 [0]
# 88 to 88 -> 90 [1] lasti
# 90 to 94 -> 96 [0]
# 96 to 116 -> 118 [1] lasti
# 118 to 122 -> 124 [0] lasti
#
# In this scenario, a generator can yield after `throw()` is called. Even
# after the exception is raised a few lines above, it remains active
# within the `78 YIELD_VALUE` instruction. When the generator resumes
# after the second yield on instruction `80 RESUME`, we cannot simply
# return the control flow to the next instruction. Instead, one must
# check the exception table (or equivalent) to find the next target
# In this case, it says the instruction pointer must be moved to 90.
#
# Without this step, if we let the trace proceed to the next
# instruction, it would follow the control flow where the exception
# raised by `throw()` was handled and swallowed, potentially leading
# to incorrect behavior.
exc_type = type("__InternalThrowException", (Exception,), {})
try:
self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
self.next_variable(tx)
except get_dynamo_observed_exception(exc_type):
# We should get back the exception raised before.
pass
else:
raise_observed_exception(RuntimeError, tracer)
return retval
super().call_method(tx, name, args, kwargs)
class ContextlibContextManagerLocalGeneratorObjectVariable(
LocalGeneratorObjectVariable
):
"""
.. note::
This is only used when the function is annotated with @contextlib.contextmanager
It is a special case of a generator function as we do not allow return a context manager
from a torch.compile function.
"""
class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
"""functions that behaves like iterators
.. note::
This is a wrapper around (Nested)UserFunctionVariable
"""
def __init__(
self,
vt: VariableTracker,
*,
generator_cls=LocalGeneratorObjectVariable,
**kwargs,
):
super().__init__(**kwargs)
self.vt = vt
self.generator_cls = generator_cls
def __getattr__(self, name):
if name in self.__class__.__dict__.keys():
return getattr(self, name)
return getattr(self.vt, name)
def _build_inline_tracer(self, tx, args, kwargs):
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
return InliningInstructionTranslator.build_inline_tracer(
tx,
self,
args,
kwargs,
)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert is_generator(self.vt.get_code())
inline_tracer = self._build_inline_tracer(tx, args, kwargs)
code = self.vt.get_code()
f_globals = self.vt.get_globals()
# calling a generator returns a generator object
return self.generator_cls(
code,
f_globals,
inline_tracer,
source=self.source,
)
class FunctionDecoratedByContextlibContextManagerVariable(
LocalGeneratorFunctionVariable
):
"""
.. note::
This is only used when the function is annotated with @contextlib.contextmanager
"""
def __init__(self, vt, **kwargs):
super().__init__(
vt,
generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
**kwargs,
)
def _build_inline_tracer(self, tx, args, kwargs):
# NOTE: This only exists to not break support for context manager when
# config.enable_faithful_generator_behavior = False and
# config.enable_trace_contextlib = True. In case the former is false,
# Dynamo should still be able to trace through @contextmanager functions
tracer = super()._build_inline_tracer(tx, args, kwargs)
assert isinstance(
tracer,
torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
)
tracer.is_generator_from_ctx_manager = True
return tracer
class UserMethodVariable(UserFunctionVariable):
"""Some unsupported user-defined method"""
def __init__(self, fn, obj, **kwargs) -> None:
super().__init__(fn=fn, **kwargs)
self.obj = obj
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
def self_args(self):
return [self.obj]
def python_type(self):
return types.MethodType
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NOTE this is to handle methods annotated by `nonstrict_trace`. Usually
# a `nonstrict_trace`-ed function will be wrapped by
# `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
# but in the case of method, we manually wrap it with `UserMethodVariable`
# inside `UserDefinedObjectVariable.var_getattr`.
#
# We might be able to simplify this away by canonicalizing the
# function/method wrapping code paths.
from ..trace_rules import is_nonstrict_trace_callable
if is_nonstrict_trace_callable(self.fn):
call_args = [*self.self_args(), *args]
var = variables.TorchInGraphFunctionVariable(
self.fn, nonstrict_traceable=True
)
return var.call_function(tx, call_args, kwargs)
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
# since we ensure `forward` of allowed modules can be traced by AOT safely.
# Note this is not only for allowed modules, as user customized modules can extend from
# allowed modules but using parent's `forward` method, which is also covered by this branch.
# If we are tracing the higher order op, we want Dynamo to step inside
# the module call so that Dynamo can see the underlying parameters and
# buffers and raise them as inputs to the graph. The is_root_tracer
# check bypasses the if condition for non-root tracers and directly
# calls the super().call_function at the end, which is basically
# equivalent of inlining the method.
if tx.output.is_root_tracer() and isinstance(
self.obj, variables.NNModuleVariable
):
module_attr = getattr(self.fn, "__module__", "")
# inline torch.nn.utils.parametrize
if (
module_attr is not None
and module_attr.startswith("torch.nn.")
and module_attr != "torch.nn.utils.parametrize"
or self.is_constant
):
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
)
elif (
_fsdp_param_group is not None
and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state
):
return variables.TorchCtxManagerClassVariable(self.fn).call_function(
tx, (self.obj, *args), kwargs
)
if self.is_constant:
fn = getattr(self.obj.value, self.fn.__name__)
return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
return super().call_function(tx, args, kwargs)
def inspect_parameter_names(self):
return super().inspect_parameter_names()[1:]
def var_getattr(self, tx: "InstructionTranslator", name: str):
source = self.source and AttrSource(self.source, name)
if name == "__self__":
return self.obj
if name == "__func__":
return VariableTracker.build(tx, self.fn, source)
return super().var_getattr(tx, name)
class WrappedUserMethodVariable(UserMethodVariable):
def __init__(self, wrapped, context, **kwargs) -> None:
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
class WrappedUserFunctionVariable(UserFunctionVariable):
def __init__(self, wrapped, context, **kwargs) -> None:
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs):
def convert(x):
if isinstance(x, variables.TensorVariable):
return x.get_real_value()
return x.as_python_constant()
args = [convert(x) for x in args]
kwargs = {k: convert(v) for k, v in kwargs.items()}
res = fn(*args, **kwargs)
return tx.output.register_attr_or_module(
res,
name,
source=ConstantSource(name),
)
class NestedUserFunctionVariable(BaseUserFunctionVariable):
_nonvar_fields = {
"f_globals",
*BaseUserFunctionVariable._nonvar_fields,
}
def __init__(
self,
fn_name,
code,
f_globals,
defaults,
kwdefaults,
annotations,
closure,
# This is present when this function is created by
# `functools.wrap(wrapped_fn)(this_fn)`.
wrapped_fn=None,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(fn_name.as_python_constant(), str)
assert isinstance(code.as_python_constant(), types.CodeType)
assert isinstance(f_globals, dict)
self.fn_name = fn_name
self.code = code
self.f_globals = f_globals
self.defaults = defaults
self.kwdefaults = kwdefaults
self.annotations = annotations
self.closure = closure
self.wrapped_fn: Optional[VariableTracker] = wrapped_fn
def self_args(self):
return []
def get_code(self):
return self.code.as_python_constant()
def python_type(self):
return types.FunctionType
def get_function(self):
if self.closure:
raise NotImplementedError
func = types.FunctionType(
self.code.as_python_constant(),
self.f_globals,
self.fn_name.as_python_constant(),
)
if self.defaults:
func.__defaults__ = self.defaults.as_python_constant()
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
if self.annotations:
annotations = self.annotations.as_python_constant()
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert isinstance(annotations, dict)
func.__annotations__ = annotations
return func
def has_closure(self):
return self.closure is not None
def has_self(self):
return False
def get_globals(self):
return self.f_globals
def bind_args(self, parent, args, kwargs):
code = self.get_code()
func = types.FunctionType(
code,
self.f_globals,
self.fn_name.as_python_constant(),
tuple(self.defaults.items) if self.defaults else None,
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
)
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
bound = inspect.signature(func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(parent.output.root_tx, result)
init_cellvars(parent, result, code)
for idx, name in enumerate(code.co_freevars):
assert name not in result
cell = self.closure.items[idx]
result[name] = cell
return result
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from(__name__, "_create_nested_fn")
)
codegen(self.code)
codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)])
codegen(ConstantVariable.create(self.code.value.co_name))
if self.defaults:
codegen(self.defaults)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.closure:
codegen(self.closure)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.kwdefaults:
codegen(self.kwdefaults)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.annotations:
try:
annotations = self.annotations.as_python_constant()
codegen.extend_output(
[codegen.create_load_const_unchecked(annotations)]
)
except NotImplementedError:
codegen(self.annotations)
else:
codegen.extend_output([codegen.create_load_const(None)])
codegen.extend_output(create_call_function(7, False))
if self.wrapped_fn:
codegen.add_push_null(
lambda: codegen.load_import_from("functools", "wraps")
)
codegen(self.wrapped_fn)
codegen.extend_output(create_call_function(1, False))
codegen.extend_output(create_rot_n(2))
codegen.extend_output(create_call_function(1, True))
class SkipFunctionVariable(VariableTracker):
_nonvar_fields = {
"value",
"reason",
*VariableTracker._nonvar_fields,
}
def __init__(self, value, reason=None, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
self.reason = reason
def as_python_constant(self):
return self.value
@classmethod
def create_with_source(cls, value, source):
if not is_wrapper_or_member_descriptor(value):
# These descriptors are not guaranteed to return the same object on
# attribute lookup. They are unlikely to be changed, so we can skip
# guarding them.
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return cls(value, source=source)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented_v2(
gb_type="Skip calling `torch.compiler.disable()`d function",
context=str(self.value),
explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`",
hints=[
"Remove the `torch.compiler.disable` call",
],
)
elif self.value is torch._dynamo.graph_break:
graph_break_msg = kwargs.get("msg", None)
if graph_break_msg:
graph_break_msg = graph_break_msg.as_python_constant()
unimplemented_v2(
gb_type="Call to `torch._dynamo.graph_break()`",
context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
explanation=f"User-inserted graph break. Message: {graph_break_msg}",
hints=[
"Remove the `torch._dynamo.graph_break()` call.",
],
)
else:
qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
module_or = getattr(self.value, "__module__", None)
module_name = "<unknown module>" if module_or is None else str(module_or)
try:
path = inspect.getfile(self.value)
explanation = (
f"Dynamo developers have intentionally marked that the function `{qualname}` "
f"in file `{path}` should not be traced."
)
hints = [
f"Avoid calling the function `{qualname}`.",
]
# TODO improve trace_rules reasoning to provide better hints.
# How do we tell that a function/file should NOT be removed from skip files?
# Do a very basic check for now.
if "_dynamo" not in path:
hints += [
f"Remove the function `{qualname}` or the file `{path}` "
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
"attempting to trace into the function.",
"Please file an issue to PyTorch.",
# TODO suggest mark_force_inline when implemented
]
except TypeError:
known_python_builtin_modules = {"_abc", "_warnings"}
if module_or in known_python_builtin_modules:
explanation = (
f"Dynamo does not know how to trace the Python builtin "
f"`{module_name}.{qualname}`."
)
hints = [
"If you are attempting to call a logging function (e.g. `_warnings.warn`), "
"you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
"Please file an issue on GitHub "
"so the PyTorch team can add support for it. ",
]
elif module_or is not None and module_or.startswith("optree"):
explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}."
hints = [
" Consider using torch.utils._pytree - "
"https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
]
# also warn on it because most users won't see the graph break message
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
else:
explanation = (
f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` "
f"This function is either a Python builtin (e.g. _warnings.warn) "
f"or a third-party C/C++ Python extension (perhaps created with pybind)."
)
hints = [
"If it is a Python builtin, please file an issue on GitHub "
"so the PyTorch team can add support for it and see the next case for a workaround.",
"If it is a third-party C/C++ Python extension, please "
"either wrap it into a PyTorch-understood custom operator "
"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
"for more details) or, if it is traceable, use "
"`torch.compiler.allow_in_graph`.",
]
# also warn on it because most users won't see the graph break message
torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
if qualname == "allow_in_graph":
explanation = (
"Found an allow_in_graph decorator to a function which "
"is created inside the parent function that is getting "
"compiled. This is not supported for now."
)
hints = []
reason = self.reason if self.reason else "<missing reason>"
unimplemented_v2(
gb_type="Attempted to call function marked as skipped",
context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}",
explanation=explanation,
hints=hints,
)
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
return variables.ConstantVariable.create(hasattr(self.value, name))
def var_getattr(self, tx: "InstructionTranslator", name: str):
if name in cmp_name_to_op_mapping:
return variables.GetAttrVariable(self, name)
return fn_var_getattr(tx, self.value, self.source, name)
class WrapperUserFunctionVariable(VariableTracker):
"""
Used to represent a wrapper object that contains the actual callable as an
attribute. For example, torch.jit.script/trace have the original function at
their _torchdynamo_inline attribute. Similarly, functions with
__script_if_tracing_wrapper have the original attr at "__original_fn".
"""
def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None:
super().__init__(**kwargs)
self.wrapper_obj = wrapper_obj
self.attr_to_trace = attr_to_trace
def var_getattr(self, tx: "InstructionTranslator", name):
if name == self.attr_to_trace:
val = getattr(self.wrapper_obj, self.attr_to_trace)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, val, source)
return super().var_getattr(tx, name)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
return variables.UserFunctionVariable(
polyfills.getattr_and_trace
).call_function(
tx, [self, variables.ConstantVariable(self.attr_to_trace), *args], kwargs
)
def _traceable_collective_remaps():
# We can't rely on importing from distributed, since it's not always built
if torch.distributed.is_available():
from torch.distributed._functional_collectives import (
traceable_collective_remaps,
)
return traceable_collective_remaps
return {}
def _traceable_collectives_source(tx: "InstructionTranslator", fn):
assert torch.distributed.is_available(), "Illegal invocation."
assert fn in _traceable_collective_remaps().values()
inner_name = fn.__name__
path_source = tx.import_source("torch.distributed._functional_collectives")
return AttrSource(path_source, inner_name)
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
"""
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
This class provides both a way to check if a function is remappable, and perform the remapping.
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
than status-quo as we currently graph-break on all distributed.* collectives.
"""
def __init__(self, fn, *, replacement_var, **kwargs) -> None:
super().__init__(fn, **kwargs)
assert isinstance(replacement_var, UserFunctionVariable)
self.replacement_var = replacement_var
@staticmethod
def create(tx: "InstructionTranslator", old_fn, source, **options):
new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
return CollectiveFunctionRewriteVariable(
old_fn,
replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
source=source,
**options,
)
@staticmethod
def can_rewrite(variable):
return (
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
)
@staticmethod
def rewrite(tx: "InstructionTranslator", fn):
new_fn = _traceable_collective_remaps()[fn]
return new_fn, _traceable_collectives_source(tx, new_fn)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# call_function must check any unsupported arguments and graph-break.
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
# since that's the contract for putting a mapping in `traceable_collective_remaps`
import torch.distributed as dist
from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
# Merge args into kwargs so positional and keyword args
# can be processed the same way.
signature = inspect.signature(self.fn)
kwargs = dict(signature.bind(*args, **kwargs).arguments)
args = ()
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
unimplemented_v2(
gb_type="async_op=True for distributed collectives",
context=f"{self.fn}, {args=}, {kwargs=}",
explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
if self.fn in (
dist.all_reduce,
dist.reduce_scatter_tensor,
dist._reduce_scatter_base,
):
reduce_op_var = kwargs.get("op")
reduce_op = (
reduce_op_var.value
if reduce_op_var is not None
else signature.parameters["op"].default
)
if reduce_op not in REDUCE_OP_TO_STR:
raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
kwargs["op"] = variables.ConstantVariable.create(
REDUCE_OP_TO_STR[reduce_op]
)
return self.replacement_var.call_function(tx, args, kwargs)
class FunctoolsWrapsVariable(UserFunctionVariable):
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if not kwargs and len(args) == 1:
def wraps(fn):
if isinstance(fn, variables.NestedUserFunctionVariable):
return fn.clone(wrapped_fn=args[0])
unimplemented_v2(
gb_type="functools.wraps",
context=f"{fn}",
explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
return variables.LambdaVariable(wraps)
return super().call_function(tx, args, kwargs)
class CollectionsNamedTupleFunction(UserFunctionVariable):
def as_python_constant(self):
return self.fn
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
constant_args = check_constant_args(args, kwargs)
if constant_args:
value = self.fn(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
return variables.UserDefinedClassVariable(
value, mutation_type=ValueMutationNew()
)
unimplemented_v2(
gb_type="namedtuple construction",
context=f"{args=}, {kwargs=}",
explanation="`torch.compile` only support certain input types for namedtuple",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
class FunctoolsPartialVariable(VariableTracker):
def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None:
super().__init__(**kwargs)
self.func = func
assert isinstance(args, list)
self.args = args
assert isinstance(keywords, dict)
self.keywords = keywords
# fake_value is used for id calculation. Creating this value and id'ng
# on it is sufficient for the tracing purposes.
self.fake_value = functools.partial(identity)
def python_type(self):
return functools.partial
def reconstruct(self, codegen):
codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
codegen(self.func)
if self.args:
codegen.foreach(self.args)
if not self.keywords:
codegen.extend_output(create_call_function(len(self.args) + 1, False))
return
codegen.foreach(self.keywords.values())
keys = tuple(self.keywords.keys())
codegen.extend_output(
codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False)
)
def get_function(self):
return self.as_python_constant()
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
merged_args = self.args + args
merged_kwargs = {**self.keywords, **kwargs}
return self.func.call_function(tx, merged_args, merged_kwargs)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
# functools.partial uses slots, so attributes are constant
return variables.ConstantVariable.create(
hasattr(functools.partial(identity), name)
)
def var_getattr(self, tx: "InstructionTranslator", name: str):
source = self.source and AttrSource(self.source, name)
# Handle __slots__
if name == "func":
return self.func
if name == "args":
return variables.ListVariable(self.args, source=source)
if name == "keywords":
items = {ConstantVariable.create(k): v for k, v in self.keywords.items()}
return variables.ConstDictVariable(items, source=source)
raise_observed_exception(AttributeError, tx)
def as_python_constant(self):
return functools.partial(
self.func.as_python_constant(),
*[arg.as_python_constant() for arg in self.args],
**{k: v.as_python_constant() for k, v in self.keywords.items()},
)
def guard_as_python_constant(self):
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
return functools.partial(
self.func.guard_as_python_constant(),
*[v.guard_as_python_constant() for v in self.args],
**{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
)
class PolyfilledFunctionVariable(VariableTracker):
_nonvar_fields = {
"fn",
"wrapped_fn",
"traceable_fn",
*VariableTracker._nonvar_fields,
}
@classmethod
@functools.lru_cache(None)
def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
return {}
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return cls(value, source=source)
def __init__(self, fn: _F, **kwargs) -> None:
super().__init__(**kwargs)
self.fn: _F = fn
handler = self._get_polyfill_handlers().get(fn, fn)
assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}"
for candidate_attr in (
"__torch_dynamo_polyfill__", # registered polyfill
"__python_implementation__", # self handler from third-party libraries
):
candidate = getattr(handler, candidate_attr, None)
if candidate:
assert callable(candidate)
traceable_fn = candidate
break
else:
raise RuntimeError(
f"Polyfill handler {handler} does not have a traceable function"
)
self.wrapped_fn: _F = handler
self.traceable_fn: _F = traceable_fn
@property
def polyfill_fn(self) -> _F:
return self.traceable_fn
def can_constant_fold_through(self):
return getattr(
self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False
)
def get_function(self):
return self.as_python_constant()
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs
):
result = (
self.fn( # use the original function which is faster than the polyfill
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
)
return VariableTracker.build(tx, result)
# Special case for sum on tuple/list of ints
if (
self.fn is builtins.sum
and len(args) == 1
and not kwargs
and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
and all(
(isinstance(x, variables.ConstantVariable) and isinstance(x.value, int))
or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
for x in args[0].items
)
):
return variables.SymNodeVariable.create(
tx,
tx.output.create_proxy(
"call_function",
torch.sym_sum,
(tuple(a.as_proxy() for a in args[0].items),),
{},
),
sym_num=torch.sym_sum(
[
(
x.value
if isinstance(x, variables.ConstantVariable)
else x.sym_num
)
for x in args[0].items
]
),
)
traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
return traceable_function_variable.call_function(tx, args, kwargs)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__call__":
return self.call_function(tx, args, kwargs)
method = getattr(self.fn, name, None)
assert method is not None, f"Member {name} not found in {self.fn}"
assert is_function(method), f"Member {name} is not callable in {self.fn}"
options = {}
if self.source:
options["source"] = AttrSource(self.source, name)
polyfilled_method_variable = PolyfilledFunctionVariable(method, **options)
return polyfilled_method_variable.call_function(tx, args, kwargs)
def as_python_constant(self):
return self.fn
class TracebackVariable(VariableTracker):
# We don't track traceback. A call to any function in this module is a no-op
def call_function(self, tx, args, kwargs): ...
class SysFunctionVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def exc_info(self, tx):
if len(tx.exn_vt_stack):
exn = tx.exn_vt_stack[-1]
typ = exn.exc_type
tb = None
items = [
VariableTracker.build(tx, typ),
exn,
VariableTracker.build(tx, tb),
]
else:
items = [
variables.ConstantVariable(None),
variables.ConstantVariable(None),
variables.ConstantVariable(None),
]
return variables.TupleVariable(items)
def exception(self, tx):
return self.exc_info(tx).items[1]
def call_function(self, tx, args, kwargs):
if self.value is sys.exc_info:
return self.exc_info(tx)
assert self.value is sys.exception
return self.exception(tx)
from torch._higher_order_ops.triton_kernel_wrap import (
TMADescriptorMetadata,
TritonHOPifier,
)
class DynamoTritonHOPifier(TritonHOPifier):
def raise_unsupported(self, msg: str) -> Never:
raise Unsupported(msg)
def is_callable(self, maybe_callable: Any) -> bool:
return isinstance(
maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable)
)
def get_value(self, val: Any) -> Any:
return val.value
def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
from .lists import BaseListVariable
if isinstance(grid, BaseListVariable):
return grid.as_proxy()
else:
unimplemented_v2(
gb_type="unsupported grid type for triton hop check_grid",
context=f"grid type = {type(grid)}",
explanation="`torch.compile` only supports list-like grid for check_grid",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def call_grid(self, grid, meta, tx):
meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()}
grid = grid.call_function(tx, [meta], {})
return grid
# We use this function to wrap call_prune_configs
def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable):
from .builder import SourcelessBuilder
wrapped_user_function = SourcelessBuilder.create(tx, user_fn)
result = wrapped_user_function.call_function(tx, args, kwargs)
return result
def wrap_user_defined_obj(self, user_obj, tx, variable, name):
from .builder import VariableBuilder
wrapped_user_obj = VariableBuilder(
tx, AttrSource(variable.kernel_source, f"{name}")
)._wrap(user_obj)
return wrapped_user_obj
def maybe_unpack_configs(self, configs, tx):
# unpack the list of configs
configs = configs.unpack_var_sequence(tx)
# guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
configs = [config.guard_as_python_constant() for config in configs]
return configs
def maybe_unpack_heuristic_result(self, result: Any) -> Any:
if not result.is_python_constant():
self.raise_unsupported(
"@triton.heuristics must return constant values because configs can only contain constant values."
)
return result.guard_as_python_constant()
# We need to override call_getitem here so that we can add the source in the case
# where we call the triton kernel with a grid
def call_getitem(
self,
variable: "TritonKernelVariable",
args: Sequence[Any],
) -> "TritonKernelVariable":
# __getitem__ should only be called if we don't already have a grid
# Only grid needs to be passed
if variable.grid is not None or len(args) != 1:
self.raise_unsupported(
"Triton kernels should be called with only a single grid"
)
return type(variable)(
kernel=variable.kernel,
kernel_idx=variable.kernel_idx,
grid=args[0],
kernel_source=variable.source,
)
def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
from .constant import ConstantVariable
from .dicts import ConstDictVariable
# as we can only pass tensors as non-const args in fx graph,
# here we replace TMA descriptors (TMADescriptorVariable
# instances) with the underlying tensors, while moving the
# TMA descriptor-related metadata to a separate argument,
# so that we can reconstruct the TMA descriptors downstream
tma_descriptor_metadata: TMADescriptorMetadata = {}
for k in list(combined_args_raw.keys()):
v = combined_args_raw[k]
if isinstance(v, TMADescriptorVariable):
tma_descriptor_metadata[k] = v.to_metadata()
combined_args_raw[k] = v.data_ptr.from_tensor
combined_args = {
variables.ConstantVariable.create(k): v
for k, v in combined_args_raw.items()
}
from torch._higher_order_ops.triton_kernel_wrap import (
kernel_side_table,
triton_kernel_wrapper_mutation,
)
# Combine args and kwargs and pass as a dict so that if user defined triton
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
# parameters of the wrapper function
constant_args = {
k: v.as_python_constant()
for k, v in combined_args_raw.items()
if isinstance(v, ConstantVariable)
}
non_constant_args = {
k: v
for k, v in combined_args.items()
if not isinstance(v, ConstantVariable)
}
for v in non_constant_args.values():
v = v.realize()
if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)):
self.raise_unsupported(
f"Unexpected argument type for a Triton kernel: {repr(v)}."
)
constant_args_idx = kernel_side_table.add_constant_args(constant_args)
meta = ConstDictVariable(non_constant_args, dict)
tx.output.create_proxy(
"call_function",
triton_kernel_wrapper_mutation,
(),
{
"kernel_idx": variable.kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grids,
"tma_descriptor_metadata": tma_descriptor_metadata,
"kwargs": meta.as_proxy(),
},
)
return variables.ConstantVariable(
None,
)
dynamo_triton_hopifier_singleton = DynamoTritonHOPifier()
class TritonKernelVariable(VariableTracker):
grid: "TritonGridType"
kernel: "TritonKernelType"
kernel_idx: Optional[int]
kernel_source: "AttrSource"
def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
self.kernel_source = kwargs.pop("kernel_source", None)
super().__init__(**kwargs)
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
return dynamo_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, tx
)
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
return dynamo_triton_hopifier_singleton.call_getitem(self, args)
elif name == "run":
return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx)
# Bail out to parent's implementation
return super().call_method(tx, name, args, kwargs)
def specialize_symbolic(self, arg: Any) -> Any:
from .constant import ConstantVariable
from .tensor import SymNodeVariable
# See [Note: Specialize tl.constexpr args in user-defined triton kernels]
if isinstance(arg, SymNodeVariable):
return ConstantVariable.create(arg.evaluate_expr())
return arg
class TMADescriptorVariable(VariableTracker):
def __init__(
self,
data_ptr: "variables.DataPtrVariable",
dims: "list[ConstantVariable]",
block_dims: "list[ConstantVariable]",
element_size: "ConstantVariable",
**kwargs,
):
assert isinstance(data_ptr, variables.DataPtrVariable)
super().__init__(**kwargs)
self.data_ptr = data_ptr
self.dims = dims
self.block_dims = block_dims
self.element_size = element_size
def to_metadata(self):
return (
[dim.as_proxy() for dim in self.dims],
[dim.as_proxy() for dim in self.block_dims],
self.element_size.as_proxy(),
)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from(
"triton.tools.experimental_descriptor",
f"create_{len(self.dims)}d_tma_descriptor",
)
)
self.data_ptr.reconstruct(codegen)
args = [*self.dims, *self.block_dims, self.element_size]
codegen.foreach(args)
codegen.call_function(len(args) + 1, False)
class CreateTMADescriptorVariable(VariableTracker):
def __init__(
self,
rank: int,
**kwargs,
) -> None:
assert rank in (1, 2)
super().__init__(**kwargs)
self.rank = rank
def call_function(
self,
tx: "InstructionTranslator",
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
if not isinstance(ptr, variables.DataPtrVariable):
raise Unsupported(
"Please ensure there were no graph breaks between "
f"create_{self.rank}d_tma_descriptor and the upstream "
".data_ptr() call."
)
if self.rank == 1:
assert len(args) + len(kwargs) == 4
dims = [
kwargs["dim"] if "dim" in kwargs else args[1],
]
block_dims = [
kwargs["block_dim"] if "block_dim" in kwargs else args[2],
]
else:
assert len(args) + len(kwargs) == 6
dims = [
kwargs["dim1"] if "dim1" in kwargs else args[1],
kwargs["dim0"] if "dim0" in kwargs else args[2],
]
block_dims = [
kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
kwargs["block_dim0"] if "block_dim0" in kwargs else args[4],
]
element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
return TMADescriptorVariable(
data_ptr=ptr,
dims=dims,
block_dims=block_dims,
element_size=element_size,
)