# mypy: ignore-errors """ Function-related variable tracking classes for Dynamo's symbolic execution. This module contains classes that track different types of functions during graph compilation, including: - User-defined functions and methods - Built-in functions and methods - Wrapped functions (e.g. from decorators) - Special function types (e.g. functools.partial) - Triton kernels and related function types These classes are responsible for: - Tracking function calls and their arguments - Managing function closures and cell variables - Handling function attributes and special methods - Maintaining guards for function identity and closure contents - Supporting function inlining and specialization - Enabling proper symbolic execution of different function types The variable trackers here work together with the rest of Dynamo to enable accurate graph capture while handling Python's various function-related behaviors. """ import builtins import functools import inspect import itertools import sys import types from collections.abc import Sequence from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from unittest.mock import patch import torch from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator from ..exc import ( get_dynamo_observed_exception, handle_observed_exception, InfiniteGeneratorError, ObservedException, ObservedGeneratorExit, ObservedUserStopIteration, raise_observed_exception, SkipFrame, unimplemented_v2, Unsupported, ) from ..guards import GuardBuilder, install_guard from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import ( check_constant_args, check_unspec_or_constant_args, cmp_name_to_op_mapping, counters, identity, is_function, is_wrapper_or_member_descriptor, istype, make_cell, ) from .base import typestr, ValueMutationNew, VariableTracker from .constant import ConstantVariable try: from torch.distributed.fsdp._fully_shard import _fsdp_param_group except ModuleNotFoundError: _fsdp_param_group = None if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, TritonKernelType, ) _F = TypeVar("_F", bound=Callable) def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): return val elif not source: return VariableTracker.build(tx, val) else: # Create a lazy variable to avoid guarding on __defaults__ unless really # needed. return variables.LazyVariableTracker.create(val, source) def wrap_args_kwargs(tx: "InstructionTranslator", result): for k, v in list(result.items()): if isinstance(v, (tuple, dict)): # args/kwargs result[k] = wrap_bound_arg(tx, v) def init_cellvars(parent, result: dict[str, VariableTracker], code): """ Update `result` to add mapping from local name to new cells created directly by `code`, or update SideEffects in `parent` if the a local cell is already in `result` (cell argument). """ side_effects = parent.output.side_effects for name in code.co_cellvars: new_cell = side_effects.track_cell_new() if name in result: # This handles when a function argument is a cell (e.g., captured by # a nested func). See `MAKE_CELL` bytecode for more info. side_effects.store_cell(new_cell, result.pop(name)) result[name] = new_cell def _create_nested_fn( code, f_globals, name, defaults, closure, kwdefaults, annotations ): from types import FunctionType func = FunctionType(code, f_globals, name, defaults, closure) func.__kwdefaults__ = kwdefaults if isinstance(annotations, tuple): from itertools import pairwise annotations = dict(pairwise(annotations)) # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) func.__annotations__ = annotations return func fn_known_dunder_attrs = { "__annotations__", "__defaults__", "__kwdefaults__", "__code__", "__globals__", "__closure__", "__doc__", } def fn_var_getattr(tx, fn, source, name): source = source and AttrSource(source, name) try: subobj = inspect.getattr_static(fn, name) except AttributeError: # function does not have a __getattr__ or __getattribute__ method, # so we can safely assume that this attribute is absent raise_observed_exception(AttributeError, tx) # Special handling for known dunder attributes if name in fn_known_dunder_attrs: subobj = getattr(fn, name) if source: return variables.LazyVariableTracker.create(subobj, source) return VariableTracker.build(tx, subobj) class BaseUserFunctionVariable(VariableTracker): def get_filename(self): return self.get_code().co_filename def get_name(self): return self.get_code().co_name def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> VariableTracker: result = False try: result = hasattr(self.get_function(), name) except NotImplementedError: if name == "__name__" and isinstance(self, NestedUserFunctionVariable): result = True return variables.ConstantVariable.create(result) def inspect_parameter_names(self): return list(inspect.signature(self.get_function()).parameters) def closure_vars(self, tx): return {} class UserFunctionVariable(BaseUserFunctionVariable): """Some unsupported user-defined global function""" _nonvar_fields = { "fn", "is_constant", *BaseUserFunctionVariable._nonvar_fields, } @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) if getattr(fn, "_dynamo_marked_constant", False): # This method should be treated as a constant for the purposes of compilation self.is_constant = True else: self.is_constant = False assert isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)), ( f"expected FunctionType found {typestr(fn)} {fn}" ) # TODO(anijain2305) - Replace directly calling UserFunctionVariable with # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn def as_python_constant(self): if istype(self, UserFunctionVariable): return self.fn # subclasses (such as methods) usually aren't a constant return super().as_python_constant() def self_args(self): return [] def get_function(self): return self.fn def get_code(self): return self.fn.__code__ def python_type(self): return types.FunctionType def has_self(self): return getattr(self.fn, "__self__", None) is not None def get_globals(self): return self.fn.__globals__ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to this function, create new bindings for initial locals. """ assert not self.is_constant root_tx = parent.output.root_tx wrap = functools.partial(wrap_bound_arg, tx=root_tx) fn: types.FunctionType = self.fn defaults = fn.__defaults__ or [] defaults_sources = [ None if self.source is None else DefaultsSource(self.source, idx) for idx, _ in enumerate(defaults) ] fake_func = types.FunctionType( fn.__code__, fn.__globals__, fn.__name__, tuple( [ wrap(val=arg, source=source) for arg, source in zip(defaults, defaults_sources) ] ), fn.__closure__, ) if fn.__kwdefaults__: kwdefaults_sources = { k: ( None if self.source is None else DefaultsSource(self.source, k, is_kw=True) ) for k in fn.__kwdefaults__ } fake_func.__kwdefaults__ = { k: wrap(val=v, source=kwdefaults_sources[k]) for k, v in fn.__kwdefaults__.items() } bound = inspect.signature(fake_func).bind(*args, **kwargs) bound.apply_defaults() result = dict(bound.arguments.items()) wrap_args_kwargs(root_tx, result) init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () assert len(closure) == len(self.fn.__code__.co_freevars) for idx, name, cell in zip( itertools.count(), self.fn.__code__.co_freevars, closure ): # TODO refactor these 3 branches. side_effects = parent.output.side_effects if cell in side_effects: cell_var = side_effects[cell] elif self.source: closure_cell = GetItemSource( AttrSource(self.source, "__closure__"), idx ) closure_cell_contents = AttrSource(closure_cell, "cell_contents") try: contents_var = VariableTracker.build( parent, cell.cell_contents, closure_cell_contents ) except ValueError: # Cell has not yet been assigned contents_var = variables.DeletedVariable() cell_var = side_effects.track_cell_existing( closure_cell, cell, contents_var ) else: # TODO figure out why source isn't available here, and whether # we can fix that and remove this branch. try: contents_var = VariableTracker.build(parent, cell.cell_contents) except ValueError: # Cell has not yet been assigned contents_var = variables.DeletedVariable() cell_var = side_effects.track_cell_existing(None, cell, contents_var) result[name] = cell_var return result def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) return fn_var_getattr(tx, self.fn, self.source, name) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> VariableTracker: result = hasattr(self.fn, name) return variables.ConstantVariable.create(result) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # Handle a `nonstrict_trace(fn)` call if self.fn is torch._dynamo.nonstrict_trace: bound = inspect.signature(self.fn).bind(*args, **kwargs) fn_var = bound.args[0] if not isinstance(fn_var, BaseUserFunctionVariable): typ = fn_var.python_type() msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" unimplemented_v2( gb_type="TypeError from user code", context=f"call_function({self.value}, {args}, {kwargs})", explanation=msg, hints=[ *graph_break_hints.USER_ERROR, ], ) if not isinstance(fn_var, UserFunctionVariable): fn_name = fn_var.get_name() msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 unimplemented_v2( gb_type="Limitation of `nonstrict_trace", context=f"{self}", explanation=msg, hints=[ f"make sure definition of {fn_name} is outside ", "`torch.compile` region", ], ) fn = fn_var.fn return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) if self.is_constant: return invoke_and_store_as_constant( tx, self.fn, self.get_name(), args, kwargs ) if ( tx.output.current_tracer.under_activation_checkpoint and not tx.output.current_tracer.allow_side_effects_under_checkpoint ): try: from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState except Exception: FSDPState = None if FSDPState is not None and self.fn in [ FSDPState._pre_forward, FSDPState._post_forward, ]: with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): return super().call_function(tx, args, kwargs) return super().call_function(tx, args, kwargs) class BuiltinMethodVariable(BaseUserFunctionVariable): def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(fn, types.BuiltinMethodType) self.fn = fn @staticmethod def is_supported_builtin_method(obj): method_self = obj.__self__ method_name = obj.__name__ # TODO(anijain2305) - Add support for more builtin methods # Supports tuple.__new__ and frozenset({....}).__contains__ return (method_self is tuple and method_name == "__new__") or ( type(method_self) is frozenset and method_name == "__contains__" ) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": method_self = self.fn.__self__ name = self.fn.__name__ obj_source = self.source and AttrSource(self.source, "__self__") obj_vt = VariableTracker.build(tx, method_self, obj_source) return obj_vt.call_method(tx, name, args, kwargs) class LocalGeneratorObjectVariable(VariableTracker): def __init__( self, code: types.CodeType, f_globals, inline_tracer: Optional["InstructionTranslator"], **kwargs, ): super().__init__(**kwargs) self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer def get_code(self): return self.code def get_filename(self): return self.get_code().co_filename def get_name(self): return self.get_code().co_name def get_function(self): raise NotImplementedError def has_self(self): return False def __name__(self): return self.get_name() def __str__(self): return f"{self.__class__.__name__}({self.get_name()})" __repr__ = __str__ def reconstruct(self, codegen): from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, save_and_restart_speculation_log, temporarely_allow_writes_to_output_graph, ) tx = InstructionTranslator.current_tx() save = save_and_restart_speculation_log(tx) disallow = disallow_side_effects_in_generator(tx) temp = temporarely_allow_writes_to_output_graph(tx) with save, disallow, temp: tracer = self._get_inline_tracer(tx) if not tracer.generator_exhausted: self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) def bind_args(self, tx, args, kwargs): return self.fn.bind_args(tx, args, kwargs) def get_globals(self): return self.f_globals def python_type(self): return types.GeneratorType def _get_inline_tracer(self, tx): from torch._dynamo.symbolic_convert import InliningInstructionTranslator if self.inline_tracer is None: self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( tx, self, [], {} ) return self.inline_tracer def next_variable(self, tx): tracer = self._get_inline_tracer(tx) if self._is_generator_exhausted(): raise_observed_exception(StopIteration, tx) try: # Hierarchically, tx can be seen as the parent of the inline tracer # created on call_function. Any exception needs to be propagated to tx # for Dynamo to behave correctly with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return tracer.inline_call_() except ObservedException as e: raise e except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit raise except Unsupported as e: torch._dynamo.eval_frame.skip_code(self.get_code()) raise SkipFrame from e finally: counters["unimplemented"] |= counters["inline_call"] def has_unpack_var_sequence(self, tx): return False def has_force_unpack_var_sequence(self, tx) -> builtins.bool: return True def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: result = [] while True: try: result.append(self.next_variable(tx)) except ObservedUserStopIteration: handle_observed_exception(tx) break return result def _setup_exception(self, tx, exc): tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) except ObservedException as e: # if no handler is available (i.e. user code doesn't catch it), the # exception is raised again. tracer.exception_handler(e) def _is_generator_just_started(self): return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 def _is_generator_exhausted(self): return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( self, tx: "InstructionTranslator", name: str, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__next__": return self.next_variable(tx) elif name == "__iter__": # iter(gen) returns itself return self elif name == "send": # Sends a value into the generator function. Returns the next value # yielded by the generator, or raises StopIteration if the generator # exits without yielding another value if self._is_generator_just_started() and len(args): # can't send non-None value to a just-started generator # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen if not all( isinstance(arg, ConstantVariable) and arg.value is None for arg in args ): raise_observed_exception(TypeError, tx) tracer = self._get_inline_tracer(tx) tracer.push_many(args) return self.next_variable(tx) elif name == "close": # * Raises a GeneratorExit at the point where the generator function was paused. # * If the generator function catches the exception and returns a # value, this value is returned from close() - Python 3.13+ # * If the generator function is already closed, or raises GeneratorExit # (by not catching the exception), close() returns None. # * If the generator yields a value, a RuntimeError is raised. # * If the generator raises any other exception, it is propagated to the caller. # * If the generator has already exited due to an exception or normal # exit, close() returns None and has no other effect. # Return None if close is called on a just-started generator # See test GeneratorCloseCpythonTests::test_close_not_started tracer = self._get_inline_tracer(tx) if self._is_generator_just_started() or self._is_generator_exhausted(): tracer.generator_exhausted = True return variables.ConstantVariable(None) # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: self._setup_exception( tx, variables.ExceptionVariable(GeneratorExit, ()) ) # There's an extra block on Python 3.12+ to handle StopIteration # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 # # 1 0 RETURN_GENERATOR # 2 POP_TOP # 4 RESUME 0 # 2 6 LOAD_CONST 1 (1) # 8 YIELD_VALUE 1 # 10 RESUME 1 # 12 POP_TOP # 14 RETURN_CONST 0 (None) # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) # 18 RERAISE 1 # ExceptionTable: # 4 to 14 -> 16 [0] lasti if ( sys.version_info >= (3, 12) and tracer.next_instruction.opname == "CALL_INTRINSIC_1" ): tracer.generator_exhausted = True return variables.ConstantVariable(None) except ObservedGeneratorExit: # If it doesn't catch, we just return None, as per the text above tracer.generator_exhausted = True return variables.ConstantVariable(None) try: # Raise RuntimeError if the generator yields any other value if self.next_variable(tx): raise_observed_exception(RuntimeError, tx) except ObservedGeneratorExit: tracer.generator_exhausted = True return variables.ConstantVariable(None) except ObservedUserStopIteration: # In Python 3.13+, one can capture GeneratorExit and return a value # See test_generator.py::test_close_capture_GeneratorExit_return # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 # https://github.com/python/cpython/pull/104771 assert tracer.symbolic_result is not None return tracer.symbolic_result elif name == "throw": # * Raises an exception at the point where the generator was paused, and # returns the next value yielded by the generator. # * If the generator exits without yielding, raise StopIteration # * If the generator function does not catch the passed-in exception, # or raises a different exception, then that exception propagates to the caller. # Setup the exception table and jump target in case of try...finally tracer = self._get_inline_tracer(tx) try: # In Python 3.9, the exception is represented as a triple (typ, val, tb) # In such cases, we re-raise the exception object given to avoid # creating a new object, so that IS_OP works. # See: https://github.com/pytorch/pytorch/pull/146496 self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) except ObservedException: # noqa: TRY203 # propagate the exception back to the parent caller raise retval = self.next_variable(tx) # The exception raised before is still active. We need to check the exception # table one more time to find the next target. But why? Let’s walk # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M # # z = 0 # def whoo(): # global z # z = 0 # try: # yield 1 # except ValueError: # yield 2 # finally: # z += 1 # z += 10 # # gen = whoo() # next(gen) # gen.throw(ValueError) # print('z', z) -> z = 1 # # ... # >> 58 PUSH_EXC_INFO # # 8 60 LOAD_GLOBAL 2 (ValueError) # 70 CHECK_EXC_MATCH # 72 POP_JUMP_IF_FALSE 7 (to 88) # 74 POP_TOP # # 9 76 LOAD_CONST 3 (2) # 78 YIELD_VALUE 3 <------ ValueError is still active here # 80 RESUME 1 # 82 POP_TOP # 84 POP_EXCEPT # 86 jump_backward 34 (to 20) # ... # # ExceptionTable: # 4 to 8 -> 124 [0] lasti # 12 to 18 -> 58 [0] # 20 to 56 -> 124 [0] lasti # 58 to 82 -> 90 [1] lasti <------ move to 90 # 84 to 86 -> 96 [0] # 88 to 88 -> 90 [1] lasti # 90 to 94 -> 96 [0] # 96 to 116 -> 118 [1] lasti # 118 to 122 -> 124 [0] lasti # # In this scenario, a generator can yield after `throw()` is called. Even # after the exception is raised a few lines above, it remains active # within the `78 YIELD_VALUE` instruction. When the generator resumes # after the second yield on instruction `80 RESUME`, we cannot simply # return the control flow to the next instruction. Instead, one must # check the exception table (or equivalent) to find the next target # In this case, it says the instruction pointer must be moved to 90. # # Without this step, if we let the trace proceed to the next # instruction, it would follow the control flow where the exception # raised by `throw()` was handled and swallowed, potentially leading # to incorrect behavior. exc_type = type("__InternalThrowException", (Exception,), {}) try: self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) self.next_variable(tx) except get_dynamo_observed_exception(exc_type): # We should get back the exception raised before. pass else: raise_observed_exception(RuntimeError, tracer) return retval super().call_method(tx, name, args, kwargs) class ContextlibContextManagerLocalGeneratorObjectVariable( LocalGeneratorObjectVariable ): """ .. note:: This is only used when the function is annotated with @contextlib.contextmanager It is a special case of a generator function as we do not allow return a context manager from a torch.compile function. """ class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): """functions that behaves like iterators .. note:: This is a wrapper around (Nested)UserFunctionVariable """ def __init__( self, vt: VariableTracker, *, generator_cls=LocalGeneratorObjectVariable, **kwargs, ): super().__init__(**kwargs) self.vt = vt self.generator_cls = generator_cls def __getattr__(self, name): if name in self.__class__.__dict__.keys(): return getattr(self, name) return getattr(self.vt, name) def _build_inline_tracer(self, tx, args, kwargs): from torch._dynamo.symbolic_convert import InliningInstructionTranslator return InliningInstructionTranslator.build_inline_tracer( tx, self, args, kwargs, ) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": assert is_generator(self.vt.get_code()) inline_tracer = self._build_inline_tracer(tx, args, kwargs) code = self.vt.get_code() f_globals = self.vt.get_globals() # calling a generator returns a generator object return self.generator_cls( code, f_globals, inline_tracer, source=self.source, ) class FunctionDecoratedByContextlibContextManagerVariable( LocalGeneratorFunctionVariable ): """ .. note:: This is only used when the function is annotated with @contextlib.contextmanager """ def __init__(self, vt, **kwargs): super().__init__( vt, generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, **kwargs, ) def _build_inline_tracer(self, tx, args, kwargs): # NOTE: This only exists to not break support for context manager when # config.enable_faithful_generator_behavior = False and # config.enable_trace_contextlib = True. In case the former is false, # Dynamo should still be able to trace through @contextmanager functions tracer = super()._build_inline_tracer(tx, args, kwargs) assert isinstance( tracer, torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, ) tracer.is_generator_from_ctx_manager = True return tracer class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" def __init__(self, fn, obj, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" def self_args(self): return [self.obj] def python_type(self): return types.MethodType def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually # a `nonstrict_trace`-ed function will be wrapped by # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, # but in the case of method, we manually wrap it with `UserMethodVariable` # inside `UserDefinedObjectVariable.var_getattr`. # # We might be able to simplify this away by canonicalizing the # function/method wrapping code paths. from ..trace_rules import is_nonstrict_trace_callable if is_nonstrict_trace_callable(self.fn): call_args = [*self.self_args(), *args] var = variables.TorchInGraphFunctionVariable( self.fn, nonstrict_traceable=True ) return var.call_function(tx, call_args, kwargs) # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method # since we ensure `forward` of allowed modules can be traced by AOT safely. # Note this is not only for allowed modules, as user customized modules can extend from # allowed modules but using parent's `forward` method, which is also covered by this branch. # If we are tracing the higher order op, we want Dynamo to step inside # the module call so that Dynamo can see the underlying parameters and # buffers and raise them as inputs to the graph. The is_root_tracer # check bypasses the if condition for non-root tracers and directly # calls the super().call_function at the end, which is basically # equivalent of inlining the method. if tx.output.is_root_tracer() and isinstance( self.obj, variables.NNModuleVariable ): module_attr = getattr(self.fn, "__module__", "") # inline torch.nn.utils.parametrize if ( module_attr is not None and module_attr.startswith("torch.nn.") and module_attr != "torch.nn.utils.parametrize" or self.is_constant ): return self.obj.call_method( tx, self.fn.__name__, args, kwargs, constant=self.is_constant ) elif ( _fsdp_param_group is not None and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state ): return variables.TorchCtxManagerClassVariable(self.fn).call_function( tx, (self.obj, *args), kwargs ) if self.is_constant: fn = getattr(self.obj.value, self.fn.__name__) return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): source = self.source and AttrSource(self.source, name) if name == "__self__": return self.obj if name == "__func__": return VariableTracker.build(tx, self.fn, source) return super().var_getattr(tx, name) class WrappedUserMethodVariable(UserMethodVariable): def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) kwargs.pop("obj", None) super().__init__(wrapped.fn, wrapped.obj, **kwargs) self.wrapped = wrapped self.context = context def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result class WrappedUserFunctionVariable(UserFunctionVariable): def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) kwargs.pop("obj", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped self.context = context def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): def convert(x): if isinstance(x, variables.TensorVariable): return x.get_real_value() return x.as_python_constant() args = [convert(x) for x in args] kwargs = {k: convert(v) for k, v in kwargs.items()} res = fn(*args, **kwargs) return tx.output.register_attr_or_module( res, name, source=ConstantSource(name), ) class NestedUserFunctionVariable(BaseUserFunctionVariable): _nonvar_fields = { "f_globals", *BaseUserFunctionVariable._nonvar_fields, } def __init__( self, fn_name, code, f_globals, defaults, kwdefaults, annotations, closure, # This is present when this function is created by # `functools.wrap(wrapped_fn)(this_fn)`. wrapped_fn=None, **kwargs, ) -> None: super().__init__(**kwargs) assert isinstance(fn_name.as_python_constant(), str) assert isinstance(code.as_python_constant(), types.CodeType) assert isinstance(f_globals, dict) self.fn_name = fn_name self.code = code self.f_globals = f_globals self.defaults = defaults self.kwdefaults = kwdefaults self.annotations = annotations self.closure = closure self.wrapped_fn: Optional[VariableTracker] = wrapped_fn def self_args(self): return [] def get_code(self): return self.code.as_python_constant() def python_type(self): return types.FunctionType def get_function(self): if self.closure: raise NotImplementedError func = types.FunctionType( self.code.as_python_constant(), self.f_globals, self.fn_name.as_python_constant(), ) if self.defaults: func.__defaults__ = self.defaults.as_python_constant() if self.kwdefaults: func.__kwdefaults__ = self.kwdefaults.as_python_constant() if self.annotations: annotations = self.annotations.as_python_constant() if isinstance(annotations, tuple): from itertools import pairwise annotations = dict(pairwise(annotations)) # TypeError: __annotations__ must be set to a dict object assert isinstance(annotations, dict) func.__annotations__ = annotations return func def has_closure(self): return self.closure is not None def has_self(self): return False def get_globals(self): return self.f_globals def bind_args(self, parent, args, kwargs): code = self.get_code() func = types.FunctionType( code, self.f_globals, self.fn_name.as_python_constant(), tuple(self.defaults.items) if self.defaults else None, tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), ) if self.kwdefaults: func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() bound = inspect.signature(func).bind(*args, **kwargs) bound.apply_defaults() result = dict(bound.arguments.items()) wrap_args_kwargs(parent.output.root_tx, result) init_cellvars(parent, result, code) for idx, name in enumerate(code.co_freevars): assert name not in result cell = self.closure.items[idx] result[name] = cell return result def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) codegen(self.code) codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) codegen(ConstantVariable.create(self.code.value.co_name)) if self.defaults: codegen(self.defaults) else: codegen.extend_output([codegen.create_load_const(None)]) if self.closure: codegen(self.closure) else: codegen.extend_output([codegen.create_load_const(None)]) if self.kwdefaults: codegen(self.kwdefaults) else: codegen.extend_output([codegen.create_load_const(None)]) if self.annotations: try: annotations = self.annotations.as_python_constant() codegen.extend_output( [codegen.create_load_const_unchecked(annotations)] ) except NotImplementedError: codegen(self.annotations) else: codegen.extend_output([codegen.create_load_const(None)]) codegen.extend_output(create_call_function(7, False)) if self.wrapped_fn: codegen.add_push_null( lambda: codegen.load_import_from("functools", "wraps") ) codegen(self.wrapped_fn) codegen.extend_output(create_call_function(1, False)) codegen.extend_output(create_rot_n(2)) codegen.extend_output(create_call_function(1, True)) class SkipFunctionVariable(VariableTracker): _nonvar_fields = { "value", "reason", *VariableTracker._nonvar_fields, } def __init__(self, value, reason=None, **kwargs) -> None: super().__init__(**kwargs) self.value = value self.reason = reason def as_python_constant(self): return self.value @classmethod def create_with_source(cls, value, source): if not is_wrapper_or_member_descriptor(value): # These descriptors are not guaranteed to return the same object on # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) return cls(value, source=source) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): unimplemented_v2( gb_type="Skip calling `torch.compiler.disable()`d function", context=str(self.value), explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`", hints=[ "Remove the `torch.compiler.disable` call", ], ) elif self.value is torch._dynamo.graph_break: graph_break_msg = kwargs.get("msg", None) if graph_break_msg: graph_break_msg = graph_break_msg.as_python_constant() unimplemented_v2( gb_type="Call to `torch._dynamo.graph_break()`", context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", explanation=f"User-inserted graph break. Message: {graph_break_msg}", hints=[ "Remove the `torch._dynamo.graph_break()` call.", ], ) else: qualname = getattr(self.value, "__qualname__", "") module_or = getattr(self.value, "__module__", None) module_name = "" if module_or is None else str(module_or) try: path = inspect.getfile(self.value) explanation = ( f"Dynamo developers have intentionally marked that the function `{qualname}` " f"in file `{path}` should not be traced." ) hints = [ f"Avoid calling the function `{qualname}`.", ] # TODO improve trace_rules reasoning to provide better hints. # How do we tell that a function/file should NOT be removed from skip files? # Do a very basic check for now. if "_dynamo" not in path: hints += [ f"Remove the function `{qualname}` or the file `{path}` " "from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of " "attempting to trace into the function.", "Please file an issue to PyTorch.", # TODO suggest mark_force_inline when implemented ] except TypeError: known_python_builtin_modules = {"_abc", "_warnings"} if module_or in known_python_builtin_modules: explanation = ( f"Dynamo does not know how to trace the Python builtin " f"`{module_name}.{qualname}`." ) hints = [ "If you are attempting to call a logging function (e.g. `_warnings.warn`), " "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", "Please file an issue on GitHub " "so the PyTorch team can add support for it. ", ] elif module_or is not None and module_or.startswith("optree"): explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}." hints = [ " Consider using torch.utils._pytree - " "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" ] # also warn on it because most users won't see the graph break message torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) else: explanation = ( f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` " f"This function is either a Python builtin (e.g. _warnings.warn) " f"or a third-party C/C++ Python extension (perhaps created with pybind)." ) hints = [ "If it is a Python builtin, please file an issue on GitHub " "so the PyTorch team can add support for it and see the next case for a workaround.", "If it is a third-party C/C++ Python extension, please " "either wrap it into a PyTorch-understood custom operator " "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " "for more details) or, if it is traceable, use " "`torch.compiler.allow_in_graph`.", ] # also warn on it because most users won't see the graph break message torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) if qualname == "allow_in_graph": explanation = ( "Found an allow_in_graph decorator to a function which " "is created inside the parent function that is getting " "compiled. This is not supported for now." ) hints = [] reason = self.reason if self.reason else "" unimplemented_v2( gb_type="Attempted to call function marked as skipped", context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}", explanation=explanation, hints=hints, ) def call_obj_hasattr(self, tx: "InstructionTranslator", name): return variables.ConstantVariable.create(hasattr(self.value, name)) def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) return fn_var_getattr(tx, self.value, self.source, name) class WrapperUserFunctionVariable(VariableTracker): """ Used to represent a wrapper object that contains the actual callable as an attribute. For example, torch.jit.script/trace have the original function at their _torchdynamo_inline attribute. Similarly, functions with __script_if_tracing_wrapper have the original attr at "__original_fn". """ def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: super().__init__(**kwargs) self.wrapper_obj = wrapper_obj self.attr_to_trace = attr_to_trace def var_getattr(self, tx: "InstructionTranslator", name): if name == self.attr_to_trace: val = getattr(self.wrapper_obj, self.attr_to_trace) source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, val, source) return super().var_getattr(tx, name) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return variables.UserFunctionVariable( polyfills.getattr_and_trace ).call_function( tx, [self, variables.ConstantVariable(self.attr_to_trace), *args], kwargs ) def _traceable_collective_remaps(): # We can't rely on importing from distributed, since it's not always built if torch.distributed.is_available(): from torch.distributed._functional_collectives import ( traceable_collective_remaps, ) return traceable_collective_remaps return {} def _traceable_collectives_source(tx: "InstructionTranslator", fn): assert torch.distributed.is_available(), "Illegal invocation." assert fn in _traceable_collective_remaps().values() inner_name = fn.__name__ path_source = tx.import_source("torch.distributed._functional_collectives") return AttrSource(path_source, inner_name) class CollectiveFunctionRewriteVariable(UserFunctionVariable): """ Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. This class provides both a way to check if a function is remappable, and perform the remapping. In the case that a function is 'remappable' but only for some combinations of call-time arguments, we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse than status-quo as we currently graph-break on all distributed.* collectives. """ def __init__(self, fn, *, replacement_var, **kwargs) -> None: super().__init__(fn, **kwargs) assert isinstance(replacement_var, UserFunctionVariable) self.replacement_var = replacement_var @staticmethod def create(tx: "InstructionTranslator", old_fn, source, **options): new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) return CollectiveFunctionRewriteVariable( old_fn, replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), source=source, **options, ) @staticmethod def can_rewrite(variable): return ( inspect.isfunction(variable) and variable in _traceable_collective_remaps() ) @staticmethod def rewrite(tx: "InstructionTranslator", fn): new_fn = _traceable_collective_remaps()[fn] return new_fn, _traceable_collectives_source(tx, new_fn) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # call_function must check any unsupported arguments and graph-break. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, # since that's the contract for putting a mapping in `traceable_collective_remaps` import torch.distributed as dist from torch.distributed._functional_collectives import REDUCE_OP_TO_STR # Merge args into kwargs so positional and keyword args # can be processed the same way. signature = inspect.signature(self.fn) kwargs = dict(signature.bind(*args, **kwargs).arguments) args = () if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): unimplemented_v2( gb_type="async_op=True for distributed collectives", context=f"{self.fn}, {args=}, {kwargs=}", explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}", hints=[ *graph_break_hints.SUPPORTABLE, ], ) if self.fn in ( dist.all_reduce, dist.reduce_scatter_tensor, dist._reduce_scatter_base, ): reduce_op_var = kwargs.get("op") reduce_op = ( reduce_op_var.value if reduce_op_var is not None else signature.parameters["op"].default ) if reduce_op not in REDUCE_OP_TO_STR: raise ValueError(f"Unsupported all_reduce op: {reduce_op}") kwargs["op"] = variables.ConstantVariable.create( REDUCE_OP_TO_STR[reduce_op] ) return self.replacement_var.call_function(tx, args, kwargs) class FunctoolsWrapsVariable(UserFunctionVariable): def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if not kwargs and len(args) == 1: def wraps(fn): if isinstance(fn, variables.NestedUserFunctionVariable): return fn.clone(wrapped_fn=args[0]) unimplemented_v2( gb_type="functools.wraps", context=f"{fn}", explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", hints=[ *graph_break_hints.SUPPORTABLE, ], ) return variables.LambdaVariable(wraps) return super().call_function(tx, args, kwargs) class CollectionsNamedTupleFunction(UserFunctionVariable): def as_python_constant(self): return self.fn def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": constant_args = check_constant_args(args, kwargs) if constant_args: value = self.fn( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) return variables.UserDefinedClassVariable( value, mutation_type=ValueMutationNew() ) unimplemented_v2( gb_type="namedtuple construction", context=f"{args=}, {kwargs=}", explanation="`torch.compile` only support certain input types for namedtuple", hints=[ *graph_break_hints.SUPPORTABLE, ], ) class FunctoolsPartialVariable(VariableTracker): def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: super().__init__(**kwargs) self.func = func assert isinstance(args, list) self.args = args assert isinstance(keywords, dict) self.keywords = keywords # fake_value is used for id calculation. Creating this value and id'ng # on it is sufficient for the tracing purposes. self.fake_value = functools.partial(identity) def python_type(self): return functools.partial def reconstruct(self, codegen): codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: codegen.foreach(self.args) if not self.keywords: codegen.extend_output(create_call_function(len(self.args) + 1, False)) return codegen.foreach(self.keywords.values()) keys = tuple(self.keywords.keys()) codegen.extend_output( codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) ) def get_function(self): return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": merged_args = self.args + args merged_kwargs = {**self.keywords, **kwargs} return self.func.call_function(tx, merged_args, merged_kwargs) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> VariableTracker: # functools.partial uses slots, so attributes are constant return variables.ConstantVariable.create( hasattr(functools.partial(identity), name) ) def var_getattr(self, tx: "InstructionTranslator", name: str): source = self.source and AttrSource(self.source, name) # Handle __slots__ if name == "func": return self.func if name == "args": return variables.ListVariable(self.args, source=source) if name == "keywords": items = {ConstantVariable.create(k): v for k, v in self.keywords.items()} return variables.ConstDictVariable(items, source=source) raise_observed_exception(AttributeError, tx) def as_python_constant(self): return functools.partial( self.func.as_python_constant(), *[arg.as_python_constant() for arg in self.args], **{k: v.as_python_constant() for k, v in self.keywords.items()}, ) def guard_as_python_constant(self): """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" return functools.partial( self.func.guard_as_python_constant(), *[v.guard_as_python_constant() for v in self.args], **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, ) class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { "fn", "wrapped_fn", "traceable_fn", *VariableTracker._nonvar_fields, } @classmethod @functools.lru_cache(None) def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) return cls(value, source=source) def __init__(self, fn: _F, **kwargs) -> None: super().__init__(**kwargs) self.fn: _F = fn handler = self._get_polyfill_handlers().get(fn, fn) assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" for candidate_attr in ( "__torch_dynamo_polyfill__", # registered polyfill "__python_implementation__", # self handler from third-party libraries ): candidate = getattr(handler, candidate_attr, None) if candidate: assert callable(candidate) traceable_fn = candidate break else: raise RuntimeError( f"Polyfill handler {handler} does not have a traceable function" ) self.wrapped_fn: _F = handler self.traceable_fn: _F = traceable_fn @property def polyfill_fn(self) -> _F: return self.traceable_fn def can_constant_fold_through(self): return getattr( self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False ) def get_function(self): return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): result = ( self.fn( # use the original function which is faster than the polyfill *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) ) return VariableTracker.build(tx, result) # Special case for sum on tuple/list of ints if ( self.fn is builtins.sum and len(args) == 1 and not kwargs and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) and all( (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int)) or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) for x in args[0].items ) ): return variables.SymNodeVariable.create( tx, tx.output.create_proxy( "call_function", torch.sym_sum, (tuple(a.as_proxy() for a in args[0].items),), {}, ), sym_num=torch.sym_sum( [ ( x.value if isinstance(x, variables.ConstantVariable) else x.sym_num ) for x in args[0].items ] ), ) traceable_function_variable = VariableTracker.build(tx, self.traceable_fn) return traceable_function_variable.call_function(tx, args, kwargs) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__call__": return self.call_function(tx, args, kwargs) method = getattr(self.fn, name, None) assert method is not None, f"Member {name} not found in {self.fn}" assert is_function(method), f"Member {name} is not callable in {self.fn}" options = {} if self.source: options["source"] = AttrSource(self.source, name) polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) return polyfilled_method_variable.call_function(tx, args, kwargs) def as_python_constant(self): return self.fn class TracebackVariable(VariableTracker): # We don't track traceback. A call to any function in this module is a no-op def call_function(self, tx, args, kwargs): ... class SysFunctionVariable(VariableTracker): def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value def exc_info(self, tx): if len(tx.exn_vt_stack): exn = tx.exn_vt_stack[-1] typ = exn.exc_type tb = None items = [ VariableTracker.build(tx, typ), exn, VariableTracker.build(tx, tb), ] else: items = [ variables.ConstantVariable(None), variables.ConstantVariable(None), variables.ConstantVariable(None), ] return variables.TupleVariable(items) def exception(self, tx): return self.exc_info(tx).items[1] def call_function(self, tx, args, kwargs): if self.value is sys.exc_info: return self.exc_info(tx) assert self.value is sys.exception return self.exception(tx) from torch._higher_order_ops.triton_kernel_wrap import ( TMADescriptorMetadata, TritonHOPifier, ) class DynamoTritonHOPifier(TritonHOPifier): def raise_unsupported(self, msg: str) -> Never: raise Unsupported(msg) def is_callable(self, maybe_callable: Any) -> bool: return isinstance( maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) ) def get_value(self, val: Any) -> Any: return val.value def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: from .lists import BaseListVariable if isinstance(grid, BaseListVariable): return grid.as_proxy() else: unimplemented_v2( gb_type="unsupported grid type for triton hop check_grid", context=f"grid type = {type(grid)}", explanation="`torch.compile` only supports list-like grid for check_grid", hints=[ *graph_break_hints.SUPPORTABLE, ], ) def call_grid(self, grid, meta, tx): meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()} grid = grid.call_function(tx, [meta], {}) return grid # We use this function to wrap call_prune_configs def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable): from .builder import SourcelessBuilder wrapped_user_function = SourcelessBuilder.create(tx, user_fn) result = wrapped_user_function.call_function(tx, args, kwargs) return result def wrap_user_defined_obj(self, user_obj, tx, variable, name): from .builder import VariableBuilder wrapped_user_obj = VariableBuilder( tx, AttrSource(variable.kernel_source, f"{name}") )._wrap(user_obj) return wrapped_user_obj def maybe_unpack_configs(self, configs, tx): # unpack the list of configs configs = configs.unpack_var_sequence(tx) # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed. configs = [config.guard_as_python_constant() for config in configs] return configs def maybe_unpack_heuristic_result(self, result: Any) -> Any: if not result.is_python_constant(): self.raise_unsupported( "@triton.heuristics must return constant values because configs can only contain constant values." ) return result.guard_as_python_constant() # We need to override call_getitem here so that we can add the source in the case # where we call the triton kernel with a grid def call_getitem( self, variable: "TritonKernelVariable", args: Sequence[Any], ) -> "TritonKernelVariable": # __getitem__ should only be called if we don't already have a grid # Only grid needs to be passed if variable.grid is not None or len(args) != 1: self.raise_unsupported( "Triton kernels should be called with only a single grid" ) return type(variable)( kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=args[0], kernel_source=variable.source, ) def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: from .constant import ConstantVariable from .dicts import ConstDictVariable # as we can only pass tensors as non-const args in fx graph, # here we replace TMA descriptors (TMADescriptorVariable # instances) with the underlying tensors, while moving the # TMA descriptor-related metadata to a separate argument, # so that we can reconstruct the TMA descriptors downstream tma_descriptor_metadata: TMADescriptorMetadata = {} for k in list(combined_args_raw.keys()): v = combined_args_raw[k] if isinstance(v, TMADescriptorVariable): tma_descriptor_metadata[k] = v.to_metadata() combined_args_raw[k] = v.data_ptr.from_tensor combined_args = { variables.ConstantVariable.create(k): v for k, v in combined_args_raw.items() } from torch._higher_order_ops.triton_kernel_wrap import ( kernel_side_table, triton_kernel_wrapper_mutation, ) # Combine args and kwargs and pass as a dict so that if user defined triton # kernel uses variables as 'grid' or 'kernel', it does not conflict with # parameters of the wrapper function constant_args = { k: v.as_python_constant() for k, v in combined_args_raw.items() if isinstance(v, ConstantVariable) } non_constant_args = { k: v for k, v in combined_args.items() if not isinstance(v, ConstantVariable) } for v in non_constant_args.values(): v = v.realize() if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)): self.raise_unsupported( f"Unexpected argument type for a Triton kernel: {repr(v)}." ) constant_args_idx = kernel_side_table.add_constant_args(constant_args) meta = ConstDictVariable(non_constant_args, dict) tx.output.create_proxy( "call_function", triton_kernel_wrapper_mutation, (), { "kernel_idx": variable.kernel_idx, "constant_args_idx": constant_args_idx, "grid": grids, "tma_descriptor_metadata": tma_descriptor_metadata, "kwargs": meta.as_proxy(), }, ) return variables.ConstantVariable( None, ) dynamo_triton_hopifier_singleton = DynamoTritonHOPifier() class TritonKernelVariable(VariableTracker): grid: "TritonGridType" kernel: "TritonKernelType" kernel_idx: Optional[int] kernel_source: "AttrSource" def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: self.kernel_source = kwargs.pop("kernel_source", None) super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return dynamo_triton_hopifier_singleton.call_triton_kernel( self, args, kwargs, tx ) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__getitem__": return dynamo_triton_hopifier_singleton.call_getitem(self, args) elif name == "run": return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # Bail out to parent's implementation return super().call_method(tx, name, args, kwargs) def specialize_symbolic(self, arg: Any) -> Any: from .constant import ConstantVariable from .tensor import SymNodeVariable # See [Note: Specialize tl.constexpr args in user-defined triton kernels] if isinstance(arg, SymNodeVariable): return ConstantVariable.create(arg.evaluate_expr()) return arg class TMADescriptorVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", dims: "list[ConstantVariable]", block_dims: "list[ConstantVariable]", element_size: "ConstantVariable", **kwargs, ): assert isinstance(data_ptr, variables.DataPtrVariable) super().__init__(**kwargs) self.data_ptr = data_ptr self.dims = dims self.block_dims = block_dims self.element_size = element_size def to_metadata(self): return ( [dim.as_proxy() for dim in self.dims], [dim.as_proxy() for dim in self.block_dims], self.element_size.as_proxy(), ) def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", f"create_{len(self.dims)}d_tma_descriptor", ) ) self.data_ptr.reconstruct(codegen) args = [*self.dims, *self.block_dims, self.element_size] codegen.foreach(args) codegen.call_function(len(args) + 1, False) class CreateTMADescriptorVariable(VariableTracker): def __init__( self, rank: int, **kwargs, ) -> None: assert rank in (1, 2) super().__init__(**kwargs) self.rank = rank def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] if not isinstance(ptr, variables.DataPtrVariable): raise Unsupported( "Please ensure there were no graph breaks between " f"create_{self.rank}d_tma_descriptor and the upstream " ".data_ptr() call." ) if self.rank == 1: assert len(args) + len(kwargs) == 4 dims = [ kwargs["dim"] if "dim" in kwargs else args[1], ] block_dims = [ kwargs["block_dim"] if "block_dim" in kwargs else args[2], ] else: assert len(args) + len(kwargs) == 6 dims = [ kwargs["dim1"] if "dim1" in kwargs else args[1], kwargs["dim0"] if "dim0" in kwargs else args[2], ] block_dims = [ kwargs["block_dim1"] if "block_dim1" in kwargs else args[3], kwargs["block_dim0"] if "block_dim0" in kwargs else args[4], ] element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] return TMADescriptorVariable( data_ptr=ptr, dims=dims, block_dims=block_dims, element_size=element_size, )