# mypy: ignore-errors """ This module contains miscellaneous variable tracker implementations for various Python types and features used in Dynamo's symbolic execution. These classes help track and propagate information about different kinds of variables during graph capture. Key classes include: - SuperVariable: Handles super() calls and method resolution - ExceptionVariable: Tracks exception objects - RandomVariable: Manages random number generators - GetAttrVariable: Tracks attribute access - MethodWrapperVariable: Handles method wrappers - PythonModuleVariable: Tracks Python modules - NumpyVariable: Handles numpy functions and types - StringFormatVariable: Manages string formatting - DebuggingVariable: Handles print and logging """ import dataclasses import functools import inspect import itertools import random import re import sys import types import warnings from typing import Optional, TYPE_CHECKING import torch._C import torch._numpy as tnp import torch.utils._pytree as pytree from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import ( AttrSource, GenericAttrSource, GetItemSource, TypeSource, WeakRefCallSource, ) from ..utils import ( check_unspec_or_constant_args, cmp_name_to_op_mapping, identity, is_tensor_base_attr_getter, istype, list_methods, proxy_args_kwargs, set_example_value, tuple_methods, ) from .base import VariableTracker from .constant import ConstantVariable from .functions import NestedUserFunctionVariable, UserFunctionVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator class NO_SUCH_SUBOBJ: pass class SuperVariable(VariableTracker): _nonvar_fields = { *VariableTracker._nonvar_fields, } def __init__(self, typevar, objvar=None, **kwargs) -> None: super().__init__(**kwargs) # typevar is the fist argument to super(). In the case where no argument # is provided to super(), it is the __class__ object where # the super() function is being called self.typevar = typevar # objvar here must be an instance or subtype of typevar. # In the case where super() is called without arguments, it is the first argument # to the current function where super() is called from (self for regular method, # cls for a classmethod) self.objvar = objvar def reconstruct(self, codegen): codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) codegen(self.typevar) if self.objvar is not None: codegen(self.objvar) codegen.extend_output(create_call_function(2, False)) else: codegen.extend_output(create_call_function(1, False)) def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): assert self.objvar, "1-arg super not implemented" search_type = self.typevar.as_python_constant() # The rest of this function does two things: # - Walk the mro to find where the attribute comes from to be # able to provide accurate source # - Call the getattr to get the object # Find the class object, where the function lives. # When objvar is "self", use type(self), when objvar is "cls", use it as-is type_to_use = self.objvar.python_type() type_to_use_source = ( TypeSource(self.objvar.source) if self.objvar.source else None ) if issubclass(type_to_use, type): type_to_use = self.objvar.value type_to_use_source = self.objvar.source source = None search_mro = type_to_use.__mro__ try: start_index = search_mro.index(search_type) + 1 except ValueError: # Corner case where the typevar is not in the mro of the objvar # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 return getattr(super(search_type, type_to_use), name), None # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 # super has its getattro implementation. The key point is that instead of calling getattr, it checks the # attribute in the class __dict__ for index in range(start_index, len(search_mro)): # Dont call getattr, just check the __dict__ of the class if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): if resolved_getattr is not NO_SUCH_SUBOBJ: # Equivalent of something like type(L['self']).__mro__[1].attr_name if type_to_use_source: source = AttrSource( GetItemSource( AttrSource(type_to_use_source, "__mro__"), index ), name, ) return resolved_getattr, source unimplemented("Unable to resolve super getattr") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": # Check if getattr is a constant. If not, delay the actual work by # wrapping the result in GetAttrVariable. Mostly super is called with a # method, so most of the work is delayed to call_function. # # We could have just implemented a const_getattr. However, super is # special when it comes to finding sources. Compared to other VTs, super # requires the attr name to walk the mro and find the actual source (and # not just AttrSource). value, source = self._resolved_getattr_and_source(self, name) if not variables.ConstantVariable.is_literal(value): return GetAttrVariable(self, name) if source: install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: objvar = self.objvar from ..side_effects import AttributeMutationNew if ( isinstance(objvar, variables.UserDefinedObjectVariable) and isinstance(objvar.mutation_type, AttributeMutationNew) and not (args or kwargs) ): with do_not_convert_to_tracable_parameter(): return variables.UserFunctionVariable( unpatched_nn_module_init, source=source ).call_function(tx, [self.objvar] + args, kwargs) else: unimplemented("super() nn.Module.__init__") elif ( self.objvar.source and hasattr(inner_fn, "__name__") and inner_fn.__name__ == "__new__" and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) ): user_cls = inner_fn.__self__ if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": user_cls_vt = variables.BuiltinVariable(user_cls) else: user_cls_source = source.member user_cls_vt = variables.UserDefinedClassVariable( user_cls, source=user_cls_source ) return user_cls_vt.call_method(tx, "__new__", args, kwargs) elif isinstance(inner_fn, staticmethod) and isinstance( inner_fn.__func__, types.FunctionType ): return variables.UserFunctionVariable( inner_fn.__func__, source=source ).call_function(tx, args, kwargs) elif isinstance(inner_fn, classmethod) and isinstance( inner_fn.__func__, types.FunctionType ): return variables.UserMethodVariable( inner_fn.__func__, self.objvar, source=source ).call_function(tx, args, kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source ).call_function(tx, [self.objvar] + args, kwargs) elif isinstance(inner_fn, types.MethodType): return variables.UserMethodVariable( inner_fn.__func__, self.objvar, source=source ).call_function(tx, args, kwargs) elif is_standard_setattr(inner_fn) and isinstance( self.objvar, UserDefinedObjectVariable ): return self.objvar.method_setattr_standard(tx, *args, **kwargs) elif inner_fn is object.__delattr__: attr = args[0] try: attr = attr.as_python_constant() except NotImplementedError: unimplemented(f"non-const delattr attr: {attr}") if not tx.output.side_effects.is_attribute_mutation(self.objvar): unimplemented(f"delattr({self.objvar}, {attr}, ...)") tx.output.side_effects.store_attr( self.objvar, attr, variables.DeletedVariable() ) return variables.ConstantVariable(None) elif ( isinstance(self.objvar, variables.UserDefinedDictVariable) and inner_fn in self.objvar._dict_methods ): return self.objvar._dict_vt.call_method(tx, name, args, kwargs) elif ( isinstance(self.objvar, variables.UserDefinedTupleVariable) and inner_fn in tuple_methods ): return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) elif ( isinstance(self.objvar, variables.UserDefinedListVariable) and inner_fn in list_methods ): return self.objvar._list_vt.call_method(tx, name, args, kwargs) elif inner_fn is object.__getattribute__: # object.__getattribute__ has no side-effects. We can directly call # __getattribute__ to access the attribute. attr_name = args[0].value if tx.output.side_effects.has_pending_mutation_of_attr( self.objvar, attr_name ): result = tx.output.side_effects.load_attr( self.objvar, attr_name, deleted_ok=True ) if isinstance(result, variables.DeletedVariable): raise_observed_exception(AttributeError, tx) return result try: # NB - use object.__getattribute__ to prevent running any user code attr_value = object.__getattribute__(self.objvar.value, attr_name) except AttributeError: raise_observed_exception(AttributeError, tx) attr_source = None if self.objvar.source is not None: # setup a object.__getattribute__(self.objvar, name) source attr_source = GenericAttrSource(self.objvar.source, attr_name) return VariableTracker.build(tx, attr_value, attr_source) unimplemented(f"non-function or method super: {inner_fn}") class ExceptionVariable(VariableTracker): # The ExceptionVariable corresponds to the BaseException class in Python def __init__(self, exc_type, args, **kwargs) -> None: super().__init__(**kwargs) self.exc_type = exc_type self.args = args # When raising a new exception while another exception is already being # handled, the new exception's __context__ attribute is automatically # set to the handled exception. self.__context__ = ConstantVariable(None) # Set when user raised an exception from another: # raise ... from ... self.__cause__ = ConstantVariable(None) # Boolean flag that controls whether the __context__ attribute is set self.__suppress_context__ = ConstantVariable(False) # Contains the call stack where the exception was raised. Dynamo does # not track traceback. So, this variable is always set to None self.__traceback__ = ConstantVariable(None) def set_context(self, context: "ExceptionVariable"): self.__context__ = context def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.load_import_from("builtins", self.exc_type.__name__) ) codegen.foreach(self.args) codegen.call_function(len(self.args), False) def codegen_attr(name: str) -> None: attr = getattr(self, name) if istype(attr, ConstantVariable): assert attr.value in (True, False, None), attr else: codegen.dup_top() codegen(attr) codegen.extend_output(codegen.rot_n(2)) codegen.store_attr(name) codegen_attr("__context__") codegen_attr("__cause__") codegen_attr("__suppress_context__") def python_type(self): return self.exc_type def call_setattr( self, tx: "InstructionTranslator", name_var: VariableTracker, val: VariableTracker, ): def raise_error(msg): raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) name = name_var.as_python_constant() if name == "__context__": self.set_context(val) elif name == "__cause__": if (isinstance(val, ConstantVariable) and val.value is None) or isinstance( val, ( variables.BuiltinVariable, variables.ExceptionVariable, variables.UserDefinedExceptionClassVariable, variables.UserDefinedExceptionObjectVariable, ), ): self.__cause__ = val self.__suppress_context__ = variables.ConstantVariable(True) else: raise_error("exception cause must be None or derive from BaseException") elif name == "__suppress_context__": if isinstance(val, ConstantVariable) and val.value in (True, False): self.__suppress_context__ = val else: raise_error("exception cause must be None or derive from BaseException") elif name == "__traceback__": if isinstance(val, ConstantVariable) and val.value is None: self.__traceback__ = val else: unimplemented(f"setattr(ExceptionVariable, {name_var}, {val})") else: unimplemented(f"setattr(ExceptionVariable, {name_var}, {val})") return variables.ConstantVariable(None) def call_method(self, tx, name, args, kwargs): if name == "__setattr__": return self.call_setattr(tx, *args) elif name == "with_traceback": [tb] = args self.call_setattr(tx, ConstantVariable("__traceback__"), tb) return self else: return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name): if name == "__context__": return self.__context__ elif name == "__cause__": return self.__cause__ elif name == "__suppress_context__": return self.__suppress_context__ elif name == "__traceback__": return variables.ConstantVariable(None) elif name == "args": return variables.ListVariable(self.args, source=self.source) return super().var_getattr(tx, name) def __str__(self): return f"{self.__class__.__name__}({self.exc_type})" __repr__ = __str__ class UnknownVariable(VariableTracker): """ It could be anything! """ class DelayGraphBreakVariable(UnknownVariable): """ Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. """ class ComptimeVariable(VariableTracker): """ This variable is special, it lets you execute arbitrary code at Dynamo compile time """ def reconstruct(self, codegen): raise NotImplementedError("comptime is special form") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": from ..comptime import comptime # To support the comptime.print_graph convenience accessors from .functions import UserFunctionVariable return UserFunctionVariable( getattr(comptime, name), source=AttrSource(self.source, name) ) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": from ..comptime import ComptimeContext # TODO: support an expression form as well assert not kwargs # Second argument is runtime lambda, ignored assert len(args) <= 2 fn = args[0] if isinstance(fn, UserFunctionVariable): fn.get_function()(ComptimeContext(tx)) elif isinstance(fn, NestedUserFunctionVariable): # We have to manually bind the freevars ourselves code = fn.get_code() assert not fn.closure, ( "comptime function must not have free variables, " f"but these variables were free: {code.co_freevars}" ) func = types.FunctionType( code, fn.f_globals, fn.fn_name.as_python_constant(), tuple(fn.defaults.items) if fn.defaults else None, # We could automatically promote free variables into # ComptimeVar but this is confusing if you access # a free variable that we actually DO have the runtime # value for # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) (), ) func(ComptimeContext(tx)) else: raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") return variables.ConstantVariable.create(None) class CellVariable(VariableTracker): # If the cell existed before Dynamo tracing started, this will be the # VariableTracker that represents the cell content. # # Note that all mutation to the cell (i.e., its content) will be buffered in # SideEffects, rather than being reflected here. One can think of # `CellVariable` as a special case for `UserDefinedObjectVariable`. pre_existing_contents: Optional[VariableTracker] # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`). local_name: Optional[str] = None def __init__( self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs ) -> None: super().__init__(**kwargs) self.pre_existing_contents = pre_existing_contents class NewGlobalVariable(VariableTracker): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def produce_trampoline_autograd_apply(fn_cls): def trampoline_autograd_apply(*args, **kwargs): return fn_cls.apply(*args, **kwargs) trampoline_autograd_apply._origin = produce_trampoline_autograd_apply return trampoline_autograd_apply class AutogradFunctionVariable(VariableTracker): """represents a torch.autograd.Function subclass""" _nonvar_fields = { "fn_cls", *VariableTracker._nonvar_fields, } def __init__(self, fn_cls, **kwargs) -> None: super().__init__(**kwargs) self.fn_cls = fn_cls def call_apply(self, tx: "InstructionTranslator", args, kwargs): requires_grad = False def visit(node): nonlocal requires_grad if isinstance(node, variables.TensorVariable): if node.requires_grad is not False: requires_grad = True if isinstance(node, variables.NNModuleVariable): if node.is_training(tx): requires_grad = True VariableTracker.visit(visit, (args, kwargs)) if requires_grad and torch.is_grad_enabled(): if config.capture_autograd_function is False: warnings.warn( "The config.capture_autograd_function flag is deprecated, it's now always true." ) from torch._functorch.autograd_function import ( autograd_function_forward_rewritten, ) from torch.autograd.function import _is_setup_context_defined forward_fn = self.fn_cls.forward is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) if is_setup_ctx_defined: # If setup_context is defined, we generate a new forward function which includes # the original forward and setup_context function, and trace the new forward function. forward_fn = autograd_function_forward_rewritten( self.fn_cls.forward, self.fn_cls.setup_context ) vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] if vjp_fn is not torch.autograd.Function.vjp: unimplemented("NYI - User defind vjp") jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] if jvp_fn is not torch.autograd.Function.jvp: unimplemented("NYI - User defind jvp") from .higher_order_ops import AutogradFunctionApplyVariable source = self.source if source is None: source = AttrSource( tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ ) val = AutogradFunctionApplyVariable( forward_fn, self.fn_cls.backward, source, source=AttrSource(source, member="apply"), ).call_function(tx, args, kwargs) # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping # the forward function, as we don't want to generate guards for new_forward.__closure__ # if forward is rewritten by autograd_function_forward_rewritten. # But we still need to generate correct guards for the original forward and setup_context # functions, so we have to add guards manually. if self.source: fwd_src = AttrSource(self.source, "forward") install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH)) if is_setup_ctx_defined: setup_ctx_src = AttrSource(self.source, "setup_context") install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH)) return val if self.source: source = AttrSource(self.source, "forward") else: source = None fn = self.fn_cls.forward ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) args = [ctx, *args] if isinstance(fn, types.FunctionType): sig = inspect.signature(fn) if len(args) - 1 == len(sig._parameters): args = args[1:] # Don't use context return variables.UserFunctionVariable(fn, source=source).call_function( tx, args, kwargs ) elif isinstance(fn, types.MethodType): return variables.UserMethodVariable( fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), source=source, ).call_function(tx, args, kwargs) else: unimplemented( f"non-function or method in subclass of torch.autograd.Function: {fn}" ) def call_backward(self, tx: "InstructionTranslator", args, kwargs): fn = self.fn_cls.backward assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction assert isinstance(fn, types.FunctionType) fn_source = AttrSource(self.source, "backward") return variables.UserFunctionVariable(fn, source=fn_source).call_function( tx, args, kwargs ) def call_function(self, tx: "InstructionTranslator", args, kwargs): return AutogradFunctionVariable(self.fn_cls) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ): from ..trace_rules import is_callable_allowed from .builder import wrap_fx_proxy if name == "apply": if is_callable_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", trampoline_autograd_apply, *proxy_args_kwargs(args, kwargs), ), ) else: return self.call_apply(tx, args, kwargs) elif name == "backward": return self.call_backward(tx, args, kwargs) else: from .. import trace_rules source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.fn_cls, name) except AttributeError: obj = None if isinstance(obj, staticmethod): func = obj.__get__(self.fn_cls) if source is not None: return ( trace_rules.lookup(func) .create_with_source(func, source=source) .call_function(tx, args, kwargs) ) else: return trace_rules.lookup(func)(func).call_function( tx, args, kwargs ) elif isinstance(obj, classmethod): return variables.UserMethodVariable( obj.__func__, self, source=source ).call_function(tx, args, kwargs) else: unimplemented(f"Unsupported method: {name}") @dataclasses.dataclass class SavedTensorBox: tensors: list[VariableTracker] = dataclasses.field(default_factory=list) class AutogradFunctionContextVariable(UserDefinedObjectVariable): """ Tracks an autograd.Function() context using mutation tracking in side_effects.py """ _nonvar_fields = { "proxy", "inference", "saved_tensors", *UserDefinedObjectVariable._nonvar_fields, } def __init__( self, value, value_type=None, inference=False, proxy=None, saved_tensors=None, needs_input_grad=None, non_differentiable=None, **kwargs, ) -> None: super().__init__(value=value, value_type=value_type, **kwargs) self.inference = inference self.proxy = proxy self.saved_tensors = saved_tensors self.needs_input_grad = needs_input_grad self.non_differentiable = non_differentiable @staticmethod def create(tx: "InstructionTranslator", args=None, kwargs=None): needs_input_grad = None if args and not kwargs: needs_input_grad = tuple( isinstance(x, variables.TensorVariable) and x.requires_grad for x in args ) proxy = tx.output.create_proxy( "call_function", torch.autograd.function.FunctionCtx, (), {} ) out = tx.output.side_effects.track_object_new( None, torch.autograd.function.FunctionCtx, functools.partial( AutogradFunctionContextVariable, inference=True, proxy=proxy, saved_tensors=SavedTensorBox(), needs_input_grad=needs_input_grad, ), {}, ) set_example_value(proxy.node, out.value) return out def as_proxy(self): if self.proxy is None: unimplemented("proxy not set") return self.proxy def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "__setattr__": return super().call_method(tx, name, args, kwargs) elif name == "mark_non_differentiable": assert len(kwargs) == 0 self.non_differentiable = proxy_args_kwargs(args, {})[0] return variables.ConstantVariable.create(None) if name != "save_for_backward": unimplemented(f"autograd.Function context method: {name}") if self.saved_tensors is None: unimplemented( "save_for_backward only supported on a newly constructed FunctionCtx" ) if not self.inference: assert self.source and not kwargs tx.output.side_effects.track_save_for_backward(self, args) # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. if len(self.saved_tensors.tensors) > 0: self.saved_tensors.tensors = [] for arg in args: self.saved_tensors.tensors.append(arg) return variables.ConstantVariable.create(None) def var_getattr(self, tx: "InstructionTranslator", name): if name in ["save_for_backward", "mark_non_differentiable"]: return LambdaVariable( lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) ) if name == "saved_tensors" and self.saved_tensors is not None: return variables.TupleVariable(list(self.saved_tensors.tensors)) if name == "needs_input_grad": if self.needs_input_grad is not None: return variables.ConstantVariable.create(self.needs_input_grad) if self.source: source = AttrSource(self.source, "needs_input_grad") return VariableTracker.build(tx, self.value.needs_input_grad, source) return super().var_getattr(tx, name) class AutogradEngineVariable(UserDefinedObjectVariable): """ Represents a torch._C._ImperativeEngine instance. """ def __init__( self, value, value_type=None, **kwargs, ) -> None: super().__init__(value=value, value_type=value_type, **kwargs) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if name == "queue_callback": if torch._dynamo.compiled_autograd.in_compiled_autograd_region: assert tx.one_graph, ( "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" ) return variables.UserFunctionVariable( torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, source=self.source, ).call_function( tx, (tx.output.side_effects.get_ca_final_callbacks_var(), *args), kwargs, ) else: unimplemented( "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" ) else: unimplemented(f"torch._C._ImperativeEngine method: {name}") class LambdaVariable(VariableTracker): def __init__(self, fn, **kwargs) -> None: super().__init__(**kwargs) self.fn = fn def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return self.fn(*args, **kwargs) class GetAttrVariable(VariableTracker): _nonvar_fields = { "name", "py_type", *VariableTracker._nonvar_fields, } def __init__(self, obj, name, py_type=None, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(obj, VariableTracker) assert isinstance(name, str) self.obj = obj self.name = name self.py_type = py_type # In some cases we know the type (ex. tensor methods) def python_type(self): if self.py_type is not None: return self.py_type else: return super().python_type() def __repr__(self) -> str: return f"{self.__class__.__name__}({self.obj}, {self.name})" @staticmethod def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): return getattr(base_proxy, attr) def as_proxy(self): return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) def as_python_constant(self): constant = self.obj.as_python_constant() try: return getattr(constant, self.name) except AttributeError: raise NotImplementedError(f"{self} is not a constant") from None def const_getattr(self, tx: "InstructionTranslator", name): if not isinstance(self.obj, variables.NNModuleVariable): raise NotImplementedError step1 = tx.output.get_submodule(self.obj.module_key) if self.name not in step1.__dict__: raise NotImplementedError step2 = inspect.getattr_static(step1, self.name) if name not in step2.__dict__: raise NotImplementedError return inspect.getattr_static(step2, name) def reconstruct(self, codegen): codegen(self.obj) codegen.extend_output(codegen.create_load_attrs(self.name)) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return self.obj.call_method(tx, self.name, args, kwargs) def call_method( self, tx, name, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if ( name in ("__getitem__", "get") and self.name == "__dict__" and not kwargs and args[0].is_python_constant() and isinstance( self.obj, ( variables.UserDefinedObjectVariable, variables.NNModuleVariable, variables.UserDefinedClassVariable, ), ) ): obj = self.obj key = args[0].as_python_constant() if obj.has_key_in_generic_dict(tx, key): # redirect to var_getattr on the original obj return obj.var_getattr(tx, key) # Return the default value for get if name == "get": if len(args) == 2: return args[1] else: return variables.ConstantVariable(None) elif ( name == "__contains__" and self.name == "__dict__" and len(args) == 1 and args[0].is_python_constant() and not kwargs and isinstance( self.obj, ( variables.UserDefinedObjectVariable, variables.NNModuleVariable, variables.UserDefinedClassVariable, ), ) ): obj = self.obj key = args[0].as_python_constant() if obj.has_key_in_generic_dict(tx, key): return variables.ConstantVariable(True) else: return variables.ConstantVariable(False) elif name == "__setitem__" and self.name == "__dict__" and not kwargs: if isinstance(self.obj, variables.UserDefinedObjectVariable): # Bypass any custom setattr as we are updating the `__dict__` itself return self.obj.method_setattr_standard(tx, args[0], args[1]) if isinstance(self.obj, variables.NNModuleVariable): # This matches how `setattr` is handled for NNModuleVariable self.obj.convert_to_unspecialized(tx) return super().call_method(tx, name, args, kwargs) class MethodWrapperVariable(VariableTracker): def __init__(self, method_wrapper, **kwargs) -> None: super().__init__(**kwargs) self.method_wrapper = method_wrapper self._builtin_fns = {} def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( args[0], variables.TensorVariable ): assert len(args) == 1 and len(kwargs) == 0 return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) # method-wrapper variables are common in __init__ calls. For example, # str("foo").__init__ is a method-wrapper. These method wrappers point # to C functions. Here we intercept if these method-wrappers are from # builtins and then call the function counterpart directly by obtaining # the self object. self_obj = self.method_wrapper.__self__ wrapper_name = self.method_wrapper.__name__ # TODO(dynamo-team) - We can perhaps expand the scope to more names and # more builtins. if wrapper_name == "__init__": fn_obj = type(self_obj).__init__ if fn_obj is object.__init__: return variables.BuiltinVariable(object).call_method( tx, wrapper_name, [self_obj, *args], kwargs ) super().call_function(tx, args, kwargs) def is_python_constant(self): return True def as_python_constant(self): return self.method_wrapper class GetSetDescriptorVariable(VariableTracker): def __init__(self, desc, **kwargs) -> None: super().__init__(**kwargs) self.desc = desc def var_getattr(self, tx: "InstructionTranslator", name): if name == "__get__" and self.source: source = AttrSource(self.source, "__get__") return VariableTracker.build(tx, self.desc.__get__, source) else: return super().var_getattr(tx, name) def is_python_constant(self): return True def as_python_constant(self): return self.desc class PythonModuleVariable(VariableTracker): _nonvar_fields = { "value", "is_torch", *VariableTracker._nonvar_fields, } def __init__(self, value: types.ModuleType, **kwargs) -> None: super().__init__(**kwargs) self.value = value self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") def python_type(self): return types.ModuleType def as_python_constant(self): return self.value def __repr__(self) -> str: return f"PythonModuleVariable({self.value})" def call_obj_hasattr(self, tx: "InstructionTranslator", name): result = hasattr(self.value, name) return variables.ConstantVariable.create(result) def var_getattr(self, tx: "InstructionTranslator", name): if tx.output.side_effects.has_pending_mutation_of_attr(self, name): return tx.output.side_effects.load_attr(self, name) if self.is_torch or name not in self.value.__dict__: attr_value = getattr(self.value, name) else: attr_value = self.value.__dict__[name] source = self.source and AttrSource(self.source, name) return VariableTracker.build(tx, attr_value, source) class TypingVariable(VariableTracker): def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # Create a new typing variable, e.g., `List[int]` if name == "__getitem__" and len(args) == 1: new_typing = self.value[args[0].as_python_constant()] return TypingVariable(new_typing) unimplemented("unsupported method call on typing variablel") def var_getattr(self, tx: "InstructionTranslator", name: str): from .builder import SourcelessBuilder, VariableBuilder if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) if tx.output.side_effects.has_pending_mutation_of_attr(self, name): return tx.side_effects.load_attr(self, name) value = getattr(self.value, name) if self.source: attr_source = AttrSource(self.source, name) return VariableBuilder(tx, attr_source)(value) else: return SourcelessBuilder.create(tx, value) def as_python_constant(self): return self.value def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None: # We're just trying to load the type here. Reconstructing the type from # scratch is tricky - for a type like `typing.List[int]` we'd need to # deconstruct the origin and args. The origin for `List[int]` is `list` # and the args is `(int,)`. When we recombine those we get the parts # back and need to emit code for: # # `typing.List[int]` # # But it's # worse than that - what if `typing` isn't in the globals (or # was loaded like `import typing as _typing ; _typing.List[int]`?) so we # really need to do something like: # # `sys.modules["typing"].List[int]` # # Argh - but what if they rewrote the global `int`? So we have to do: # # `sys.modules["typing"].List[sys.modules["builtins"].int]` # # But where do we get `sys`? What if they never imported it or have # something ELSE called `sys`? # # Let's skip all that noise and just emit it as a simple const. # codegen.append_output(codegen.create_load_const(self.value)) @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): """ This generates a mapping from numpy modules to their torch._numpy modules equivalents. """ from ..utils import NP_TO_TNP_MODULE np_fn_to_tnp_fn = {} for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): for fn_name, tnp_fn in tnp_mod.__dict__.items(): if callable(tnp_fn): # some internal details do leak from tnp # which are not part of numpy API. if np_fn := getattr(np_mod, fn_name, None): np_fn_to_tnp_fn[np_fn] = tnp_fn return np_fn_to_tnp_fn @functools.lru_cache(maxsize=1) def get_tnp_to_np_map(): """ This is just the reverse mapping of get_np_to_tnp_map() - mapping from torch._numpy modules to numpy equivalents. """ m = get_np_to_tnp_map() return {v: k for k, v in m.items()} class NumpyVariable(VariableTracker): """ Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. """ constant_fold_functions = (tnp.issubdtype,) def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value @classmethod def can_constant_fold_through(cls, fn): mod = fn.__module__.split(".") assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] return fn in cls.constant_fold_functions @classmethod def get_constant_collection_for_func(cls, fn): mod = fn.__module__.split(".") assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] return np_constant_collections_map.get(fn, None) def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if not config.trace_numpy: unimplemented(f"numpy.{self.value}()") from ..utils import numpy_to_tensor_wrapper from .tensor import NumpyNdarrayVariable func = get_np_to_tnp_map().get(self.value) if func is None: unimplemented( f"Can't find numpy function {self.value} in torch._numpy. " " Please file an issue to request support for this function." ) # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) if ( collection_variable_typ := self.get_constant_collection_for_func(func) ) is not None: try: return collection_variable_typ( self.value( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) ) except NotImplementedError: unimplemented( f"{self.value.__name__} with non-const args: {args} {kwargs}" ) else: if ( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream ): msg = f"delegate '{func.__qualname__}' to NumPy itself via " msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" unimplemented(msg) args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) if self.can_constant_fold_through(func) and ( check_unspec_or_constant_args(args, kwargs) ): # constant fold return variables.ConstantVariable.create( self.as_python_constant()( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ), ) # TODO Add all the functions that go from constants to constants to can_constant_fold_through proxy = tx.output.create_proxy( "call_function", numpy_to_tensor_wrapper(func), *proxy_args_kwargs(args, kwargs), ) return NumpyNdarrayVariable.create(tx, proxy) def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": unimplemented("numpy") def as_python_constant(self): return self.value def as_proxy(self): if config.trace_numpy and isinstance(self.value, type): # This handles numpy dtype attributes such as np.float32 # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does return self.value.__name__ return super().as_proxy() # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls class NullVariable(VariableTracker): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def __repr__(self) -> str: return "NullVariable" def reconstruct(self, codegen): if sys.version_info < (3, 11): unimplemented("cannot reconstruct NullVariable in < Python 3.11") codegen.append_output(create_instruction("PUSH_NULL")) class DeletedVariable(VariableTracker): """Marker used to implement delattr()""" class StringFormatVariable(VariableTracker): """ Represents a call to str.format(), we delay calling format until after the graph. """ _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} @classmethod def create(cls, format_string, sym_args, sym_kwargs): if all( x.is_python_constant() for x in itertools.chain(sym_args, sym_kwargs.values()) ): return variables.ConstantVariable.create( format_string.format( *[v.as_python_constant() for v in sym_args], **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, ) ) return cls(format_string, list(sym_args), dict(sym_kwargs)) def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(format_string, str) self.format_string = format_string self.sym_args = sym_args self.sym_kwargs = sym_kwargs def __repr__(self) -> str: return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_const(self.format_string), codegen.create_load_attr("format"), ] ), call_function_ex=True, ) codegen(variables.TupleVariable(self.sym_args)) kwargs = { variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() } codegen(variables.ConstDictVariable(kwargs)) codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) class DebuggingVariable(VariableTracker): """ Represents a call to a debugging function like print(), or something registered to config.reorderable_logging_functions. """ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value @staticmethod def is_reorderable_logging_function(obj): return ( callable(obj) and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) and obj in torch._dynamo.config.reorderable_logging_functions ) def call_function(self, tx: "InstructionTranslator", args, kwargs): if tx.export: # For export cases, we can just make debugging functions no-ops return if not self.can_reorder_logs(self.value, args, kwargs): unimplemented( f"Reordering debugging function {self.value} " f"with inputs {args} {kwargs} is not yet implemented." ) tx.debug_locals.append((self, list(args))) def reconstruct(self, codegen): return self.source.reconstruct(codegen) @staticmethod def can_reorder_logs(fn, args, kwargs) -> True: """ Run some additional checks for what sort of function calls can we actually reorder. """ allowed_input_types = ( variables.TensorVariable, variables.ConstantVariable, StringFormatVariable, ) flat_args = pytree.tree_leaves([args, kwargs]) for arg in flat_args: if not isinstance(arg, allowed_input_types): return False return True class LoggingLoggerVariable(VariableTracker): """ Represents a call to any of logging.Logger methods """ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if tx.export: # For export cases, we can just make debugging functions no-ops return method = getattr(self.value, name, None) function = getattr(method, "__func__", None) if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): return variables.ConstantVariable.create(None) unimplemented( "Logger not supported for non-export cases. " "To avoid graph breaks caused by logger in compile-mode, it is recommended to" " disable logging by adding logging methods to config.ignore_logger_methods" ) class ConstantLikeVariable(VariableTracker): """self.value is a compile-time constant, but not a literal""" _error_prefix = "ConstantLikeVariable" try: from numpy import ( dtype as np_dtype, floating as np_floating, generic as np_generic, ) except ImportError: np_floating = type("invalid_type", (), {}) np_dtype = type("invalid_type", (), {}) def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value def as_python_constant(self): return self.value def call_method( self, tx, name, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: try: # we only support constant propagation for methods cargs = [x.as_python_constant() for x in args] ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} except NotImplementedError: unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})") result = getattr(self.value, name)(*cargs, **ckwargs) if variables.ConstantVariable.is_literal(result): return variables.ConstantVariable.create(result) if isinstance(result, re.Match): return ConstantRegexMatchVariable(result) unimplemented(f"{self._error_prefix}.{name}() -> {result}") def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: result = getattr(self.value, name) if isinstance(result, self.np_floating): result = float(result) if isinstance(result, self.np_dtype): return NumpyDTypeVariable(result) if isinstance(result, type) and issubclass(result, self.np_generic): # things like x.dtype.type return NumpyVariable(result) if variables.ConstantVariable.is_literal(result): return variables.ConstantVariable.create(result) return GetAttrVariable(self, name) class RegexPatternVariable(ConstantLikeVariable): _error_prefix = "re.Pattern" class ConstantRegexMatchVariable(ConstantLikeVariable): _error_prefix = "re.Match" class TorchVersionVariable(ConstantLikeVariable): _error_prefix = "torch.__version__" def __init__(self, **kwargs) -> None: kwargs.setdefault("value", torch.__version__) assert kwargs["value"] is torch.__version__ super().__init__(**kwargs) class NumpyTypeInfoVariable(ConstantLikeVariable): _error_prefix = "np.iinfo/np.finfo" class NumpyDTypeVariable(ConstantLikeVariable): _error_prefix = "np.dtype[...]" def as_proxy(self): """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. This also handles unsupported things nicely (i.e. structured arrays and object arrays). """ return self.value.type.__name__ np_constant_collections_map = { tnp.finfo: NumpyTypeInfoVariable, tnp.iinfo: NumpyTypeInfoVariable, tnp.dtype: NumpyDTypeVariable, } class RandomClassVariable(VariableTracker): """random.Random""" def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def call_function(self, tx: "InstructionTranslator", args, kwargs): if len(args) > 1: unimplemented("random.Random() with > 1 arg") elif kwargs: unimplemented("random.Random() with kwargs") seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] return RandomVariable( seed=seed, mutation_type=variables.base.ValueMutationNew() ) class RandomVariable(VariableTracker): """random.Random() Implemented by wrapping a VariableTracker around a random.Random object. The supported methods for the random.Random object cannot be overriden. Assumes that random objects behave the same given a set seed or state. """ _nonvar_fields = { "random", *VariableTracker._nonvar_fields, } _supported_fn_names = { "random", "randint", "randrange", "uniform", } def __init__( self, rand: Optional[random.Random] = None, seed: Optional[VariableTracker] = None, **kwargs, ) -> None: super().__init__(**kwargs) if rand is not None: assert self.is_supported_random_obj(rand) self.random = random.Random() self.random.setstate(rand.getstate()) else: seed = seed.as_python_constant() if seed is not None else None self.random = random.Random(seed) def python_type(self): return random.Random def as_python_constant(self): return self.random @staticmethod def is_supported_random_obj(val): if type(val) is not random.Random: return False for name in itertools.chain( RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") ): if not hasattr(val, name): return False meth = getattr(val, name) if inspect.isbuiltin(meth): # e.g. random.Random.random if meth != getattr(random.Random, name).__get__(val): return False else: if getattr(meth, "__func__", None) is not getattr(random.Random, name): return False return True @staticmethod def check_state(state): assert type(state) is tuple assert type(state[0]) is int assert type(state[1]) is tuple assert all(type(x) is int for x in state[1]) assert state[2] is None or type(state[2]) is float @staticmethod def wrap_state(state): RandomVariable.check_state(state) return variables.TupleVariable( [ variables.ConstantVariable.create(state[0]), variables.TupleVariable( [variables.ConstantVariable.create(x) for x in state[1]] ), variables.ConstantVariable.create(state[2]), ] ) @staticmethod def unwrap_state(state): state_obj = state.as_python_constant() RandomVariable.check_state(state_obj) return state_obj def call_method( self, tx, name, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name == "seed": tx.output.side_effects.mutation(self) self.random.seed( *[x.as_python_constant() for x in args], **{key: val.as_python_constant() for key, val in kwargs.items()}, ) return variables.ConstantVariable.create(None) elif name == "getstate": return self.wrap_state(self.random.getstate()) elif name == "setstate": tx.output.side_effects.mutation(self) self.random.setstate(self.unwrap_state(args[0])) return variables.ConstantVariable.create(None) elif name in self._supported_fn_names: tx.output.side_effects.mutation(self) state = self.random.getstate() def call_random_meth(*args, **kwargs): r = random.Random() r.setstate(state) return getattr(r, name)(*args, **kwargs) # self.random state not actually updated by call_random_meth, so update here # by calling the method getattr(self.random, name)( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) return call_random_fn(tx, call_random_meth, args, kwargs) return super().call_method(tx, name, args, kwargs) def reconstruct(self, codegen): codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(random), codegen.create_load_attr("Random"), ] ) ) codegen.call_function(0, False) # NOTE using add_push_null may result in NULL being duplicated # so defer the push_null to call_function codegen.dup_top() codegen.load_attr("setstate") codegen(self.wrap_state(self.random.getstate())) codegen.call_function(1, True) codegen.pop_top() class WeakRefVariable(VariableTracker): @staticmethod def build(tx, weakref_value, **options): source = options.get("source", None) referent = weakref_value() source = source and WeakRefCallSource(source) referent_vt = VariableTracker.build(tx, referent, source) options["source"] = source return WeakRefVariable(referent_vt, **options) def __init__(self, referent_vt, **options): super().__init__(**options) self.referent_vt = referent_vt def call_function( self, tx: "InstructionTranslator", args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": return self.referent_vt def reconstruct(self, codegen): codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) codegen(self.referent_vt) codegen.extend_output(create_call_function(1, False))