# mypy: ignore-errors """ This module implements variable tracking for PyTorch optimizers during Dynamo tracing. The OptimizerVariable class provides specialized handling for optimizer instances by: - Optimizing the tracing of expensive optimizer initialization - Managing optimizer state and parameter group tracking - Handling tensor sources and guards for optimizer state tensors - Supporting CUDA graph execution through static tensor address management - Providing special handling for parameter gradients and optimizer state tensors Key features include: - Efficient initialization tracing via _init_group optimization - Automatic marking of optimizer state tensors as static for CUDA graphs - Proper source tracking for parameter groups, gradients, and state tensors - Guard installation for optimizer state structure - Support for both CPU and GPU tensor handling - Cleanup of static tensor references via finalizers The module integrates with Dynamo's broader tracing system while providing optimizer-specific optimizations and safety guarantees. """ import logging import weakref from typing import TYPE_CHECKING import torch from torch._logging import getArtifactLogger from torch.utils._pytree import tree_map_only from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, ConstDictKeySource, DictGetItemSource, GetItemSource, GlobalWeakRefSource, GradSource, ) from ..utils import GLOBAL_KEY_PREFIX from .base import VariableTracker from .constant import ConstantVariable from .dicts import ConstDictVariable from .lists import ListVariable from .misc import GetAttrVariable from .user_defined import UserDefinedObjectVariable if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator class ArgMappingException(Exception): pass class GuardInstallException(Exception): pass perf_hint_log = getArtifactLogger(__name__, "perf_hints") def _is_static_for_cudagraphs(x): from torch._inductor.cudagraph_trees import get_manager if x.is_cuda: manager = get_manager(x.device.index, False) is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None if manager: return ( is_static_address or manager.current_node._is_cuda_graph_recorded_tensor(x) ) else: return is_static_address else: # Don't print a warning for non-cuda tensors return True class OptimizerVariable(UserDefinedObjectVariable): _nonvar_fields = { "grad_to_source", "tensor_to_source", "static_tensor_names", *UserDefinedObjectVariable._nonvar_fields, } def __init__( self, value, grad_to_source=None, static_tensor_names=None, tensor_to_source=None, **kwargs, ) -> None: super().__init__(value, **kwargs) self.grad_to_source = grad_to_source or {} self.tensor_to_source = tensor_to_source or {} self.static_tensor_names = static_tensor_names or set() def call_method( self, tx, name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": """This is an optimization to avoid tracing the very slow initialization of the optimizer""" if name == "_init_group": try: self.graph_break_if_pending_mutation(tx) self.move_step_if_cpu() py_args, py_kwargs = self.get_python_args(*args, **kwargs) ret_val = self.value._init_group(*py_args, **py_kwargs) self.map_sources_and_install_guards(tx) self.update_list_args(tx, args, kwargs, py_args, py_kwargs) # stash a weak_ptr to optimizer to invalidate code # if the optimizer object dies mangled_name = f"__optimizer_{id(self.value)}" tx.store_global_weakref_by_id(mangled_name, self.value) self.create_finalizer(tx) # This is currently safe only because the only actual `ret_val`s returned # by the `_init_group` of existing optimizers are properties that are invariant # to the input tensors (e.g. dtype, layout). Changing these would trigger a # recompilation and hence never result in the wrong specialization of `ret_val`. return ConstantVariable.create(ret_val) except (ArgMappingException, GuardInstallException) as _: # trace normally if we can't map args or install guards correctly pass return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx: "InstructionTranslator", name): # Note: this allows us to intercept the call in call_method # in the typical case, we return a UserMethodVariable # which will directly inline if name in ("_init_group", "step"): return GetAttrVariable(self, name, source=AttrSource(self.source, name)) if name == "param_groups": from ..decorators import mark_static_address for group in self.value.param_groups: for p in group["params"]: mark_static_address(p) self._set_capturable(tx) return super().var_getattr(tx, name) def graph_break_if_pending_mutation(self, tx): # If there are pending mutations on a parameter (due to using closure) # then we need to graph break to allow the python version of the parameter # to update, so that running _init_group will initialize the states with # the correct values for g in self.value.param_groups: for p in g["params"]: side_effects = tx.output.side_effects variable = side_effects.id_to_variable.get(id(p), None) if variable and side_effects.has_pending_mutation(variable): from ..exc import Unsupported raise Unsupported("Pending mutation on parameter") def _set_capturable(self, tx): from . import LazyVariableTracker # We only set capturable if params are on cuda # and the state is not initialized def safe_to_set_capturable(group): all_uninitialized = True all_gpu = True for p in group.get("params", []): all_gpu &= p.is_cuda or p.is_xpu all_uninitialized &= p not in self.value.state return "capturable" in group and all_uninitialized and all_gpu # track indices to not set so we don't need to # in the variable tracker realize the whole state # we handle guarding the state specially for group in self.value.param_groups: if safe_to_set_capturable(group): group["capturable"] = True source = self.source and AttrSource(self.source, "param_groups") param_groups_vt = LazyVariableTracker.realize_all( VariableTracker.build(tx, self.value.param_groups, source) ) for param_group_vt in param_groups_vt.items: key = ConstDictVariable._HashableTracker( ConstantVariable.create("capturable") ) param_group_vt.items[key] = ConstantVariable.create(True) def get_python_args(self, *args, **kwargs): """Get python values equivalent to the variable tracker args""" def map_arg(arg): if isinstance(arg, ConstantVariable): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: return [] elif ( isinstance(arg, ConstDictVariable) and isinstance(arg.source, GetItemSource) and isinstance(arg.source.base, AttrSource) and arg.source.base.member == "param_groups" ): return self.value.param_groups[arg.source.index] raise ArgMappingException new_args = [map_arg(arg) for arg in args] new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} return new_args, new_kwargs # If users load an old state dictionary, # it's possible that step could be on the cpu # if this is the case, move it to the GPU # corresponding to the parameter # in most cases this is a no-op because the state is empty def move_step_if_cpu(self): for p, state in self.value.state.items(): if "step" in state and state["step"].is_cpu: state["step"] = state["step"].to(p.device) def map_sources_and_install_guards(self, tx): from ..decorators import mark_static_address from .lazy import LazyVariableTracker self.grad_to_source = {} self.tensor_to_source = {} # Tracing the _init_group is expensive. But we still have to insert the # necessary guards for _init_group. So, we manually handle insertion of # guards. We also want to mark all the tensors inside the state dict to # be static address. # Mark all the tensors in the state dict to be static address. This has # to be done first because the variable builder relies on the static # address annotation. def mark_static(x): mark_static_address(x) tree_map_only(torch.Tensor, mark_static, self.value.state) # Recursively realize the variable trackers for optim.state and # optim.param_groups, which recursively install the necessary guards. params_groups_source = self.source and AttrSource(self.source, "param_groups") param_groups_vt = LazyVariableTracker.realize_all( VariableTracker.build(tx, self.value.param_groups, params_groups_source) ) state_source = self.source and AttrSource(self.source, "state") state_vt = VariableTracker.build(tx, self.value.state, state_source) # We need to realize the top level state dict to populate # the guard locals state_vt.realize() tx.output.guard_on_key_order.add(state_source.name()) # Populate self.grad_to_source and self.tensor_to_source so that we can # manually update_list_args for group, group_vt in zip(self.value.param_groups, param_groups_vt.items): # we assume here that all params within a param group # are initialized similarly if len(group["params"]) > 0: for param in group["params"]: if param.grad is not None: key_index = None for i, k in enumerate(self.value.state.keys()): if k is param: key_index = i break if key_index: LazyVariableTracker.realize_all( VariableTracker.build( tx, self.value.state[param], DictGetItemSource( state_source, ConstDictKeySource(state_source, key_index), ), ) ) break params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) all_static = True non_static_grads = [] for p_ind, (p, p_vt) in enumerate( zip(group["params"], params_vt.unpack_var_sequence(tx)) ): param_source = p_vt.source self.tensor_to_source[p] = param_source grad_source = GradSource( param_source, "grad", ) if p.grad is not None: self.grad_to_source[p.grad] = grad_source if not _is_static_for_cudagraphs(p.grad): all_static = False non_static_grads.append(grad_source) else: install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) # Note: to avoid spam logs only warn if perf hint artifact is enabled # (NB: artifacts are only enabled at the debug or warning level) if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): non_static_grads = [src.name() for src in non_static_grads] perf_hint_log.warning( ( "Grad tensors %s will be copied during cudagraphs execution." "If using cudagraphs and the grad tensor addresses will be the same across runs," " use torch._dynamo.decorators.mark_static_address to elide this copy.", ), non_static_grads, ) # We have to again iterate over the state dict to collect the # tensor_to_source dict. This is used for the finalizer. for idx, (p, value) in enumerate(self.value.state.items()): p_state_source = DictGetItemSource( state_source, ConstDictKeySource(state_source, idx) ) tx.output.guard_on_key_order.add(p_state_source.name()) for inner_idx, (k, v) in enumerate(value.items()): if ( isinstance(v, torch.Tensor) and v not in self.grad_to_source and v not in self.tensor_to_source ): self.tensor_to_source[v] = DictGetItemSource( p_state_source, ConstDictKeySource(p_state_source, inner_idx) ) def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): """Wrap state tensor in a TensorVariable""" from ..decorators import mark_static_address # If we have a source for a tensor already use it, # if we have not seen a tensor before, stash and use a # global weak ref source, since it must be an optimizer tensor # that we have missed if tensor_value in self.tensor_to_source: # mark these tensors as static for cudagraphs mark_static_address(tensor_value) source = self.tensor_to_source[tensor_value] self.static_tensor_names.add(tx.output.module_key_name(source.name())) elif tensor_value in self.grad_to_source: source = self.grad_to_source[tensor_value] else: # mark these tensors as static for cudagraphs mark_static_address(tensor_value) global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) source = GlobalWeakRefSource(global_name) self.static_tensor_names.add(tx.output.module_key_name(source.name())) return VariableTracker.build(tx, tensor_value, source) def update_list_args( self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs ): """Update the args and kwargs to the traced optimizer call""" for arg, py_arg in zip(args, py_args): if isinstance(arg, ListVariable): assert isinstance(py_arg, list), ( "py_arg should be a list in optimizer variable" ) for i, val in enumerate(py_arg): tx.output.side_effects.mutation(arg) if isinstance(val, torch.Tensor): arg.items.append(self.wrap_tensor(tx, val)) else: source = arg.source and GetItemSource(arg.source, i) arg.items.append(VariableTracker.build(tx, val, source)) def create_finalizer(self, tx): names_to_delete = self.static_tensor_names value = self.value tc = tx.output.tracing_context def init_finalizer(gm): def clear_static_tensor_refs(): for name in names_to_delete: gm._buffers.pop(name, None) gm._parameters.pop(name, None) if tc.params_flat: tc.params_flat.clear() if tc.params_flat_unwrap_subclasses: tc.params_flat_unwrap_subclasses.clear() weakref.finalize(value, clear_static_tensor_refs) tx.output.add_graph_finalizer(init_finalizer)