4202 lines
161 KiB
Python
4202 lines
161 KiB
Python
# mypy: allow-untyped-defs
|
|
|
|
"""
|
|
Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format.
|
|
|
|
This module implements the bytecode-level tracing system that allows TorchDynamo to analyze
|
|
and transform Python code. It converts Python bytecode instructions into a symbolic format
|
|
that tracks the flow of tensors and other values through the program.
|
|
|
|
Key components:
|
|
- InstructionTranslatorBase: Base class for converting bytecode to symbolic execution
|
|
- InstructionTranslator: Main translator for function bytecode
|
|
- InliningInstructionTranslator: Handles inlining of called functions
|
|
- SpeculationLog: Manages state for speculative execution and rollback
|
|
|
|
The symbolic conversion process handles:
|
|
- Control flow (loops, conditionals, etc.)
|
|
- Function inlining and call stack management
|
|
- Tracking of program values and side effects
|
|
- Graph breaks and resumption points
|
|
- Exception handling and stack frame management
|
|
|
|
This is a core part of TorchDynamo's tracing system that enables ahead-of-time
|
|
optimization of PyTorch programs.
|
|
"""
|
|
|
|
import collections
|
|
import collections.abc
|
|
import contextlib
|
|
import copy
|
|
import dataclasses
|
|
import dis
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import itertools
|
|
import linecache
|
|
import logging
|
|
import operator
|
|
import re
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
import types
|
|
import typing
|
|
import weakref
|
|
from typing import Any, Callable, cast, NoReturn, Optional, Union
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._logging
|
|
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
|
|
from torch._guards import tracing, TracingContext
|
|
from torch.fx.experimental.symbolic_shapes import guard_bool
|
|
from torch.utils._functools import cache_method
|
|
|
|
from . import (
|
|
config,
|
|
exc,
|
|
graph_break_hints,
|
|
logging as torchdynamo_logging,
|
|
trace_rules,
|
|
variables,
|
|
)
|
|
from .bytecode_analysis import (
|
|
get_indexof,
|
|
JUMP_OPNAMES,
|
|
livevars_analysis,
|
|
propagate_line_nums,
|
|
)
|
|
from .bytecode_transformation import (
|
|
cleaned_instructions,
|
|
create_call_function,
|
|
create_instruction,
|
|
create_jump_absolute,
|
|
create_swap,
|
|
get_code_keys,
|
|
Instruction,
|
|
is_generator,
|
|
unique_id,
|
|
)
|
|
from .code_context import code_context
|
|
from .codegen import PyCodegen
|
|
from .exc import (
|
|
ArgsMismatchError,
|
|
BackendCompilerFailed,
|
|
collapse_resume_frames,
|
|
format_graph_break_message,
|
|
get_stack_above_dynamo,
|
|
unimplemented_v2,
|
|
Unsupported,
|
|
)
|
|
from .funcname_cache import get_funcname
|
|
from .guards import GuardBuilder, install_guard
|
|
from .output_graph import GraphCompileReason, OutputGraph
|
|
from .replay_record import DummyModule, ExecutionRecorder
|
|
from .resume_execution import ContinueExecutionCache, ReenterWith
|
|
from .source import (
|
|
AttrSource,
|
|
DictGetItemSource,
|
|
GlobalSource,
|
|
GlobalWeakRefSource,
|
|
LocalCellSource,
|
|
LocalSource,
|
|
Source,
|
|
)
|
|
from .trace_rules import is_builtin_constant, is_forbidden
|
|
from .utils import (
|
|
counters,
|
|
get_fake_value,
|
|
get_instruction_source_311,
|
|
get_metrics_context,
|
|
graph_break_dup_warning_checker,
|
|
istype,
|
|
LazyString,
|
|
proxy_args_kwargs,
|
|
)
|
|
from .variables.base import typestr, ValueMutationNew, VariableTracker
|
|
from .variables.builder import FrameStateSizeEntry, wrap_fx_proxy
|
|
from .variables.builtin import BuiltinVariable
|
|
from .variables.constant import ConstantVariable
|
|
from .variables.ctx_manager import (
|
|
ContextWrappingVariable,
|
|
GenericContextWrappingVariable,
|
|
WithExitFunctionVariable,
|
|
)
|
|
from .variables.dicts import ConstDictVariable, SetVariable
|
|
from .variables.functions import (
|
|
BaseUserFunctionVariable,
|
|
LocalGeneratorFunctionVariable,
|
|
LocalGeneratorObjectVariable,
|
|
NestedUserFunctionVariable,
|
|
SkipFunctionVariable,
|
|
UserFunctionVariable,
|
|
UserMethodVariable,
|
|
)
|
|
from .variables.iter import MAX_ITERATOR_LIMIT
|
|
from .variables.lazy import LazyVariableTracker
|
|
from .variables.lists import (
|
|
BaseListVariable,
|
|
ListIteratorVariable,
|
|
ListVariable,
|
|
SliceVariable,
|
|
TupleVariable,
|
|
)
|
|
from .variables.misc import (
|
|
CellVariable,
|
|
ExceptionVariable,
|
|
GetAttrVariable,
|
|
NullVariable,
|
|
PythonModuleVariable,
|
|
UnknownVariable,
|
|
)
|
|
from .variables.nn_module import NNModuleVariable
|
|
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
|
from .variables.torch_function import (
|
|
SymbolicTorchFunctionState,
|
|
TorchFunctionModeVariable,
|
|
)
|
|
from .variables.user_defined import (
|
|
RemovableHandleVariable,
|
|
UserDefinedClassVariable,
|
|
UserDefinedExceptionClassVariable,
|
|
UserDefinedExceptionObjectVariable,
|
|
UserDefinedObjectVariable,
|
|
)
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
|
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
|
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
|
|
trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode")
|
|
tls = threading.local()
|
|
compare_op_handlers: dict[str, Any] = {
|
|
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
|
|
}
|
|
handle_contains = BuiltinVariable(operator.contains).call_function
|
|
handle_not = BuiltinVariable(operator.not_).call_function
|
|
compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
|
|
tx, [*reversed(args)], {}
|
|
)
|
|
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
|
|
tx, [handle_contains(tx, [*reversed(args)], {})], {}
|
|
)
|
|
|
|
|
|
PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml"
|
|
|
|
|
|
@functools.cache
|
|
def _import_module(name: str) -> types.ModuleType:
|
|
"""
|
|
Import the named module and cache the result. importlib.import_module()
|
|
seems to do some filesystem checking to validate the name so not caching
|
|
this can be slow.
|
|
"""
|
|
return importlib.import_module(name)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SpeculationEntry:
|
|
filename: str
|
|
lineno: int
|
|
instruction_pointer: int
|
|
inst: Instruction # for debugging only
|
|
failed: bool = False
|
|
reason: Optional[GraphCompileReason] = None
|
|
|
|
def fail_and_restart_analysis(self):
|
|
"""
|
|
Start tracing of the current frame over again, and don't take this branch.
|
|
"""
|
|
self.failed = True
|
|
if self.reason is not None:
|
|
restart_reason = self.reason.reason
|
|
else:
|
|
restart_reason = "Unknown fail_and_restart_analysis"
|
|
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SpeculationLog:
|
|
"""
|
|
SpeculationLog replaces the prior copy_graphstate/restore_graphstate
|
|
checkpointing. Rather than saving/restoring state, we restart the
|
|
dynamo conversion process over from the beginning -- but when we
|
|
hit the start of the speculation that failed, we instead generate
|
|
a graph break.
|
|
"""
|
|
|
|
entries: list[SpeculationEntry] = dataclasses.field(default_factory=list)
|
|
index: int = 0
|
|
|
|
def restart(self):
|
|
self.index = 0
|
|
|
|
def clear(self):
|
|
self.entries.clear()
|
|
self.index = 0
|
|
|
|
def next(
|
|
self, filename: str, lineno: int, instruction_pointer, inst
|
|
) -> SpeculationEntry:
|
|
"""
|
|
Lookup or create a SpeculationEntry() that is shared across
|
|
RestartAnalysis calls. Args are used only for debug checks.
|
|
"""
|
|
if len(self.entries) == self.index:
|
|
self.entries.append(
|
|
SpeculationEntry(filename, lineno, instruction_pointer, inst)
|
|
)
|
|
entry = self.entries[self.index]
|
|
prev_entry_msg = ""
|
|
if self.index != 0:
|
|
prev_entry = self.entries[self.index - 1]
|
|
prev_entry_msg = (
|
|
f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
|
|
f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
|
|
)
|
|
if not (
|
|
entry.instruction_pointer == instruction_pointer
|
|
and entry.filename == filename
|
|
and entry.lineno == lineno
|
|
):
|
|
raise SpeculationLogDivergence(
|
|
f"""
|
|
SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
|
|
- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
|
|
- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
|
|
{prev_entry_msg}
|
|
There are two usual reasons why this may have occured:
|
|
- When Dynamo analysis restarted, the second run took a different path than
|
|
the first. If this occurred, the previous instruction is the critical instruction that
|
|
behaved differently.
|
|
- Speculation entries are only added under certain conditions (as seen in
|
|
step()), e.g., there must exist operators in the graph; those conditions may
|
|
have changed on restart.
|
|
|
|
If this divergence was intentional, clear the speculation log before restarting (do NOT
|
|
do this for graph breaks, you will infinite loop).
|
|
|
|
Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
|
|
"""
|
|
)
|
|
self.index += 1
|
|
return entry
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LocalState:
|
|
automatic_dynamic: dict[str, FrameStateSizeEntry] = dataclasses.field(
|
|
default_factory=dict
|
|
)
|
|
|
|
def render(self) -> str:
|
|
return "\n".join(
|
|
f"{k}: {v.render()}" for k, v in self.automatic_dynamic.items()
|
|
)
|
|
|
|
|
|
# Mutable box that is shared across restarts
|
|
@dataclasses.dataclass
|
|
class DistributedState:
|
|
compile_pg: Any
|
|
local_state: LocalState
|
|
all_states: Optional[list[LocalState]] = None
|
|
|
|
|
|
class TensorifyState:
|
|
# These are the set of string symfloats names (eg. "zf0") that we collect
|
|
# from the tensorify_python_scalars.py joint fx pass to inform us about
|
|
# which float inputs we should specialize when we restart analysis.
|
|
force_specializations: set[str] = set()
|
|
|
|
@classmethod
|
|
def specialize(cls, index: str) -> None:
|
|
cls.force_specializations.add(index)
|
|
|
|
@classmethod
|
|
def should_specialize(cls, index: str) -> bool:
|
|
return index in cls.force_specializations
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls.force_specializations.clear()
|
|
|
|
@classmethod
|
|
def empty(cls) -> bool:
|
|
return len(cls.force_specializations) == 0
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _step_logger():
|
|
return torchdynamo_logging.get_step_logger(log)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"):
|
|
# When reconstructing a generator after a graph break, we advance it until
|
|
# it is fully exhausted. This process adds new entries to the speculation
|
|
# log that were not previously observed. Without temporarily clearing the
|
|
# speculation log, this could lead to a divergence error.
|
|
|
|
entries = tx.speculation_log.entries
|
|
index = tx.speculation_log.index
|
|
try:
|
|
tx.speculation_log.entries = []
|
|
tx.speculation_log.index = 0
|
|
yield
|
|
finally:
|
|
tx.speculation_log.entries = entries
|
|
tx.speculation_log.index = index
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"):
|
|
try:
|
|
tmp = tx.output.should_exit
|
|
tx.output.should_exit = False
|
|
yield
|
|
finally:
|
|
tx.output.should_exit = tmp
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BlockStackEntry:
|
|
# Current instruction that pushes something to block_stack
|
|
inst: Instruction
|
|
target: Instruction
|
|
stack_index: int
|
|
with_context: Optional[
|
|
Union[ContextWrappingVariable, GenericContextWrappingVariable]
|
|
] = None
|
|
|
|
def can_restore(self):
|
|
return self.with_context is not None
|
|
|
|
def resume_fn(self):
|
|
assert self.stack_index is not None
|
|
if (
|
|
self.with_context
|
|
and hasattr(self.with_context, "target_values")
|
|
and self.with_context.target_values
|
|
):
|
|
return ReenterWith(
|
|
self.stack_index - 1, tuple(self.with_context.target_values)
|
|
)
|
|
else:
|
|
return ReenterWith(self.stack_index - 1)
|
|
|
|
def exit(self, tx, is_graph_break):
|
|
assert self.with_context is not None
|
|
if (
|
|
is_graph_break and self.with_context.exit_on_graph_break()
|
|
) or not is_graph_break:
|
|
return self.with_context.exit(tx)
|
|
|
|
|
|
class SpeculationLogDivergence(AssertionError):
|
|
pass
|
|
|
|
|
|
class ReturnValueOp(Exception):
|
|
pass
|
|
|
|
|
|
class YieldValueOp(Exception):
|
|
"""
|
|
Signal to the symbolic tracer to stop and return control flow to the
|
|
caller
|
|
"""
|
|
|
|
|
|
def stack_op(fn: typing.Callable[..., object]):
|
|
nargs = len(inspect.signature(fn).parameters)
|
|
fn_var = BuiltinVariable(fn)
|
|
|
|
@functools.wraps(fn)
|
|
def impl(self: "InstructionTranslator", inst: Instruction):
|
|
self.push(fn_var.call_function(self, self.popn(nargs), {}))
|
|
|
|
return impl
|
|
|
|
|
|
def _detect_and_normalize_assert_statement(
|
|
self: "InstructionTranslatorBase",
|
|
truth_fn: typing.Callable[[object], bool],
|
|
push: bool,
|
|
):
|
|
# Detect if this jump instruction is assert and normalize the assert
|
|
# by pushing dummy error message when nothing is given.
|
|
#
|
|
# Python 3.9 assertion is in following format:
|
|
# 18 POP_JUMP_IF_TRUE 28
|
|
# 20 LOAD_ASSERTION_ERROR
|
|
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
|
|
# 24 CALL_FUNCTION 1 -> optional instruction
|
|
# 26 RAISE_VARARGS
|
|
#
|
|
# Python 3.8 assertion is in following format:
|
|
# 18 POP_JUMP_IF_TRUE 28
|
|
# 20 LOAD_GLOBAL 0 (Assertion type)
|
|
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
|
|
# 24 CALL_FUNCTION 1 -> optional instruction
|
|
# 26 RAISE_VARARGS 1
|
|
|
|
if (truth_fn is not operator.truth) or push:
|
|
return False
|
|
|
|
assert isinstance(self.instruction_pointer, int)
|
|
current_instruction_pointer = self.instruction_pointer
|
|
inst = self.instructions[current_instruction_pointer]
|
|
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
|
|
if inst.opname != "LOAD_ASSERTION_ERROR":
|
|
return False
|
|
|
|
current_instruction_pointer += 1
|
|
|
|
# Use dummy error message if its hard to extract
|
|
error_msg = "assertion error"
|
|
|
|
inst = self.instructions[current_instruction_pointer]
|
|
# DETECT RAISE_VARARGS or LOAD CONST
|
|
if inst.opname == "LOAD_CONST":
|
|
if not isinstance(inst.argval, str):
|
|
return False
|
|
error_msg = inst.argval
|
|
|
|
# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
|
|
# (PRECALL for Python 3.11, CALL for Python 3.12+)
|
|
current_instruction_pointer += 1
|
|
inst = self.instructions[current_instruction_pointer]
|
|
if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"):
|
|
return False
|
|
|
|
# for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS
|
|
# for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS
|
|
current_instruction_pointer += 1
|
|
if inst.opname == "PRECALL":
|
|
current_instruction_pointer += 1
|
|
inst = self.instructions[current_instruction_pointer]
|
|
|
|
if inst.opname != "RAISE_VARARGS":
|
|
return False
|
|
|
|
self.push(ConstantVariable.create(error_msg))
|
|
|
|
return True
|
|
|
|
|
|
explain = False
|
|
|
|
|
|
def log_graph_break(code_options, reason="", exc_info=False, user_stack=None):
|
|
if user_stack is None:
|
|
user_stack = torch._guards.TracingContext.extract_stack()
|
|
|
|
try:
|
|
frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
|
|
except IndexError:
|
|
# first instruction
|
|
frame_loc = (
|
|
code_options["co_filename"],
|
|
code_options["co_firstlineno"],
|
|
)
|
|
|
|
stack_above_dynamo_formatted = ""
|
|
if config.verbose:
|
|
stack_above_dynamo = get_stack_above_dynamo()
|
|
stack_above_dynamo_formatted = "".join(
|
|
traceback.format_list(stack_above_dynamo)
|
|
)
|
|
else:
|
|
user_stack = get_stack_above_dynamo() + user_stack
|
|
user_stack = collapse_resume_frames(user_stack)
|
|
user_stack_formatted = "".join(traceback.format_list(user_stack))
|
|
user_stack_trace = (
|
|
f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n"
|
|
f"Graph Break Reason: {reason}\n"
|
|
"User code traceback:\n"
|
|
)
|
|
|
|
if config.verbose:
|
|
user_stack_trace += (
|
|
f"{stack_above_dynamo_formatted}\n"
|
|
"========== most recent `torch.compile` tracing attempt started here ==========\n\n"
|
|
f"{user_stack_formatted}\n"
|
|
"NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! "
|
|
"This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another "
|
|
"Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python "
|
|
"function, which Dynamo intercepts as a top-level frame.\n"
|
|
)
|
|
else:
|
|
user_stack_trace += str(user_stack_formatted)
|
|
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_graph_break_reason",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc() if exc_info else ''}",
|
|
)
|
|
|
|
# torch._dynamo.explain() formats this a little nicer, and presents a slightly
|
|
# more actionable user code pointer
|
|
if (
|
|
graph_break_log.isEnabledFor(logging.DEBUG)
|
|
and not explain
|
|
and graph_break_dup_warning_checker.add(frame_loc)
|
|
):
|
|
# This log line MUST contain the string "Graph break in user code",
|
|
# This log line is exercised from
|
|
# python test/dynamo/test_exc.py -k test_graph_break_log
|
|
graph_break_log.debug(
|
|
user_stack_trace,
|
|
)
|
|
else:
|
|
# This log line MUST not contain the string "Graph break in user code",
|
|
# exercised by
|
|
# python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
|
|
graph_break_log.debug(
|
|
"Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s",
|
|
frame_loc[0],
|
|
frame_loc[1],
|
|
reason,
|
|
)
|
|
|
|
|
|
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
|
# graph break message fields for data dependent branching
|
|
_gb_type = "Data-dependent branching"
|
|
_explanation = (
|
|
"Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). "
|
|
"Dynamo does not support tracing dynamic control flow."
|
|
)
|
|
_hints = [
|
|
*graph_break_hints.FUNDAMENTAL,
|
|
"Use `torch.cond` to express dynamic control flow.",
|
|
]
|
|
|
|
def jump_graph_break(self, inst, value, extra_msg=""):
|
|
log_graph_break(
|
|
self.code_options,
|
|
reason=format_graph_break_message(
|
|
gb_type=_gb_type,
|
|
context=f"attempted to jump with {value}",
|
|
explanation=_explanation,
|
|
hints=_hints,
|
|
),
|
|
)
|
|
if not self.should_compile_partial_graph():
|
|
unimplemented_v2(
|
|
gb_type="Should not compile partial graph (data-dependent branching)",
|
|
context="",
|
|
explanation="Dynamo has determined when encountering data-dependent "
|
|
"branching (e.g. `if my_tensor.item() > 0:`) that it should not "
|
|
"compile the partial graph.",
|
|
hints=[],
|
|
)
|
|
# compile a partial subgraph prefix then jump into user code
|
|
if self.maybe_has_backedge():
|
|
msg = (
|
|
"Skipping frame because there is a graph break in a for/while loop\n"
|
|
f"{self.frame_summary()}"
|
|
)
|
|
log.info(msg)
|
|
raise exc.SkipFrame(msg)
|
|
|
|
self.push(value)
|
|
log.debug("generic_jump triggered compile")
|
|
self.output.compile_subgraph(
|
|
self,
|
|
reason=GraphCompileReason(
|
|
f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()]
|
|
),
|
|
)
|
|
self.pop()
|
|
|
|
if_next = self.create_call_resume_at(self.next_instruction)
|
|
if push:
|
|
self.push(value)
|
|
if_jump = self.create_call_resume_at(inst.target)
|
|
|
|
if sys.version_info >= (3, 13):
|
|
# 3.13 requires stack[-1] to be bool type
|
|
self.output.add_output_instructions([create_instruction("TO_BOOL")])
|
|
|
|
jump_inst = create_instruction(inst.opname, target=if_jump[0])
|
|
jump_inst.copy_positions(inst)
|
|
self.output.add_output_instructions([jump_inst] + if_next + if_jump)
|
|
|
|
def inner(self: "InstructionTranslatorBase", inst: Instruction):
|
|
value: VariableTracker = self.pop()
|
|
if (
|
|
config.rewrite_assert_with_torch_assert
|
|
and _detect_and_normalize_assert_statement(self, truth_fn, push)
|
|
):
|
|
error_msg: VariableTracker = self.pop()
|
|
# Skip over things like `assert True`
|
|
if value.is_python_constant():
|
|
if bool(value.as_python_constant()):
|
|
return self.jump(inst)
|
|
else:
|
|
jump_graph_break(self, inst, value)
|
|
|
|
# TODO maybe should respect DtoH sync intention of users later??
|
|
# Manually insert torch._assert_async instead of python assert and jump over
|
|
# assert related instructions as we don't need them anymore.
|
|
|
|
# if we see Tensor as assert statement, no need to call scalar_tensor
|
|
if isinstance(value, TensorVariable):
|
|
self.output.create_proxy(
|
|
"call_function",
|
|
torch._assert_async,
|
|
*proxy_args_kwargs((value, error_msg), {}),
|
|
)
|
|
self.jump(inst)
|
|
return
|
|
|
|
if isinstance(value, SymNodeVariable):
|
|
# if the assertion is normal shape expression.
|
|
# just install guard and bail out.
|
|
sym_expr = value.sym_num
|
|
if not isinstance(sym_expr, torch.SymBool):
|
|
sym_expr = sym_expr != 0
|
|
|
|
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
|
|
if not result:
|
|
unimplemented_v2(
|
|
gb_type="Assertion failed on symbolic shapes",
|
|
context=str(sym_expr),
|
|
explanation="",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
self.jump(inst)
|
|
return
|
|
|
|
scalar_to_tensor_proxy = self.output.create_proxy(
|
|
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
|
|
)
|
|
|
|
scalar_to_tensor = wrap_fx_proxy(
|
|
self,
|
|
scalar_to_tensor_proxy,
|
|
example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
|
|
)
|
|
|
|
self.output.create_proxy(
|
|
"call_function",
|
|
torch._assert_async,
|
|
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
|
|
)
|
|
self.jump(inst)
|
|
return
|
|
|
|
if value.is_python_constant():
|
|
# ConstDictVariable is optimized to be very lazy about insertion of
|
|
# guards, so we have to manually insert a SEQUENCE_LENGTH guard
|
|
# here.
|
|
if isinstance(value, ConstDictVariable) and value.source:
|
|
install_guard(value.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
|
|
if truth_fn(value.as_python_constant()):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif (
|
|
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
|
|
):
|
|
jump_graph_break(self, inst, value)
|
|
elif isinstance(value, NNModuleVariable):
|
|
# Equivalent of "self.nn_module is not None"
|
|
mod = self.output.get_submodule(value.module_key)
|
|
if truth_fn(mod):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif isinstance(value, UserDefinedObjectVariable):
|
|
try:
|
|
x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
|
|
except exc.ObservedAttributeError:
|
|
exc.handle_observed_exception(self)
|
|
# if __bool__ is missing, trying __len__ to infer a truth value.
|
|
try:
|
|
x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
|
|
except exc.ObservedAttributeError:
|
|
exc.handle_observed_exception(self)
|
|
x = None
|
|
|
|
# __bool__ or __len__ is function
|
|
if isinstance(x, UserMethodVariable):
|
|
result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment]
|
|
if isinstance(result, ConstantVariable) and isinstance(
|
|
result.value, (bool, int)
|
|
):
|
|
if truth_fn(result.value):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif isinstance(result, SymNodeVariable):
|
|
if result.evaluate_expr():
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Data-dependent branching with non-constant __bool__",
|
|
context=f"method: {x}, result: {result}",
|
|
explanation="Attempted to perform data-dependent branching on a user-defined "
|
|
"object with a __bool__ method that did not return a constant.",
|
|
hints=[],
|
|
)
|
|
# __bool__ or __len__ is non-function or not existed in the user defined object
|
|
else:
|
|
if truth_fn(True):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
|
|
self
|
|
):
|
|
if truth_fn(len(value.unpack_var_sequence(self))):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif isinstance(value, SymNodeVariable):
|
|
try:
|
|
# if the user is branching on a SymBool, guard on it
|
|
# if the user has code like:
|
|
# if size:
|
|
# ...
|
|
# then they are just testing truthiness: guard that the expr != 0
|
|
if isinstance(value.sym_num, torch.SymBool):
|
|
eval_result = value.evaluate_expr(self.output)
|
|
else:
|
|
eval_result = guard_bool(value.sym_num != 0)
|
|
except exc.UserError as e:
|
|
if self.should_compile_partial_graph():
|
|
return jump_graph_break(self, inst, value, extra_msg=f"\n{e}")
|
|
raise
|
|
if truth_fn(eval_result):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
elif isinstance(value, variables.BackwardHookVariable):
|
|
if truth_fn(True):
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
else:
|
|
from .source import is_constant_source
|
|
|
|
if value.source is not None and is_constant_source(value.source):
|
|
if truth_fn(value.get_real_value()): # type: ignore[attr-defined]
|
|
if push:
|
|
self.push(value)
|
|
self.jump(inst)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type=_gb_type,
|
|
context=f"attempted to jump with {value}",
|
|
explanation=_explanation,
|
|
hints=_hints,
|
|
)
|
|
|
|
return inner
|
|
|
|
|
|
def break_graph_if_unsupported(*, push):
|
|
def decorator(inner_fn):
|
|
@functools.wraps(inner_fn)
|
|
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
|
|
speculation = self.speculate()
|
|
if speculation.failed:
|
|
assert speculation.reason is not None
|
|
return handle_graph_break(self, inst, speculation.reason)
|
|
try:
|
|
return inner_fn(self, inst)
|
|
except Unsupported as excp:
|
|
if self.active_generic_context_managers:
|
|
# We don't support graph break under GenericContextWrappingVariable,
|
|
# If there is, we roll back to the checkpoint and fall back.
|
|
excp.remove_from_stats()
|
|
unimplemented_v2(
|
|
gb_type="Graph break under GenericContextWrappingVariable",
|
|
context=f"Active generic context managers: {self.active_generic_context_managers}",
|
|
explanation="Attempted to graph break in an active context manager(s) that doesn't support graph breaking.",
|
|
hints=[
|
|
"Move the offending context manager(s) to outside the compiled region.",
|
|
*graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
|
|
],
|
|
from_exc=excp,
|
|
)
|
|
|
|
if isinstance(excp, exc.UncapturedHigherOrderOpError):
|
|
raise
|
|
|
|
if not self.should_compile_partial_graph():
|
|
raise
|
|
|
|
log_graph_break(
|
|
self.code_options,
|
|
exc_info=True,
|
|
reason=str(excp),
|
|
user_stack=excp.real_stack,
|
|
)
|
|
|
|
if self.maybe_has_backedge():
|
|
msg = (
|
|
"Skipping frame because there is a graph break in a for/while loop\n"
|
|
f"{self.frame_summary()}"
|
|
)
|
|
log.info(msg)
|
|
raise exc.SkipFrame(msg) from excp
|
|
|
|
excp.remove_from_stats()
|
|
excp.add_to_stats("graph_break")
|
|
speculation.reason = GraphCompileReason(excp.msg, excp.real_stack)
|
|
speculation.fail_and_restart_analysis()
|
|
|
|
def handle_graph_break(
|
|
self: "InstructionTranslatorBase",
|
|
inst: Instruction,
|
|
reason: GraphCompileReason,
|
|
):
|
|
self.output.compile_subgraph(self, reason=reason)
|
|
cg = PyCodegen(self)
|
|
cleanup: list[Instruction] = []
|
|
# Reconstruct the context variable CLASS in the block stack
|
|
for b in self.block_stack:
|
|
# Don't exit any modes we have entered,
|
|
# output bytecode will mutate the tf mode stack accordingly
|
|
if isinstance(b.with_context, TorchFunctionModeVariable):
|
|
cg.extend_output(
|
|
b.resume_fn().try_except_torch_function_mode(
|
|
cg.code_options, cleanup
|
|
)
|
|
)
|
|
continue
|
|
assert b.with_context is not None
|
|
assert isinstance(b.with_context, (ContextWrappingVariable))
|
|
b.with_context.reconstruct_type(cg)
|
|
cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
|
|
self.output.add_output_instructions(cg.get_instructions())
|
|
del cg
|
|
|
|
if sys.version_info >= (3, 11) and inst.opname == "CALL":
|
|
kw_names = (
|
|
self.kw_names.as_python_constant()
|
|
if self.kw_names is not None
|
|
else ()
|
|
)
|
|
if len(kw_names) > 0:
|
|
# KW_NAMES no longer used in 3.13
|
|
assert sys.version_info < (3, 13)
|
|
self.output.add_output_instructions(
|
|
[create_instruction("KW_NAMES", argval=kw_names)]
|
|
)
|
|
call_insts = create_call_function(inst.arg, False)
|
|
call_insts[-1].copy_positions(inst)
|
|
self.output.add_output_instructions(call_insts)
|
|
else:
|
|
# copy instruction, but without exception table data
|
|
assert inst.target is None
|
|
inst_copy = copy.copy(inst)
|
|
inst_copy.exn_tab_entry = None
|
|
self.output.add_output_instructions([inst_copy])
|
|
|
|
self.output.add_output_instructions(cleanup)
|
|
|
|
if (
|
|
sys.version_info >= (3, 11)
|
|
and sys.version_info < (3, 12)
|
|
and inst.opname == "CALL"
|
|
):
|
|
# stack effect for PRECALL + CALL is split between the two instructions
|
|
stack_effect = dis.stack_effect(
|
|
dis.opmap["PRECALL"], inst.arg
|
|
) + dis.stack_effect(dis.opmap["CALL"], inst.arg)
|
|
else:
|
|
stack_effect = dis.stack_effect(inst.opcode, inst.arg)
|
|
self.popn(push - stack_effect)
|
|
|
|
for _ in range(push):
|
|
self.push(UnknownVariable())
|
|
self.output.add_output_instructions(
|
|
self.create_call_resume_at(self.next_instruction)
|
|
)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
class BytecodeDistpatchTableMeta(type):
|
|
"""Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()"""
|
|
|
|
def __init__(cls, name, bases, dct) -> None:
|
|
super().__init__(name, bases, dct)
|
|
|
|
def _missing(opname, *args):
|
|
unimplemented_v2(
|
|
gb_type="Missing bytecode handler",
|
|
context=f"{opname} with args {args}",
|
|
explanation=f"Dynamo does not know how to handle the bytecode instruction `{opname}`.",
|
|
hints=[
|
|
f"Do not trace code that produces the `{opname}` bytecode instruction "
|
|
"(see https://docs.python.org/3/library/dis.html for bytecode semantics).",
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
dispatch_table = {
|
|
op: getattr(cls, opname, functools.partial(_missing, opname))
|
|
for opname, op in dis.opmap.items()
|
|
}
|
|
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ExceptionStack:
|
|
"""
|
|
Exception stack that it is shared among all InstructionTranslator instances
|
|
"""
|
|
|
|
# Exception handling in CPython is a bit confusing and some of the bytecode
|
|
# have a slightly different behavior than what is is documented. While reading
|
|
# the documentation, is important to notice that the terms "current exception"
|
|
# and "stack" sometimes refers to a C variable with the same name and the
|
|
# exception stack, respectively.
|
|
#
|
|
# The lifetime of an exception is (Python 3.11+):
|
|
# + tx._raise_exception_variable(...) := sets the current_exception variable
|
|
# + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
|
|
# + POP_EXCEPT := pops TOS from the *exception stack*
|
|
|
|
_exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list)
|
|
_current_exception: Optional[VariableTracker] = dataclasses.field(default=None)
|
|
|
|
def clear_current_exception(self):
|
|
self._current_exception = None
|
|
|
|
def set_current_exception(self, val):
|
|
self._set_context_and_break_context_reference_cycle(val)
|
|
self._current_exception = val
|
|
|
|
def move_current_exception_to_stack(self):
|
|
assert self._current_exception is not None
|
|
self.append(self._current_exception)
|
|
self.clear_current_exception()
|
|
|
|
def get_current_exception(self):
|
|
assert self._current_exception is not None
|
|
return self._current_exception
|
|
|
|
def _set_context_recursive(self, val, prev_idx):
|
|
if (ctx := val.__context__) and type(ctx) is not ConstantVariable:
|
|
return val
|
|
if len(self._exc_stack) + prev_idx > 0:
|
|
prev = self._exc_stack[prev_idx]
|
|
self._set_context_recursive(prev, prev_idx - 1)
|
|
val.set_context(prev)
|
|
return val
|
|
|
|
def _break_context_reference_cycle(self, val):
|
|
# See test_exceptions::test_raise_does_not_create_context_chain_cycle
|
|
# Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
|
|
# As noted on CPython, this is O(chain length) but the context chains
|
|
# are usually very small
|
|
o = slow_o = val
|
|
slow_update_toggle = False # floyd's algorithm for detecting cycle
|
|
while True:
|
|
context = o.__context__
|
|
if type(context) is ConstantVariable: # context not set
|
|
break
|
|
|
|
if context is val:
|
|
o.set_context(ConstantVariable(None))
|
|
break
|
|
|
|
o = context
|
|
if o is slow_o:
|
|
# pre-existing cycle - all exceptions on the path were
|
|
# visited and checked
|
|
break
|
|
|
|
if slow_update_toggle:
|
|
slow_o = slow_o.__context__ # visited all exceptions
|
|
slow_update_toggle = not slow_update_toggle
|
|
|
|
def _set_context_and_break_context_reference_cycle(self, val):
|
|
# set Exception.__context__
|
|
self._set_context_recursive(val, len(self._exc_stack) - 1)
|
|
self._break_context_reference_cycle(val)
|
|
|
|
def pop(self):
|
|
return self._exc_stack.pop()
|
|
|
|
def append(self, val):
|
|
self._exc_stack.append(val)
|
|
|
|
def __len__(self):
|
|
return len(self._exc_stack)
|
|
|
|
def __getitem__(self, index):
|
|
return self._exc_stack[index]
|
|
|
|
def __str__(self):
|
|
return f"{self._exc_stack=} - {self._current_exception=}"
|
|
|
|
__repr__ = __str__
|
|
|
|
|
|
class InstructionTranslatorBase(
|
|
metaclass=BytecodeDistpatchTableMeta,
|
|
):
|
|
output: OutputGraph
|
|
symbolic_locals: dict[str, VariableTracker]
|
|
symbolic_globals: dict[str, VariableTracker]
|
|
symbolic_torch_function_state: SymbolicTorchFunctionState
|
|
stack: list[VariableTracker]
|
|
instruction_pointer: Optional[int]
|
|
current_instruction: Instruction
|
|
block_stack: list[BlockStackEntry]
|
|
lineno: int
|
|
kw_names: Optional[ConstantVariable]
|
|
accept_prefix_inst: bool
|
|
prefix_insts: list[Instruction]
|
|
inline_depth: int
|
|
inconsistent_side_effects: bool
|
|
current_speculation: Optional[SpeculationEntry]
|
|
dispatch_table: list[Any]
|
|
exn_vt_stack: ExceptionStack
|
|
exec_recorder: Optional[ExecutionRecorder]
|
|
strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
|
|
start_point: Optional[int]
|
|
|
|
def mark_inconsistent_side_effects(self):
|
|
"""
|
|
InstructionTranslator has encountered instructions which may cause
|
|
dynamo to see a different version of history from eager
|
|
See: https://github.com/pytorch/pytorch/issues/110765
|
|
"""
|
|
self.inconsistent_side_effects = True
|
|
|
|
def maybe_has_backedge(self):
|
|
# This function employs a heuristic. It does not reliably detect a backedge.
|
|
# The heuristic is straightforward: starting from the current instruction and
|
|
# continuing to the end, if any jump instruction targets an instruction before
|
|
# the current one, there might be a backedge.
|
|
|
|
# Python 3.12 introduced changes to bytecode that group common paths in
|
|
# blockstacks (with or try...else) and allow for early returns. Consequently,
|
|
# there can be multiple RETURN_VALUE instructions. Another heuristic is to
|
|
# halt detection upon encountering the first RETURN_VALUE or RETURN_CONST.
|
|
|
|
# These heuristics can result in both false positives and negatives, but
|
|
# in either case, the Dynamo code remains valid. For false positives
|
|
# (where an edge is incorrectly marked as a backedge), Dynamo will
|
|
# perform a SkipFrame instead of potentially applying optimizations. For
|
|
# false negatives (where an edge that should be marked as a backedge
|
|
# isn't), multiple graphs may be generated if there's a break in the
|
|
# graph during a for loop. In general, its better to have fewer false
|
|
# negatives so that Dynamo does not skip the whole frame.
|
|
|
|
cur_offset = self.current_instruction.offset
|
|
assert self.instruction_pointer is not None
|
|
for inst in self.instructions[self.instruction_pointer :]:
|
|
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
|
return False
|
|
if inst.opname in JUMP_OPNAMES:
|
|
jump_offset = inst.argval
|
|
if jump_offset < cur_offset:
|
|
return True
|
|
return False
|
|
|
|
def cellvars(self):
|
|
if not hasattr(self, "_cellvars"):
|
|
self._cellvars = tuple(self.code_options["co_cellvars"] or [])
|
|
# An inlined function might depend on the cellvar of the parent
|
|
# function. So, recursively obtain parent cellvars.
|
|
if isinstance(self, InliningInstructionTranslator):
|
|
self._cellvars += self.parent.cellvars()
|
|
return self._cellvars
|
|
|
|
def freevars(self):
|
|
if not hasattr(self, "_freevars"):
|
|
self._freevars = tuple(self.code_options["co_freevars"] or [])
|
|
# An inlined function might depend on the freevar of the parent
|
|
# function. So, recursively obtain parent freevars.
|
|
if isinstance(self, InliningInstructionTranslator):
|
|
self._freevars += self.parent.freevars()
|
|
return self._freevars
|
|
|
|
def cell_and_freevars(self):
|
|
if not hasattr(self, "_cell_and_freevars"):
|
|
self._cell_and_freevars = self.cellvars() + self.freevars()
|
|
return self._cell_and_freevars
|
|
|
|
def prune_dead_locals(self):
|
|
# Only keep the locals that must remain on the stack.
|
|
reads = livevars_analysis(self.instructions, self.current_instruction)
|
|
self.symbolic_locals = {
|
|
k: v for k, v in self.symbolic_locals.items() if k in reads
|
|
}
|
|
# "Garbage collect the heap".
|
|
self.output.side_effects.prune_dead_object_new(self)
|
|
|
|
def call_function(
|
|
self,
|
|
fn: VariableTracker,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
):
|
|
assert isinstance(fn, VariableTracker)
|
|
assert isinstance(args, list)
|
|
assert isinstance(kwargs, dict)
|
|
assert all(
|
|
isinstance(x, VariableTracker)
|
|
for x in itertools.chain(args, kwargs.values())
|
|
)
|
|
inner_fn = None
|
|
if hasattr(fn, "value"):
|
|
inner_fn = fn.value
|
|
if hasattr(fn, "fn"):
|
|
inner_fn = fn.fn
|
|
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
|
|
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
|
|
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
|
|
|
|
def inline_generator_function(self, fn, args, kwargs):
|
|
"""
|
|
Redirect the call to the generator "call_function"
|
|
"""
|
|
if not isinstance(fn, LocalGeneratorFunctionVariable):
|
|
fn = LocalGeneratorFunctionVariable(fn)
|
|
return fn.call_function(self, args, kwargs)
|
|
|
|
def inline_user_function_return(self, fn, args, kwargs):
|
|
"""
|
|
A call to some user defined function by inlining it.
|
|
"""
|
|
if config.enable_faithful_generator_behavior and is_generator(fn.get_code()):
|
|
return self.inline_generator_function(fn, args, kwargs)
|
|
else:
|
|
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
|
|
|
|
def get_line_of_code_header(self, lineno=None):
|
|
if lineno is None:
|
|
lineno = self.lineno
|
|
inline_depth_str = (
|
|
f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
|
|
)
|
|
funcname = get_funcname(self.f_code.co_filename, lineno)
|
|
funcname_str = "" if funcname is None else f" ({funcname})"
|
|
return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}"
|
|
|
|
def get_log_starts_line_log_str(self):
|
|
log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
|
|
line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
|
|
log_str += f" {line}"
|
|
return log_str
|
|
|
|
def starts_line(self, lineno):
|
|
if self.lineno == lineno:
|
|
return
|
|
self.lineno = lineno
|
|
TracingContext.set_current_loc(
|
|
self.f_code.co_filename, lineno, self.f_code.co_name
|
|
)
|
|
from torch._logging.structured import dump_file
|
|
|
|
dump_file(self.f_code.co_filename)
|
|
if trace_source_log.isEnabledFor(logging.DEBUG):
|
|
trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
|
|
|
|
def step(self):
|
|
"""Process exactly one instruction, return False we should exit"""
|
|
ip = self.instruction_pointer
|
|
if ip is None:
|
|
return False
|
|
self.current_instruction = inst = self.instructions[ip]
|
|
self.instruction_pointer = ip + 1
|
|
|
|
if inst.starts_line:
|
|
self.starts_line(inst.starts_line)
|
|
|
|
if (
|
|
not self.stack
|
|
and self.should_compile_partial_graph()
|
|
and self.is_non_empty_graph()
|
|
):
|
|
self.current_speculation = self.speculate()
|
|
if self.current_speculation.failed:
|
|
return self.step_graph_break(inst)
|
|
|
|
if trace_bytecode_log.isEnabledFor(logging.DEBUG):
|
|
trace_bytecode_log.debug(
|
|
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
|
|
)
|
|
|
|
self.update_block_stack(inst)
|
|
|
|
try:
|
|
self.dispatch_table[inst.opcode](self, inst)
|
|
return not self.output.should_exit
|
|
except TensorifyScalarRestartAnalysis:
|
|
raise
|
|
except exc.ObservedException as e:
|
|
self.exception_handler(e)
|
|
return True
|
|
except (ReturnValueOp, YieldValueOp):
|
|
return False
|
|
except Unsupported:
|
|
if self.current_speculation is None:
|
|
log.debug("empty checkpoint")
|
|
raise
|
|
log.debug("step triggered compile", exc_info=True)
|
|
|
|
self.current_speculation.fail_and_restart_analysis()
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
def update_block_stack(self, inst):
|
|
# 3.11+ no longer uses a block stack, but we still keep track of one
|
|
# so that we know which contexts are currently active.
|
|
# For our purposes, all exception table entries with the same target
|
|
# are considered to be part of the same "block".
|
|
# NOTE: we only keep track of with blocks that are not contained in try blocks.
|
|
# This is because we will not create continuation functions on graph breaks in try blocks,
|
|
# but we may for with blocks. We do not push blocks here since
|
|
# with blocks are pushed when handling BEFORE_WITH.
|
|
entry = inst.exn_tab_entry
|
|
if entry:
|
|
# Detect when we have exited the top with block.
|
|
# The with blocks on the block stack are not enclosed in try
|
|
# blocks, so a with block's cleanup code should be in the
|
|
# previous with block (if any).
|
|
if (
|
|
len(self.block_stack) >= 2
|
|
and entry.target is not self.block_stack[-1].target
|
|
and entry.target is self.block_stack[-2].target
|
|
):
|
|
# exit the current block
|
|
self.block_stack.pop()
|
|
else:
|
|
# no longer in any block
|
|
# It is possible for NOPs to be between two instructions
|
|
# in the same block, but the NOPs are not covered by an
|
|
# exception table entry. In this case, assume that we
|
|
# are still in the same block.
|
|
# In 3.12+, JUMP_BACKWARD might also not be covered by
|
|
# an exception table entry, so we also assume that we
|
|
# are still in the same block. It is probably safe to do
|
|
# this in 3.11, even though we haven't encountered this case before.
|
|
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
|
|
# If we really escape from a block and the current
|
|
# instruction is not in another block, then there
|
|
# should be no other nested blocks that we are in.
|
|
assert len(self.block_stack) == 1
|
|
self.block_stack.pop()
|
|
|
|
else:
|
|
|
|
def update_block_stack(self, inst):
|
|
pass
|
|
|
|
@property
|
|
def next_instruction(self):
|
|
return self.instructions[self.instruction_pointer] # type: ignore[index]
|
|
|
|
def step_graph_break(self, continue_inst):
|
|
# generate code from checkpoint
|
|
assert not self.output.output_instructions
|
|
assert self.current_speculation is not None
|
|
self.output.compile_subgraph(
|
|
self,
|
|
partial_convert=True,
|
|
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
|
|
)
|
|
self.output.add_output_instructions(
|
|
[create_jump_absolute(continue_inst)] + self.instructions
|
|
)
|
|
|
|
def run_ctx_mgr(self):
|
|
# NB: Don't push the top level frame summary; set_current_loc will
|
|
# take care of it. However, DO make sure we attach real_stack to
|
|
# exceptions
|
|
return TracingContext.current_frame(None)
|
|
|
|
def run(self):
|
|
with self.run_ctx_mgr():
|
|
try:
|
|
self.output.push_tx(self)
|
|
self.start_point = self.instruction_pointer
|
|
while self.step():
|
|
pass
|
|
except TensorifyScalarRestartAnalysis:
|
|
raise
|
|
except BackendCompilerFailed:
|
|
raise
|
|
except RuntimeError as e:
|
|
if hasattr(e, "msg") and "Data-dependent" in e.msg:
|
|
readable_graph = torch.fx.GraphModule(
|
|
self.output.nn_modules, self.output.graph
|
|
).print_readable(
|
|
print_output=False, include_stride=True, include_device=True
|
|
)
|
|
e.partial_fx_graph = readable_graph # type: ignore[attr-defined]
|
|
raise
|
|
|
|
raise
|
|
except Exception as e:
|
|
if self.exec_recorder:
|
|
e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
|
|
|
|
raise
|
|
finally:
|
|
self.output.pop_tx()
|
|
# Cleanup the outputGraph to delete the held tensors. We perform the
|
|
# cleanup only for InstructionTranslator and not
|
|
# InliningInstructionTranslator. The InliningInstructionTranslator
|
|
# mutates the output object and is restored to original state if
|
|
# there was an exception.
|
|
if isinstance(self, InstructionTranslator):
|
|
self.output.cleanup()
|
|
|
|
def push(self, val: Optional[VariableTracker]):
|
|
assert val is None or isinstance(val, VariableTracker), (
|
|
f"push expects VariableTracker, got {typestr(val)}"
|
|
)
|
|
self.stack.append(val) # type: ignore[arg-type]
|
|
|
|
def push_many(self, vals: list[VariableTracker]):
|
|
for val in vals:
|
|
self.push(val)
|
|
|
|
def pop(self) -> VariableTracker:
|
|
return self.stack.pop()
|
|
|
|
def popn(self, n: int) -> list[VariableTracker]:
|
|
return [*reversed([self.pop() for _ in range(n)])]
|
|
|
|
def LOAD_FAST(self, inst):
|
|
name = inst.argval
|
|
if self.exec_recorder and name in self.f_locals:
|
|
self.exec_recorder.add_local_var(name, self.f_locals[name])
|
|
|
|
try:
|
|
self.push(self.symbolic_locals[name].unwrap())
|
|
except KeyError:
|
|
if name.startswith("."):
|
|
try:
|
|
# This happens in dict/list comprehensions
|
|
new_name = name.replace(".", "implicit")
|
|
self.push(self.symbolic_locals[new_name])
|
|
except KeyError:
|
|
unimplemented_v2(
|
|
gb_type="Attempted to read undefined local variable (implicit)",
|
|
context=f"LOAD_FAST {name}",
|
|
explanation=f"Could not find an implicit local variable with name `{name}`",
|
|
hints=[
|
|
"This happens in dict/list comprehensions",
|
|
*graph_break_hints.USER_ERROR,
|
|
],
|
|
)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Attempted to read undefined local variable",
|
|
context=f"LOAD_FAST {name}",
|
|
explanation=f"Could not find a local variable with name `{name}`",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
# for continuation functions
|
|
if name.startswith("___stack"):
|
|
self.symbolic_locals.pop(name)
|
|
|
|
def LOAD_DEREF(self, inst):
|
|
assert inst.argval in self.cell_and_freevars()
|
|
cell = self.symbolic_locals[inst.argval]
|
|
contents_var = self.output.side_effects.load_cell(cell)
|
|
self.push(contents_var)
|
|
|
|
if self.exec_recorder and inst.argval in self.f_locals:
|
|
self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
|
|
|
|
def STORE_FAST(self, inst):
|
|
name = inst.argval
|
|
loaded_vt = self.pop()
|
|
loaded_vt.set_name_hint(name)
|
|
self.symbolic_locals[name] = loaded_vt
|
|
|
|
def DELETE_FAST(self, inst):
|
|
del self.symbolic_locals[inst.argval]
|
|
|
|
def STORE_DEREF(self, inst): # type: ignore[override]
|
|
assert inst.argval in self.cell_and_freevars()
|
|
cell = self.symbolic_locals[inst.argval]
|
|
val = self.pop()
|
|
self.output.side_effects.store_cell(cell, val)
|
|
|
|
assert isinstance(cell, CellVariable) # tame mypy
|
|
if cell.local_name is not None:
|
|
val.set_name_hint(cell.local_name) # type: ignore[attr-defined]
|
|
|
|
LOAD_CLOSURE = LOAD_FAST
|
|
|
|
def _load_const(self, inst):
|
|
i = inst.arg
|
|
if i is None:
|
|
return ConstantVariable.create(value=inst.argval)
|
|
val = self._constants_cache[i]
|
|
if not val:
|
|
self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval)
|
|
return val
|
|
|
|
def LOAD_CONST(self, inst):
|
|
self.push(self._load_const(inst))
|
|
|
|
def _load_global(self, inst):
|
|
name = inst.argval
|
|
|
|
if self.exec_recorder:
|
|
if name in self.f_globals:
|
|
self.exec_recorder.add_global_var(name, self.f_globals[name])
|
|
else:
|
|
assert name in self.f_builtins
|
|
self.exec_recorder.builtins[name] = self.f_builtins[name]
|
|
|
|
if name in self.symbolic_globals:
|
|
variable = self.output.side_effects[self.symbolic_globals[name]]
|
|
self.push(self.output.side_effects.load_global(variable, name))
|
|
return
|
|
|
|
try:
|
|
value = self.f_globals[name]
|
|
except KeyError:
|
|
return self.load_builtin(inst)
|
|
|
|
self.push(VariableTracker.build(self, value, GlobalSource(name)))
|
|
|
|
@functools.cached_property
|
|
def nn_modules_globals_vt(self):
|
|
module_name = "torch.nn.modules.module"
|
|
module_source = self.import_source(module_name)
|
|
fglobals_value = _import_module(module_name)
|
|
return VariableTracker.build(self, fglobals_value, module_source)
|
|
|
|
def LOAD_GLOBAL(self, inst):
|
|
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
|
|
self.PUSH_NULL(inst)
|
|
self._load_global(inst)
|
|
if sys.version_info >= (3, 13) and inst.arg % 2:
|
|
self.PUSH_NULL(inst)
|
|
|
|
def STORE_GLOBAL(self, inst):
|
|
value = self.pop()
|
|
name = inst.argval
|
|
source = GlobalSource(name)
|
|
if name not in self.symbolic_globals:
|
|
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
|
|
variable = self.output.side_effects.track_global_existing(
|
|
source, self.symbolic_globals[name]
|
|
)
|
|
if isinstance(value, RemovableHandleVariable):
|
|
unimplemented_v2(
|
|
gb_type="Storing Tensor hook handle in globals",
|
|
context=name,
|
|
explanation="This is not supported.",
|
|
hints=[],
|
|
)
|
|
self.output.side_effects.store_global(variable, name, value)
|
|
|
|
# Cache note: This cache only exists for the duration of this
|
|
# InstructionTranslator - so it should be safe to do.
|
|
@cache_method
|
|
def import_source(self, module_name):
|
|
"""Create an alias to a module for use in guards"""
|
|
if "torch_package" in module_name:
|
|
value = torch.package.package_importer._package_imported_modules[
|
|
module_name
|
|
]
|
|
alias = (
|
|
module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
|
|
)
|
|
else:
|
|
value = _import_module(module_name)
|
|
alias = f"__import_{module_name.replace('.', '_dot_')}"
|
|
f_globals = self.output.global_scope
|
|
assert alias not in f_globals or f_globals[alias] is value
|
|
f_globals[alias] = value
|
|
self.output.update_co_names(alias)
|
|
return GlobalSource(alias)
|
|
|
|
def resolve_name(self, name, package, level):
|
|
"""
|
|
Copied from the Cpython implementation of __import__
|
|
Resolve a relative module name to an absolute one.
|
|
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
|
|
"""
|
|
bits = package.rsplit(".", level - 1)
|
|
if len(bits) < level:
|
|
raise ImportError("attempted relative import beyond top-level package")
|
|
base = bits[0]
|
|
return f"{base}.{name}" if name else base
|
|
|
|
def calc_package(self):
|
|
"""
|
|
Copied from the Cpython implementation of __import__
|
|
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
|
|
"""
|
|
package = self.f_globals.get("__package__")
|
|
spec = self.f_globals.get("__spec__")
|
|
if package is not None:
|
|
if spec is not None and package != spec.parent:
|
|
log.warning(
|
|
"__package__ != __spec__.parent (%r != %r)",
|
|
package,
|
|
spec.parent,
|
|
stacklevel=3,
|
|
)
|
|
return package
|
|
elif spec is not None:
|
|
return spec.parent
|
|
else:
|
|
log.warning(
|
|
"can't resolve package from __spec__ or __package__, "
|
|
"falling back on __name__ and __path__",
|
|
stacklevel=3,
|
|
)
|
|
package = self.f_globals["__name__"]
|
|
if "__path__" not in self.f_globals:
|
|
package = package.rpartition(".")[0]
|
|
return package
|
|
|
|
def IMPORT_NAME(self, inst):
|
|
level, fromlist = self.popn(2)
|
|
level = level.as_python_constant()
|
|
fromlist = fromlist.as_python_constant()
|
|
module_name = inst.argval
|
|
|
|
# Are we replaying? if so, load recorded module
|
|
recorded_name = (
|
|
f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
|
|
)
|
|
if recorded_name in self.f_globals:
|
|
value = self.f_globals[recorded_name]
|
|
source = GlobalSource(recorded_name)
|
|
else:
|
|
try:
|
|
value = __import__(
|
|
module_name,
|
|
fromlist=fromlist,
|
|
level=level,
|
|
globals=self.f_globals,
|
|
)
|
|
except ImportError:
|
|
unimplemented_v2(
|
|
gb_type="Import failure",
|
|
context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}",
|
|
explanation="Failure when attempting to import.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
if level != 0:
|
|
pkg = self.calc_package()
|
|
module_name = self.resolve_name(module_name, pkg, level)
|
|
|
|
# For __import__, when the name variable is of the form package.module,
|
|
# normally, the top-level package (the name up till the first dot) is
|
|
# returned, not the module named by module_name. However, when a
|
|
# non-empty fromlist argument is given, the module named by name is
|
|
# returned. Therefore, we set the source correctly here.
|
|
if not fromlist:
|
|
top_level_module_name = module_name.partition(".")[0]
|
|
source = self.import_source(top_level_module_name)
|
|
else:
|
|
source = self.import_source(module_name)
|
|
|
|
if self.exec_recorder:
|
|
self.exec_recorder.add_local_mod(recorded_name, value)
|
|
|
|
if istype(value, (types.ModuleType, DummyModule)):
|
|
self.push(PythonModuleVariable(value, source=source))
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Bad import result",
|
|
context=typestr(value),
|
|
explanation="Import result is not a Python module.",
|
|
hints=[],
|
|
)
|
|
|
|
def IMPORT_FROM(self, inst):
|
|
self.DUP_TOP(inst)
|
|
self._load_attr(inst)
|
|
|
|
def load_builtin_from_argval(self, argval):
|
|
if argval not in self.f_builtins:
|
|
raise Unsupported(f"name '{argval}' is not defined")
|
|
val = self.f_builtins[argval]
|
|
|
|
if callable(val):
|
|
builtins_source = GlobalSource(
|
|
self.output.name_of_builtins_dict_key_in_fglobals
|
|
)
|
|
var_source = DictGetItemSource(builtins_source, argval)
|
|
self.push(VariableTracker.build(self, val, var_source))
|
|
else:
|
|
assert is_builtin_constant(val)
|
|
self.push(ConstantVariable.create(value=val))
|
|
|
|
def load_builtin(self, inst):
|
|
self.load_builtin_from_argval(inst.argval)
|
|
|
|
def jump(self, inst):
|
|
assert self.instruction_pointer is not None
|
|
assert self.start_point is not None
|
|
get_metrics_context().increment(
|
|
"ir_count", self.instruction_pointer - self.start_point
|
|
)
|
|
self.instruction_pointer = self.indexof[inst.target]
|
|
self.start_point = self.instruction_pointer
|
|
|
|
JUMP_FORWARD = jump
|
|
JUMP_ABSOLUTE = jump
|
|
|
|
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
|
|
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
|
|
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
|
|
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
|
|
|
|
def SETUP_LOOP(self, inst):
|
|
# only exists in python<=3.7
|
|
self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
|
|
|
|
def SETUP_EXCEPT(self, inst):
|
|
# only exists in python<=3.7
|
|
self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
|
|
|
|
def POP_BLOCK(self, inst):
|
|
self.block_stack.pop()
|
|
|
|
def SETUP_WITH(self, inst):
|
|
self.setup_or_before_with(inst)
|
|
|
|
def SETUP_FINALLY(self, inst):
|
|
self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
|
|
|
|
def BEGIN_FINALLY(self, inst):
|
|
self.push(None)
|
|
|
|
def WITH_CLEANUP_START(self, inst):
|
|
exit, exc = self.popn(2)
|
|
assert exc is None
|
|
self.push(exc)
|
|
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
|
|
|
|
def WITH_CLEANUP_FINISH(self, inst):
|
|
self.popn(2)
|
|
self.push(None)
|
|
|
|
def CALL_FINALLY(self, inst):
|
|
"""
|
|
pushes the address of the next instruction onto the stack and increments
|
|
bytecode counter by delta
|
|
"""
|
|
# Python 3.8 only
|
|
addr = self.indexof[self.next_instruction]
|
|
self.push(ConstantVariable.create(addr))
|
|
self.jump(inst)
|
|
|
|
def END_FINALLY(self, inst):
|
|
# Python 3.8 only
|
|
# https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY
|
|
tos = self.pop()
|
|
if isinstance(tos, ConstantVariable):
|
|
self.instruction_pointer = tos.as_python_constant()
|
|
else:
|
|
pass
|
|
|
|
def POP_FINALLY(self, inst):
|
|
# Python 3.8 only
|
|
preserve_tos = inst.argval
|
|
if preserve_tos:
|
|
tos = self.pop()
|
|
_ = self.pop()
|
|
if preserve_tos:
|
|
self.push(tos) # type: ignore[possibly-undefined]
|
|
|
|
def FOR_ITER(self, inst):
|
|
it = self.pop().realize()
|
|
try:
|
|
val = it.next_variable(self)
|
|
self.push(it)
|
|
self.push(val)
|
|
except (StopIteration, exc.ObservedUserStopIteration) as e:
|
|
if isinstance(e, exc.ObservedUserStopIteration):
|
|
exc.handle_observed_exception(self)
|
|
|
|
# leave iterator upon exhaustion in 3.12
|
|
if sys.version_info >= (3, 12):
|
|
# CPython 3.12 actually jumps to the instruction after the END_FOR
|
|
# and performs the action of END_FOR as part of FOR_ITER. We jump
|
|
# to the END_FOR and run it, so we need to make sure 2 values are
|
|
# on the stack for it to pop.
|
|
self.push(it)
|
|
self.push(ConstantVariable.create(None))
|
|
self.jump(inst)
|
|
|
|
def _raise_exception_variable(self, val) -> NoReturn:
|
|
# User can raise exception in 2 ways
|
|
# 1) raise exception type - raise NotImplementedError
|
|
# 2) raise execption instance - raise NotImplemetedError("foo")
|
|
|
|
# 1) when user raises exception type
|
|
if isinstance(
|
|
val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable)
|
|
):
|
|
# Create the instance of the exception type
|
|
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
|
|
val = val.call_function(self, [], {}) # type: ignore[arg-type]
|
|
|
|
# Handle https://peps.python.org/pep-0479/
|
|
# CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
|
|
if (
|
|
is_generator(self.f_code)
|
|
and isinstance(val, variables.ExceptionVariable)
|
|
and val.exc_type is StopIteration
|
|
):
|
|
val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type]
|
|
|
|
# Save the exception in a global data structure
|
|
self.exn_vt_stack.set_current_exception(val)
|
|
|
|
# 2) when user raises exception instance
|
|
if self._isinstance_exception(val):
|
|
observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined]
|
|
raise observed_exception_type(f"raised exception {val}")
|
|
unimplemented_v2(
|
|
gb_type="Failed to raise exception",
|
|
context=str(exc),
|
|
explanation="Attempted to raise a non-Exception type/value.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
def RAISE_VARARGS(self, inst):
|
|
if inst.arg == 0:
|
|
# re-raise the previous exception. Here CPython refers to the exception
|
|
# on top of the exception stack
|
|
assert len(self.exn_vt_stack)
|
|
val = self.exn_vt_stack[-1]
|
|
assert self._isinstance_exception(val), val
|
|
self._raise_exception_variable(val)
|
|
elif inst.arg == 1:
|
|
# raise TOS
|
|
val = self.stack[-1]
|
|
self._raise_exception_variable(val)
|
|
else:
|
|
# raise .. from None
|
|
from_vt = self.pop()
|
|
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
|
|
val = self.pop()
|
|
try:
|
|
self._raise_exception_variable(val)
|
|
finally:
|
|
# Update __cause__/__supppress_context__ in the raised exception
|
|
curr_exc = self.exn_vt_stack.get_current_exception()
|
|
curr_exc.call_setattr(
|
|
self, ConstantVariable("__cause__"), ConstantVariable(None)
|
|
)
|
|
unimplemented_v2(
|
|
gb_type="Re-raise with 2 arguments",
|
|
context=str(from_vt),
|
|
explanation="Dynamo does not support `raise ... from [not-None]`",
|
|
hints=[],
|
|
)
|
|
|
|
def CLEANUP_THROW(self, inst):
|
|
# https://github.com/python/cpython/pull/96010
|
|
tos = self.stack[-1]
|
|
assert isinstance(tos, ExceptionVariable)
|
|
if tos.exc_type is StopIteration:
|
|
unimplemented_v2(
|
|
gb_type="CLEANUP_THROW with StopIteration",
|
|
context="",
|
|
explanation="Received StopIteration when handling generator.throw/close. This is not supported.",
|
|
hints=[],
|
|
)
|
|
else:
|
|
self.RERAISE(inst)
|
|
|
|
def RERAISE(self, inst):
|
|
# https://docs.python.org/3/library/dis.html#opcode-RERAISE
|
|
# Re-raises the exception currently on top of the stack. If oparg is
|
|
# non-zero, pops an additional value from the stack which is used to
|
|
# set f_lasti of the current frame.
|
|
|
|
if sys.version_info >= (3, 11):
|
|
# RERAISE is currently supported in a narrow case of `raise ... from None`
|
|
val = self.pop()
|
|
if inst.argval:
|
|
# RERAISE 1
|
|
_ = self.pop()
|
|
self._raise_exception_variable(val)
|
|
else:
|
|
# RERAISE 0
|
|
self.push(val)
|
|
self._raise_exception_variable(val)
|
|
else:
|
|
_exc = self.pop()
|
|
val = self.pop()
|
|
_tb = self.pop()
|
|
self._raise_exception_variable(val)
|
|
|
|
def _isinstance_exception(self, val):
|
|
return isinstance(
|
|
val,
|
|
(
|
|
variables.ExceptionVariable,
|
|
UserDefinedExceptionClassVariable,
|
|
UserDefinedExceptionObjectVariable,
|
|
),
|
|
)
|
|
|
|
def WITH_EXCEPT_START(self, inst):
|
|
if sys.version_info >= (3, 11):
|
|
# At the top of the stack are 4 values:
|
|
# - TOP = exc_info()
|
|
# - SECOND = previous exception
|
|
# - THIRD: lasti of exception in exc_info()
|
|
# - FOURTH: the context.__exit__ bound method
|
|
# We call FOURTH(type(TOP), TOP, GetTraceback(TOP)).
|
|
# Then we push the __exit__ return value.
|
|
assert len(self.stack) >= 4
|
|
fn = self.stack[-4]
|
|
val = self.stack[-1]
|
|
assert self._isinstance_exception(val)
|
|
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
|
|
tb = ConstantVariable(None)
|
|
else:
|
|
assert len(self.stack) >= 7
|
|
fn = self.stack[-7]
|
|
val = self.stack[-2]
|
|
assert self._isinstance_exception(val)
|
|
typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
|
|
tb = ConstantVariable(None)
|
|
|
|
self.call_function(fn, [typ, val, tb], {})
|
|
|
|
def exception_handler(self, raised_exception):
|
|
observed_exn_gb_explanation = (
|
|
"Dynamo found no exception handler at the top-level compiled function "
|
|
"when encountering an exception. Exception will propagate outside the compiled region."
|
|
)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
exn_tab_entry = self.current_instruction.exn_tab_entry
|
|
if exn_tab_entry:
|
|
# Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
|
|
|
|
# 1) pop values from the stack until it matches the stack depth
|
|
# for the handler
|
|
while len(self.stack) > exn_tab_entry.depth:
|
|
self.pop()
|
|
|
|
# 2) if 'lasti' is true, then push the offset that the exception was raised at
|
|
if exn_tab_entry.lasti:
|
|
self.push(
|
|
variables.ConstantVariable(self.current_instruction.offset)
|
|
)
|
|
|
|
# 3) push the exception to the stack
|
|
self.push(self.exn_vt_stack.get_current_exception())
|
|
|
|
# 4) jump to the handler
|
|
self.jump(exn_tab_entry)
|
|
else:
|
|
# No handler found. Bubble the exception to the parent
|
|
# instruction translater. We use special exception for this.
|
|
self.stack.clear()
|
|
if type(self) is InstructionTranslator:
|
|
unimplemented_v2(
|
|
gb_type="Observed exception",
|
|
context=str(raised_exception),
|
|
explanation=observed_exn_gb_explanation,
|
|
hints=[
|
|
*graph_break_hints.USER_ERROR,
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
raise raised_exception
|
|
else:
|
|
if len(self.block_stack):
|
|
# base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
|
|
|
|
block_stack_entry = self.block_stack.pop()
|
|
|
|
while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
|
|
# TODO(anijain2305) - This is not tested .. unable to create a testcase
|
|
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
|
|
self.popn(3)
|
|
self.exn_vt_stack.pop()
|
|
if len(self.block_stack) == 0:
|
|
# No handler found in this frame. Bubble the exception to the parent
|
|
# instruction translater.
|
|
self.stack.clear()
|
|
if type(self) is InstructionTranslator:
|
|
unimplemented_v2(
|
|
gb_type="Observed exception (EXCEPT_HANDLER)",
|
|
context=str(raised_exception),
|
|
explanation=observed_exn_gb_explanation
|
|
+ " This graph break is unexpected.",
|
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
|
)
|
|
|
|
raise raised_exception
|
|
block_stack_entry = self.block_stack.pop()
|
|
|
|
exception_var = self.exn_vt_stack.get_current_exception()
|
|
self.exn_vt_stack.move_current_exception_to_stack()
|
|
|
|
# 1) pop values from the stack until it matches the stack depth
|
|
# for the handler
|
|
while len(self.stack) > block_stack_entry.stack_index:
|
|
self.pop()
|
|
|
|
# Push a dummy block stack entry of EXCEPT_HANDLER
|
|
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
|
|
except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0)
|
|
self.block_stack.append(
|
|
BlockStackEntry(except_handler_inst, None, len(self.stack))
|
|
)
|
|
|
|
# Push old exception
|
|
if len(self.exn_vt_stack) >= 2:
|
|
old_exception = self.exn_vt_stack[-2]
|
|
|
|
# Push the old exception on to stack - tb, value, type
|
|
# Traceback is currently mapped to UnknownVariable
|
|
self.push(variables.UnknownVariable())
|
|
self.push(old_exception)
|
|
self.push(variables.BuiltinVariable(old_exception.exc_type))
|
|
else:
|
|
# Push empty exception tb, value, type
|
|
self.push(variables.ConstantVariable(None))
|
|
self.push(variables.ConstantVariable(None))
|
|
self.push(variables.ConstantVariable(None))
|
|
|
|
# Push new exception - tb, val, type
|
|
# Traceback is currently mapped to UnknownVariable
|
|
self.push(variables.UnknownVariable())
|
|
self.push(exception_var)
|
|
self.push(variables.BuiltinVariable(exception_var.exc_type))
|
|
|
|
# Jump to target
|
|
self.jump(block_stack_entry)
|
|
else:
|
|
# No handler found. Bubble the exception to the parent
|
|
# instruction translater. We use special exception for this.
|
|
self.stack.clear()
|
|
if type(self) is InstructionTranslator:
|
|
unimplemented_v2(
|
|
gb_type="Observed exception",
|
|
context=str(raised_exception),
|
|
explanation=observed_exn_gb_explanation,
|
|
hints=[
|
|
*graph_break_hints.USER_ERROR,
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
raise raised_exception
|
|
|
|
def PUSH_EXC_INFO(self, inst):
|
|
# https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
|
|
# Pops a value from the stack. Pushes the current exception to the top
|
|
# of the stack. Pushes the value originally popped back to the stack.
|
|
#
|
|
# The behavior of this opcode in CPython is a bit different than what it
|
|
# is described. It pops a value from the stack, pushes the top of the
|
|
# exception stack to the interpreter stack and moves the
|
|
# "current exception" to the exception stack.
|
|
#
|
|
# As an example, suppose the stack is in the following state:
|
|
# + stack = [..., ConstantVariable(1), ConstantVariable(2)]
|
|
# + current_exception = TypeError
|
|
# + exception_stack = [ValueError]
|
|
#
|
|
# After PUSH_EXC_INFO is executed
|
|
# + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
|
|
# + current_exception = None
|
|
# + exception_stack = [ValueError, TypeError]
|
|
|
|
val = self.pop()
|
|
if len(self.exn_vt_stack) == 0:
|
|
prev_exc = ConstantVariable(None)
|
|
else:
|
|
prev_exc = self.exn_vt_stack[-1]
|
|
self.push(prev_exc)
|
|
self.push(val)
|
|
self.exn_vt_stack.move_current_exception_to_stack()
|
|
|
|
def POP_EXCEPT(self, inst):
|
|
if sys.version_info >= (3, 11):
|
|
_ = self.pop()
|
|
# This exception is handled and therefore we can clear the error indicator
|
|
assert len(self.exn_vt_stack)
|
|
self.exn_vt_stack.pop()
|
|
else:
|
|
assert len(self.block_stack) > 0
|
|
if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER":
|
|
raise AssertionError(
|
|
"Bug in Dynamo tracing of exception handling."
|
|
"Top of the block stack is not EXCEPT_HANDLER."
|
|
)
|
|
self.block_stack.pop()
|
|
|
|
self.popn(3)
|
|
|
|
# This exception is handled and therefore we can clear the error indicator
|
|
assert len(self.exn_vt_stack)
|
|
self.exn_vt_stack.pop()
|
|
|
|
def check_if_exc_matches(self):
|
|
assert len(self.stack) >= 2
|
|
expected_exc_types = self.pop()
|
|
if sys.version_info >= (3, 11):
|
|
# CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop.
|
|
# This is the description from the disassembly doc
|
|
#
|
|
# Performs exception matching for ``except``. Tests whether the ``STACK[-2]``
|
|
# is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean
|
|
# result of the test.
|
|
exc_instance = self.stack[-1]
|
|
else:
|
|
# This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH
|
|
# There is no documentation but here is the code pointer that does 2 pops
|
|
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
|
|
exc_instance = self.stack.pop()
|
|
|
|
# Users can check exception in 3 ways
|
|
# 1) except NotImplementedError --> BuiltinVariable
|
|
# 2) except CustomException --> UserDefinedExceptionClasVariable
|
|
# 3) except (NotImplemetedError, AttributeError) -> TupleVariable
|
|
|
|
if not isinstance(
|
|
expected_exc_types,
|
|
(
|
|
BuiltinVariable,
|
|
TupleVariable,
|
|
UserDefinedExceptionClassVariable,
|
|
UserDefinedExceptionObjectVariable,
|
|
),
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Exception with bad expected type",
|
|
context=str(expected_exc_types),
|
|
explanation=f"`except ...` has unsupported type {expected_exc_types}.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
if not self._isinstance_exception(exc_instance):
|
|
unimplemented_v2(
|
|
gb_type="Caught non-Exception value",
|
|
context=str(exc_instance),
|
|
explanation=f"Except expects to recieve an object of Exception type but received {exc_instance}.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
if isinstance(expected_exc_types, TupleVariable):
|
|
expected_types = expected_exc_types.items
|
|
else:
|
|
expected_types = [
|
|
expected_exc_types,
|
|
]
|
|
|
|
for expected_type in expected_types:
|
|
if not isinstance(
|
|
expected_type,
|
|
(
|
|
BuiltinVariable,
|
|
UserDefinedExceptionObjectVariable,
|
|
UserDefinedExceptionClassVariable,
|
|
),
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Exception with non-type expectation",
|
|
context=str(expected_type),
|
|
explanation=f"`except ...` expects a non-type: {expected_type}.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
if self._isinstance_exception(exc_instance) and issubclass(
|
|
exc_instance.exc_type, # type: ignore[attr-defined]
|
|
expected_type.fn, # type: ignore[attr-defined]
|
|
):
|
|
return True
|
|
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
|
|
exc_instance.fn, expected_type.fn
|
|
):
|
|
return True
|
|
|
|
return False
|
|
|
|
def CHECK_EXC_MATCH(self, inst):
|
|
self.push(variables.ConstantVariable(self.check_if_exc_matches()))
|
|
|
|
def JUMP_IF_NOT_EXC_MATCH(self, inst):
|
|
if not self.check_if_exc_matches():
|
|
self.jump(inst)
|
|
|
|
def COMPARE_OP(self, inst):
|
|
if inst.argval == "exception match":
|
|
self.CHECK_EXC_MATCH(inst)
|
|
else:
|
|
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
|
|
|
|
def GET_ITER(self, inst):
|
|
self.call_function(BuiltinVariable(iter), [self.pop()], {})
|
|
|
|
@break_graph_if_unsupported(push=1)
|
|
def CALL_FUNCTION(self, inst):
|
|
args = self.popn(inst.argval)
|
|
fn = self.pop()
|
|
self.call_function(fn, args, {})
|
|
|
|
@break_graph_if_unsupported(push=1)
|
|
def CALL_FUNCTION_EX(self, inst):
|
|
kwargsvars: VariableTracker
|
|
if inst.argval == 0:
|
|
kwargsvars = ConstDictVariable({})
|
|
argsvars = self.pop()
|
|
elif inst.argval == 1:
|
|
kwargsvars = self.pop()
|
|
argsvars = self.pop()
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Variadic function call with bad flags",
|
|
context=f"flags: {inst.argval}",
|
|
explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}",
|
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
|
)
|
|
|
|
if sys.version_info >= (3, 13):
|
|
# 3.13 swapped null and callable
|
|
null = self.pop()
|
|
assert isinstance(null, NullVariable)
|
|
|
|
fn = self.pop()
|
|
|
|
if sys.version_info >= (3, 11) and sys.version_info < (3, 13):
|
|
null = self.pop()
|
|
assert isinstance(null, NullVariable)
|
|
|
|
if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable):
|
|
# realize is requires for Python 3.8
|
|
kwargsvars = kwargsvars.realize()
|
|
if fn.name == "view" and isinstance(
|
|
argsvars, (ConstantVariable, TensorVariable)
|
|
):
|
|
# Hack to handle special case in some bert models. Converts
|
|
# x.view(*shape) into x.view(shape), which is correct for view()
|
|
# but not generally. See test_transpose_for_scores().
|
|
argsvars = TupleVariable([argsvars])
|
|
elif (
|
|
fn.name == "random_"
|
|
and isinstance(argsvars, TupleVariable)
|
|
and len(argsvars.items) == 0
|
|
and isinstance(kwargsvars, ConstDictVariable)
|
|
and ConstantVariable.create("from") in kwargsvars
|
|
):
|
|
# `from`` is python keyword. Adding random_ with `from` in the
|
|
# Fx graph causes syntax error. Even if we convert the kwargs to
|
|
# args, aot_autograd/inductor while lowering generates
|
|
# aten.random.from, again causing syntax errors. Since this
|
|
# usecase is uncommon, graph break.
|
|
unimplemented_v2(
|
|
gb_type="Tensor.random_ op called with `from` keyword",
|
|
context="",
|
|
explanation="This is not supported.",
|
|
hints=[],
|
|
)
|
|
elif (
|
|
fn.name == "uniform_"
|
|
and isinstance(argsvars, TupleVariable)
|
|
and len(argsvars.items) == 0
|
|
and isinstance(kwargsvars, ConstDictVariable)
|
|
and ConstantVariable.create("from") in kwargsvars
|
|
):
|
|
# `from`` is python keyword. Adding uniform_ with `from` in the
|
|
# Fx graph causes syntax error. Even if we convert the kwargs to
|
|
# args, aot_autograd/inductor while lowering generates
|
|
# aten.uniform.from, again causing syntax errors. Since this
|
|
# usecase is uncommon, graph break.
|
|
unimplemented_v2(
|
|
gb_type="Tensor.uniform_ op called with `from` keyword",
|
|
context="",
|
|
explanation="This is not supported.",
|
|
hints=[],
|
|
)
|
|
|
|
if not isinstance(
|
|
argsvars, BaseListVariable
|
|
) and argsvars.has_force_unpack_var_sequence(self):
|
|
argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
|
|
|
|
# Unpack for cases like fn(**obj) where obj is a map
|
|
if isinstance(kwargsvars, UserDefinedObjectVariable):
|
|
kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
|
|
|
|
if not isinstance(argsvars, BaseListVariable) or not isinstance(
|
|
kwargsvars, ConstDictVariable
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Variadic function call with bad args/kwargs type",
|
|
context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
|
|
explanation="Expected args to be a list and kwargs to be a dict",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
# Map to a dictionary of str -> VariableTracker
|
|
kwargsvars = kwargsvars.keys_as_python_constant()
|
|
self.call_function(fn, argsvars.items, kwargsvars)
|
|
|
|
@break_graph_if_unsupported(push=1)
|
|
def CALL_FUNCTION_KW(self, inst):
|
|
argnames = self.pop()
|
|
args = self.popn(inst.argval)
|
|
fn = self.pop()
|
|
assert isinstance(argnames, TupleVariable) and argnames.is_python_constant()
|
|
argnames = argnames.as_python_constant()
|
|
args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
|
|
kwargs = dict(zip(argnames, kwargs_list))
|
|
assert len(kwargs) == len(argnames)
|
|
self.call_function(fn, args, kwargs)
|
|
|
|
def LOAD_METHOD_SUPER(self, inst):
|
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
|
|
arg = inst.argval[0]
|
|
argval = self.code_options["co_names"][arg]
|
|
if sys.version_info < (3, 11):
|
|
self._load_attr(dataclasses.replace(inst, argval=argval))
|
|
else:
|
|
self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
|
|
|
|
def LOAD_ATTR_SUPER(self, inst):
|
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
|
|
arg = inst.argval[0]
|
|
argval = self.code_options["co_names"][arg]
|
|
self._load_attr(dataclasses.replace(inst, argval=argval))
|
|
|
|
def LOAD_METHOD(self, inst):
|
|
self._load_attr(inst)
|
|
obj = self.pop()
|
|
if sys.version_info >= (3, 13):
|
|
self.push(obj)
|
|
self.PUSH_NULL(inst)
|
|
elif sys.version_info >= (3, 11):
|
|
# always follow the NULL + fn convention, since if obj
|
|
# is actually a method, self is already bound to it, so it
|
|
# doesn't need to be passed in as an arg.
|
|
self.PUSH_NULL(inst)
|
|
self.push(obj)
|
|
else:
|
|
self.push(obj)
|
|
self.push(None)
|
|
|
|
def CALL_METHOD(self, inst):
|
|
args = self.popn(inst.argval)
|
|
dummy = self.pop()
|
|
assert dummy is None
|
|
fn = self.pop()
|
|
self.call_function(fn, args, {})
|
|
|
|
def _load_attr(self, inst):
|
|
obj = self.pop()
|
|
result = BuiltinVariable(getattr).call_function(
|
|
self, # type: ignore[arg-type]
|
|
[obj, ConstantVariable.create(inst.argval)],
|
|
{},
|
|
)
|
|
self.push(result)
|
|
|
|
def LOAD_ATTR(self, inst):
|
|
if sys.version_info >= (3, 12):
|
|
if inst.arg % 2:
|
|
self.LOAD_METHOD(inst)
|
|
return
|
|
self._load_attr(inst)
|
|
|
|
def STORE_ATTR(self, inst):
|
|
speculation = self.speculate()
|
|
if speculation.failed:
|
|
return self.store_attr_graph_break(inst)
|
|
val, obj = self.popn(2)
|
|
|
|
if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
|
|
# We don't allow side effects during export on non-constant values
|
|
# https://github.com/pytorch/torchdynamo/issues/1475
|
|
assert not self.export, (
|
|
f"Mutating module attribute {inst.argval} during export."
|
|
)
|
|
|
|
try:
|
|
BuiltinVariable(setattr).call_function(
|
|
self, # type: ignore[arg-type]
|
|
[obj, ConstantVariable.create(inst.argval), val],
|
|
{},
|
|
)
|
|
return
|
|
except Unsupported as e:
|
|
if not self.should_compile_partial_graph():
|
|
raise
|
|
log.debug("STORE_ATTR triggered compile", exc_info=True)
|
|
e.remove_from_stats()
|
|
e.add_to_stats("graph_break")
|
|
speculation.fail_and_restart_analysis()
|
|
|
|
def store_attr_graph_break(self, inst):
|
|
log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
|
|
if not self.should_compile_partial_graph():
|
|
unimplemented_v2(
|
|
gb_type="Should not compile partial graph (STORE_ATTR)",
|
|
context="",
|
|
explanation="Dynamo has determined when encountering an unsupported "
|
|
"STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.",
|
|
hints=[],
|
|
)
|
|
self.output.compile_subgraph(
|
|
self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
|
|
)
|
|
self.output.add_output_instructions([copy.copy(inst)])
|
|
self.popn(2)
|
|
self.output.add_output_instructions(
|
|
self.create_call_resume_at(self.next_instruction)
|
|
)
|
|
|
|
def DELETE_ATTR(self, inst):
|
|
obj = self.pop()
|
|
BuiltinVariable(delattr).call_function(
|
|
self, # type: ignore[arg-type]
|
|
[obj, ConstantVariable.create(inst.argval)],
|
|
{},
|
|
)
|
|
|
|
def create_call_resume_at(self, offset):
|
|
raise AssertionError(
|
|
f"create_call_resume_at not overridden by subclass {type(self)}"
|
|
)
|
|
|
|
def should_compile_partial_graph(self) -> bool:
|
|
raise AssertionError(
|
|
f"should_compile_partial_graph not overridden by subclass {type(self)}"
|
|
)
|
|
|
|
@break_graph_if_unsupported(push=0)
|
|
def STORE_SUBSCR(self, inst):
|
|
val, obj, key = self.popn(3)
|
|
obj.call_method(self, "__setitem__", [key, val], {})
|
|
|
|
def DELETE_SUBSCR(self, inst):
|
|
obj, key = self.popn(2)
|
|
obj.call_method(self, "__delitem__", [key], {})
|
|
|
|
def BUILD_TUPLE(self, inst):
|
|
items = self.popn(inst.argval)
|
|
self.push(TupleVariable(items))
|
|
|
|
def BUILD_SLICE(self, inst):
|
|
items = self.popn(inst.argval)
|
|
self.push(SliceVariable(items))
|
|
|
|
def BUILD_LIST(self, inst):
|
|
items = self.popn(inst.argval)
|
|
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
|
|
|
|
def BUILD_SET(self, inst):
|
|
if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
|
|
unimplemented_v2(
|
|
gb_type="missing BUILD_SET handler",
|
|
context="",
|
|
explanation="Missing BUILD_SET bytecode handler (for testing purposes).",
|
|
hints=[],
|
|
)
|
|
items = self.popn(inst.argval)
|
|
new_set = SetVariable(items, mutation_type=ValueMutationNew())
|
|
self.push(new_set)
|
|
|
|
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
|
|
seqs = self.popn(inst.argval)
|
|
items = []
|
|
for seq in seqs:
|
|
try:
|
|
items.extend(seq.force_unpack_var_sequence(self))
|
|
except NotImplementedError:
|
|
unimplemented_v2(
|
|
gb_type="Failed to unpack object for BUILD_LIST_UNPACK",
|
|
context=str(seq),
|
|
explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK "
|
|
"bytecode (`[*x, *y, ...]`).",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
self.push(cls(items, mutation_type=ValueMutationNew()))
|
|
|
|
def BUILD_TUPLE_UNPACK(self, inst):
|
|
self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
|
|
|
|
BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
|
|
|
|
def BUILD_MAP(self, inst):
|
|
items = self.popn(inst.argval * 2)
|
|
d = dict(zip(items[::2], items[1::2]))
|
|
self.push(ConstDictVariable(d, mutation_type=ValueMutationNew()))
|
|
|
|
def BUILD_MAP_UNPACK(self, inst):
|
|
items = self.popn(inst.argval)
|
|
# ensure everything is a dict
|
|
items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
|
|
result = {}
|
|
for x in items:
|
|
assert isinstance(x, ConstDictVariable)
|
|
result.update(x.items)
|
|
self.push(
|
|
ConstDictVariable(
|
|
result,
|
|
mutation_type=ValueMutationNew(),
|
|
)
|
|
)
|
|
|
|
BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK
|
|
|
|
def BUILD_CONST_KEY_MAP(self, inst):
|
|
keys = self.pop()
|
|
values = self.popn(inst.argval)
|
|
assert isinstance(keys, TupleVariable)
|
|
assert keys.is_python_constant()
|
|
|
|
keys = keys.force_unpack_var_sequence(self)
|
|
assert len(keys) == len(values)
|
|
|
|
self.push(
|
|
ConstDictVariable(
|
|
dict(zip(keys, values)),
|
|
mutation_type=ValueMutationNew(),
|
|
)
|
|
)
|
|
|
|
def MAP_ADD(self, inst):
|
|
k, v = self.popn(2)
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg].realize()
|
|
assert isinstance(obj, ConstDictVariable)
|
|
obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type]
|
|
|
|
def SET_ADD(self, inst):
|
|
v = self.pop()
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg]
|
|
assert isinstance(obj, SetVariable)
|
|
assert obj.is_mutable()
|
|
return obj.call_method(self, "add", [v], {})
|
|
|
|
def SET_UPDATE(self, inst):
|
|
v = self.pop()
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg]
|
|
assert isinstance(obj, SetVariable)
|
|
assert obj.is_mutable()
|
|
obj.call_method(self, "update", [v], {})
|
|
|
|
def LIST_APPEND(self, inst):
|
|
v = self.pop()
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg].realize()
|
|
assert isinstance(obj, ListVariable)
|
|
assert obj.is_mutable()
|
|
self.output.side_effects.mutation(obj)
|
|
obj.items.append(v)
|
|
|
|
def MAKE_FUNCTION(self, inst):
|
|
flags = inst.arg
|
|
if sys.version_info < (3, 11):
|
|
fn_name = self.pop()
|
|
code = self.pop()
|
|
if sys.version_info >= (3, 11):
|
|
# MAKE_FUNCTION behavior actually changed in 3.11, see
|
|
# https://github.com/python/cpython/pull/93189/
|
|
assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined]
|
|
fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined]
|
|
defaults = None
|
|
closure = None
|
|
annotations = None
|
|
kwdefaults = None
|
|
|
|
if sys.version_info < (3, 13):
|
|
# in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE
|
|
if flags & 0x08:
|
|
closure = self.pop()
|
|
if flags & 0x04:
|
|
annotations = self.pop()
|
|
if flags & 0x02:
|
|
kwdefaults = self.pop()
|
|
if flags & 0x01:
|
|
defaults = self.pop()
|
|
|
|
self.push(
|
|
NestedUserFunctionVariable(
|
|
fn_name,
|
|
code,
|
|
self.f_globals,
|
|
defaults,
|
|
kwdefaults,
|
|
annotations,
|
|
closure,
|
|
)
|
|
)
|
|
|
|
def UNPACK_SEQUENCE(self, inst):
|
|
seq = self.pop()
|
|
if isinstance(seq, TensorVariable):
|
|
val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type]
|
|
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
|
|
# x, y = a.shape
|
|
proxy = getattr(seq.obj.as_proxy(), seq.name)
|
|
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
|
|
elif seq.has_force_unpack_var_sequence(self):
|
|
val = seq.force_unpack_var_sequence(self)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Failed to unpack object for UNPACK_SEQUENCE",
|
|
context=str(seq),
|
|
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode "
|
|
"(i.e. `a, b, c = d`).",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
if len(val) != inst.argval:
|
|
unimplemented_v2(
|
|
gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
|
|
context=f"expected length: {inst.argval}, actual: {len(val)}",
|
|
explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
|
|
"(i.e. `a, b, c = d`) with unexpected length.",
|
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
|
)
|
|
for i in reversed(val):
|
|
self.push(i)
|
|
|
|
def UNPACK_EX(self, inst):
|
|
assert 0 <= inst.argval <= 0xFFFF
|
|
prefix = inst.argval & 0xFF # low byte
|
|
suffix = inst.argval >> 8 # high byte
|
|
seq = self.pop()
|
|
if seq.has_force_unpack_var_sequence(self):
|
|
vals = list(seq.force_unpack_var_sequence(self))
|
|
assert len(vals) >= prefix + suffix
|
|
vals_prefix = vals[:prefix]
|
|
vals_list = vals[prefix : len(vals) - suffix]
|
|
vals_suffix = vals[len(vals) - suffix :]
|
|
for item in reversed(vals_suffix):
|
|
self.push(item)
|
|
self.push(TupleVariable(vals_list))
|
|
for item in reversed(vals_prefix):
|
|
self.push(item)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Failed to unpack object for UNPACK_EX",
|
|
context=str(seq),
|
|
explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
|
|
def NOP(self, inst):
|
|
pass
|
|
|
|
def POP_TOP(self, inst):
|
|
self.pop()
|
|
|
|
def ROT_TWO(self, inst):
|
|
a = self.pop()
|
|
b = self.pop()
|
|
self.push(a)
|
|
self.push(b)
|
|
|
|
def ROT_THREE(self, inst):
|
|
a = self.pop()
|
|
b = self.pop()
|
|
c = self.pop()
|
|
self.push(a)
|
|
self.push(c)
|
|
self.push(b)
|
|
|
|
def ROT_FOUR(self, inst):
|
|
a = self.pop()
|
|
b = self.pop()
|
|
c = self.pop()
|
|
d = self.pop()
|
|
self.push(a)
|
|
self.push(d)
|
|
self.push(c)
|
|
self.push(b)
|
|
|
|
def DUP_TOP(self, inst):
|
|
a = self.pop()
|
|
self.push(a)
|
|
self.push(a)
|
|
|
|
def DUP_TOP_TWO(self, inst):
|
|
a = self.pop()
|
|
b = self.pop()
|
|
self.push(b)
|
|
self.push(a)
|
|
self.push(b)
|
|
self.push(a)
|
|
|
|
def _convert_value(self, value, flag):
|
|
if flag == 1:
|
|
return BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type]
|
|
elif flag == 2:
|
|
return BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type]
|
|
elif flag == 3:
|
|
return BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type]
|
|
return value
|
|
|
|
def _format_value(self, fmt_spec, flags):
|
|
value = self.pop()
|
|
if isinstance(value, SymNodeVariable):
|
|
from torch._dynamo.variables.lazy import (
|
|
LazySymNodeFormatString,
|
|
LazyVariableTracker,
|
|
)
|
|
|
|
value = LazyVariableTracker.create(
|
|
LazySymNodeFormatString(value, fmt_spec), source=value.source
|
|
)
|
|
self.push(value)
|
|
return
|
|
|
|
value = self._convert_value(value, flags & 0x03)
|
|
|
|
fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
|
|
|
|
self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
|
|
|
|
def FORMAT_VALUE(self, inst):
|
|
flags = inst.arg
|
|
if (flags & 0x04) == 0x04:
|
|
fmt_spec = self.pop()
|
|
else:
|
|
fmt_spec = ConstantVariable.create("")
|
|
|
|
return self._format_value(fmt_spec, flags)
|
|
|
|
def BUILD_STRING(self, inst):
|
|
format_string_parts: list[str] = []
|
|
args: list[VariableTracker] = []
|
|
kwargs: dict[str, VariableTracker] = {}
|
|
for part in self.popn(inst.arg):
|
|
if isinstance(part, ConstantVariable):
|
|
format_string_parts.append("{}")
|
|
args.append(part)
|
|
elif isinstance(part, variables.StringFormatVariable):
|
|
format_string_parts.append(part.format_string)
|
|
args.extend(part.sym_args)
|
|
if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
|
|
unimplemented_v2(
|
|
gb_type="BUILD_STRING key conflict",
|
|
context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}",
|
|
explanation="Failed to build format string due to key conflict",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
kwargs.update(part.sym_kwargs)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="BUILD_STRING type error",
|
|
context=str(part),
|
|
explanation="Format string part type is not correct - expected constant or format string.",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
self.push(
|
|
variables.StringFormatVariable.create(
|
|
"".join(format_string_parts), args, kwargs
|
|
)
|
|
)
|
|
|
|
def IS_OP(self, inst):
|
|
assert inst.argval == 0 or inst.argval == 1
|
|
if inst.argval == 0:
|
|
new_argval = "is"
|
|
else:
|
|
new_argval = "is not"
|
|
new_inst = create_instruction("COMPARE_OP", argval=new_argval)
|
|
self.COMPARE_OP(new_inst)
|
|
|
|
def CONTAINS_OP(self, inst):
|
|
assert inst.argval == 0 or inst.argval == 1
|
|
left, right = self.popn(2)
|
|
op = inst.argval
|
|
self.push(right.call_method(self, "__contains__", [left], {}))
|
|
if op == 1:
|
|
self.UNARY_NOT(inst)
|
|
|
|
def LIST_EXTEND(self, inst):
|
|
v = self.pop()
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg]
|
|
assert isinstance(obj, ListVariable)
|
|
assert obj.is_mutable()
|
|
obj.call_method(self, "extend", [v], {})
|
|
|
|
def LIST_TO_TUPLE(self, inst):
|
|
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
|
|
|
|
def STOPITERATION_ERROR(self, inst):
|
|
# wrap the generator body in a try: ... except StopIteration: ... which
|
|
# converts the StopIteration into a RuntimeError
|
|
# https://peps.python.org/pep-0479/
|
|
# https://github.com/python/cpython/pull/99006
|
|
# https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2
|
|
val = self.stack[-1]
|
|
assert self._isinstance_exception(val)
|
|
if val.exc_type is StopIteration: # type: ignore[attr-defined]
|
|
new_val = variables.BuiltinVariable(RuntimeError).call_function(
|
|
self, # type: ignore[arg-type]
|
|
[],
|
|
{},
|
|
)
|
|
self.stack[-1] = new_val
|
|
|
|
def DICT_MERGE(self, inst):
|
|
v = self.pop()
|
|
assert inst.argval > 0
|
|
obj = self.stack[-inst.arg].realize()
|
|
assert isinstance(obj, ConstDictVariable)
|
|
assert obj.is_mutable()
|
|
obj.call_method(self, "update", [v], {})
|
|
|
|
DICT_UPDATE = DICT_MERGE
|
|
|
|
def GEN_START(self, inst):
|
|
self.pop()
|
|
|
|
def GET_LEN(self, inst):
|
|
tos = self.stack[-1]
|
|
if tos.is_python_constant():
|
|
self.push(ConstantVariable.create(len(tos.as_python_constant())))
|
|
else:
|
|
self.push(tos.call_method(self, "__len__", [], {}))
|
|
|
|
def MATCH_MAPPING(self, inst):
|
|
tos = self.stack[-1]
|
|
assert isinstance(tos, ConstDictVariable)
|
|
if isinstance(tos.items, collections.abc.Mapping):
|
|
self.push(ConstantVariable.create(True))
|
|
else:
|
|
self.push(ConstantVariable.create(False))
|
|
|
|
def MATCH_SEQUENCE(self, inst):
|
|
tos = self.stack[-1]
|
|
assert tos.is_python_constant()
|
|
tos_value = tos.as_python_constant()
|
|
if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
|
|
tos_value, (str, bytes, bytearray)
|
|
):
|
|
self.push(ConstantVariable.create(True))
|
|
else:
|
|
self.push(ConstantVariable.create(False))
|
|
|
|
def MATCH_KEYS(self, inst):
|
|
tos = self.stack[-1]
|
|
tos1 = self.stack[-2]
|
|
assert isinstance(tos1, ConstDictVariable)
|
|
|
|
if all(k in tos1 for k in tos): # type: ignore[attr-defined]
|
|
self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type]
|
|
if sys.version_info < (3, 11):
|
|
self.push(ConstantVariable.create(True))
|
|
else:
|
|
self.push(ConstantVariable.create(None))
|
|
if sys.version_info < (3, 11):
|
|
self.push(ConstantVariable.create(False))
|
|
|
|
def LOAD_ASSERTION_ERROR(self, inst):
|
|
self.load_builtin_from_argval("AssertionError")
|
|
|
|
UNARY_POSITIVE = stack_op(operator.pos)
|
|
UNARY_NEGATIVE = stack_op(operator.neg)
|
|
UNARY_NOT = stack_op(operator.not_)
|
|
UNARY_INVERT = stack_op(operator.invert)
|
|
|
|
BINARY_POWER = stack_op(operator.pow)
|
|
BINARY_MULTIPLY = stack_op(operator.mul)
|
|
BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
|
|
BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
|
|
BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
|
|
BINARY_MODULO = stack_op(operator.mod)
|
|
BINARY_REMAINDER = stack_op(operator.mod)
|
|
BINARY_ADD = stack_op(operator.add)
|
|
BINARY_SUBTRACT = stack_op(operator.sub)
|
|
BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
|
|
BINARY_LSHIFT = stack_op(operator.lshift)
|
|
BINARY_RSHIFT = stack_op(operator.rshift)
|
|
BINARY_AND = stack_op(operator.and_)
|
|
BINARY_OR = stack_op(operator.or_)
|
|
BINARY_XOR = stack_op(operator.xor)
|
|
|
|
INPLACE_POWER = stack_op(operator.ipow)
|
|
INPLACE_MULTIPLY = stack_op(operator.imul)
|
|
INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
|
|
INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
|
|
INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
|
|
INPLACE_MODULO = stack_op(operator.imod)
|
|
INPLACE_REMAINDER = stack_op(operator.imod)
|
|
INPLACE_ADD = stack_op(operator.iadd)
|
|
INPLACE_SUBTRACT = stack_op(operator.isub)
|
|
INPLACE_LSHIFT = stack_op(operator.ilshift)
|
|
INPLACE_RSHIFT = stack_op(operator.irshift)
|
|
INPLACE_AND = stack_op(operator.iand)
|
|
INPLACE_XOR = stack_op(operator.ixor)
|
|
INPLACE_OR = stack_op(operator.ior)
|
|
|
|
# 3.11 opcodes
|
|
def RESUME(self, inst):
|
|
if inst.arg == 0:
|
|
self.append_prefix_inst(inst)
|
|
self.accept_prefix_inst = False
|
|
else:
|
|
assert not self.accept_prefix_inst
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
def BINARY_OP(self, inst):
|
|
return _binary_op_lookup[inst.arg](self, inst)
|
|
|
|
def PRECALL(self, inst):
|
|
pass
|
|
|
|
def KW_NAMES(self, inst):
|
|
kw_names = self.code_options["co_consts"][inst.arg]
|
|
assert isinstance(kw_names, tuple)
|
|
for name in kw_names:
|
|
assert isinstance(name, str)
|
|
assert self.kw_names is None
|
|
self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment]
|
|
|
|
def PUSH_NULL(self, inst):
|
|
self.push(NullVariable())
|
|
|
|
def _call(self, inst, call_kw=False):
|
|
# see https://docs.python.org/3.11/library/dis.html#opcode-CALL
|
|
# for convention
|
|
if call_kw:
|
|
# TOS is kw_names for CALL_KW instruction
|
|
assert sys.version_info >= (3, 13)
|
|
kw_names = self.pop()
|
|
assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant()
|
|
kw_names = kw_names.as_python_constant()
|
|
else:
|
|
kw_names = self.kw_names.value if self.kw_names else ()
|
|
|
|
contents = self.popn(inst.arg + 2)
|
|
if sys.version_info >= (3, 13):
|
|
# NULL and callable swapped
|
|
fn = contents[0]
|
|
args = [] if isinstance(contents[1], NullVariable) else [contents[1]]
|
|
else:
|
|
if isinstance(contents[0], NullVariable):
|
|
fn = contents[1]
|
|
args = []
|
|
else:
|
|
fn = contents[0]
|
|
args = [contents[1]]
|
|
|
|
if kw_names:
|
|
args = args + contents[2 : -len(kw_names)]
|
|
kwargs_list = contents[-len(kw_names) :]
|
|
kwargs = dict(zip(kw_names, kwargs_list))
|
|
assert len(kwargs) == len(kw_names)
|
|
else:
|
|
args = args + contents[2:]
|
|
kwargs = {}
|
|
|
|
try:
|
|
# if call_function fails, need to set kw_names to None, otherwise
|
|
# a subsequent call may have self.kw_names set to an old value
|
|
self.call_function(fn, args, kwargs)
|
|
finally:
|
|
self.kw_names = None
|
|
|
|
@break_graph_if_unsupported(push=1)
|
|
def CALL(self, inst):
|
|
self._call(inst)
|
|
|
|
def COPY(self, inst):
|
|
self.push(self.stack[-inst.arg])
|
|
|
|
def SWAP(self, inst):
|
|
self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
|
|
|
|
JUMP_BACKWARD = jump
|
|
JUMP_BACKWARD_NO_INTERRUPT = jump
|
|
|
|
POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
|
|
POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
|
|
POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
|
|
POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
|
|
|
|
def CACHE(self, inst):
|
|
pass
|
|
|
|
def BEFORE_WITH(self, inst):
|
|
self.setup_or_before_with(inst)
|
|
|
|
def setup_or_before_with(self, inst):
|
|
ctx = self.pop()
|
|
if not isinstance(
|
|
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Unsupported context manager",
|
|
context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}",
|
|
explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.",
|
|
hints=[
|
|
"Avoid using the unsupported context manager.",
|
|
"File an issue to PyTorch. Simple context managers can potentially be supported, "
|
|
"but note that context managers can't be supported in general",
|
|
],
|
|
)
|
|
|
|
if (
|
|
isinstance(ctx, GenericContextWrappingVariable)
|
|
and not ctx.supports_graph_breaks()
|
|
):
|
|
self.active_generic_context_managers.append(ctx)
|
|
|
|
# Need this redundant check for mypy
|
|
assert isinstance(
|
|
ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
|
|
)
|
|
|
|
exit = WithExitFunctionVariable(
|
|
ctx,
|
|
inst.target,
|
|
)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
# See create_call_resume_at for block stack details.
|
|
# Only push a block if the current instruction's block is a
|
|
# with block that is not nested in a try block - that is, the current
|
|
# instruction's block target is the same as the top block's target.
|
|
if inst.exn_tab_entry and (
|
|
not self.block_stack
|
|
or inst.exn_tab_entry.target is not self.block_stack[-1].target
|
|
):
|
|
target = None
|
|
else:
|
|
target = self.next_instruction.exn_tab_entry.target
|
|
else:
|
|
target = inst.target
|
|
|
|
self.push(exit)
|
|
|
|
if target:
|
|
if isinstance(self, InstructionTranslator):
|
|
self.block_stack.append(
|
|
BlockStackEntry(inst, target, len(self.stack), ctx)
|
|
)
|
|
else:
|
|
self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
|
|
|
|
self.push(ctx.enter(self))
|
|
|
|
def append_prefix_inst(self, inst):
|
|
assert self.accept_prefix_inst
|
|
self.prefix_insts.append(inst)
|
|
|
|
def MAKE_CELL(self, inst):
|
|
if sys.version_info >= (3, 12) and not self.accept_prefix_inst:
|
|
# In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction.
|
|
# It can be generated by inlined comprehensions.
|
|
assert isinstance(self.symbolic_locals[inst.argval], NullVariable)
|
|
self.symbolic_locals[inst.argval] = (
|
|
self.output.side_effects.track_cell_new()
|
|
)
|
|
else:
|
|
self.append_prefix_inst(inst)
|
|
|
|
def COPY_FREE_VARS(self, inst):
|
|
self.append_prefix_inst(inst)
|
|
|
|
def RETURN_GENERATOR(self, inst):
|
|
self.append_prefix_inst(inst)
|
|
|
|
# 3.12 opcodes
|
|
# BINARY/STORE_SLICE opcodes are broken down into
|
|
# BUILD_SLICE 2 and BINARY/STORE_SUBSCR
|
|
|
|
def END_FOR(self, inst):
|
|
if sys.version_info >= (3, 13):
|
|
self.pop()
|
|
else:
|
|
self.popn(2)
|
|
|
|
def LOAD_FAST_CHECK(self, inst):
|
|
if isinstance(self.symbolic_locals[inst.argval], NullVariable):
|
|
unimplemented_v2(
|
|
gb_type="LOAD_FAST_CHECK on uninitialized variable",
|
|
context=inst.argval,
|
|
explanation=f"Attempted to load uninitialized local variable {inst.argval}",
|
|
hints=[*graph_break_hints.USER_ERROR],
|
|
)
|
|
self.LOAD_FAST(inst)
|
|
|
|
def LOAD_FAST_AND_CLEAR(self, inst):
|
|
if inst.argval not in self.symbolic_locals:
|
|
self.push(NullVariable())
|
|
else:
|
|
self.LOAD_FAST(inst)
|
|
self.symbolic_locals[inst.argval] = NullVariable()
|
|
|
|
def LOAD_SUPER_ATTR(self, inst):
|
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
|
|
if inst.arg & 1:
|
|
self.LOAD_METHOD(inst)
|
|
else:
|
|
self._load_attr(inst)
|
|
|
|
def CALL_INTRINSIC_1(self, inst):
|
|
if inst.argval == 3:
|
|
# INTRINSIC_STOPITERATION_ERROR
|
|
self.STOPITERATION_ERROR(inst)
|
|
elif inst.argval == 5:
|
|
# INTRINSIC_UNARY_POSITIVE
|
|
self.UNARY_POSITIVE(inst)
|
|
elif inst.argval == 6:
|
|
# INTRINSIC_LIST_TO_TUPLE
|
|
self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="Missing CALL_INTRINSIC_1 handler",
|
|
context=f"CALL_INTRINSIC_1 operand: {inst.argval}",
|
|
explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.",
|
|
hints=[*graph_break_hints.SUPPORTABLE],
|
|
)
|
|
|
|
def END_SEND(self, inst):
|
|
tos = self.pop()
|
|
self.pop()
|
|
self.push(tos)
|
|
|
|
# 3.13 opcodes
|
|
# fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST
|
|
# are broken down.
|
|
@break_graph_if_unsupported(push=1)
|
|
def CALL_KW(self, inst):
|
|
self._call(inst, call_kw=True)
|
|
|
|
def TO_BOOL(self, inst):
|
|
# TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython)
|
|
# So we can skip this instruction as long as we remember to codegen a TO_BOOL
|
|
# before conditional jumps/UNARY_NOT.
|
|
assert self.next_instruction.opname in (
|
|
"POP_JUMP_IF_TRUE",
|
|
"POP_JUMP_IF_FALSE",
|
|
"UNARY_NOT",
|
|
)
|
|
|
|
def SET_FUNCTION_ATTRIBUTE(self, inst):
|
|
flags = inst.arg
|
|
fn = self.pop()
|
|
assert isinstance(fn, NestedUserFunctionVariable)
|
|
attr = self.pop()
|
|
|
|
if flags & 0x08:
|
|
fn.closure = attr
|
|
elif flags & 0x04:
|
|
fn.annotations = attr
|
|
elif flags & 0x02:
|
|
fn.kwdefaults = attr
|
|
elif flags & 0x01:
|
|
fn.defaults = attr
|
|
|
|
self.push(fn)
|
|
|
|
def CONVERT_VALUE(self, inst):
|
|
self.push(self._convert_value(self.pop(), inst.argval))
|
|
|
|
def FORMAT_SIMPLE(self, inst):
|
|
self._format_value(ConstantVariable.create(""), 0)
|
|
|
|
def FORMAT_WITH_SPEC(self, inst):
|
|
self._format_value(self.pop(), 0)
|
|
|
|
def is_non_empty_graph(self):
|
|
if self.output.count_calls() > 1:
|
|
# perf optimization only
|
|
self.is_non_empty_graph = lambda: True # type: ignore[method-assign]
|
|
return True
|
|
return False
|
|
|
|
def format_frame_summary(self, additional_stack_frames=None):
|
|
if additional_stack_frames is None:
|
|
additional_stack_frames = []
|
|
return "".join(
|
|
traceback.format_list(
|
|
[self.frame_summary()] + list(reversed(additional_stack_frames))
|
|
)
|
|
)
|
|
|
|
def frame_summary(self):
|
|
return traceback.FrameSummary(
|
|
getattr(self.f_code, "co_filename", "<unknown>"),
|
|
self.lineno,
|
|
getattr(self.f_code, "co_name", "<unknown>"),
|
|
lookup_line=False,
|
|
)
|
|
|
|
def is_co_filename_from_nn_modules(self):
|
|
filename = getattr(self.f_code, "co_filename", "<unknown>")
|
|
nn_modules_pattern = re.compile(r".*torch/nn/modules.*")
|
|
return nn_modules_pattern.match(filename) is not None
|
|
|
|
def store_global_weakref_by_id(self, prefix, value):
|
|
global_name = self.output.install_global_by_id(prefix, weakref.ref(value))
|
|
install_guard(
|
|
GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE)
|
|
)
|
|
return global_name
|
|
|
|
@property
|
|
def fake_mode(self):
|
|
return self.output.tracing_context.fake_mode
|
|
|
|
@contextlib.contextmanager
|
|
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]):
|
|
"""
|
|
Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node).
|
|
"""
|
|
prior = self.strict_checks_fn
|
|
self.strict_checks_fn = check_fn
|
|
try:
|
|
yield
|
|
finally:
|
|
self.strict_checks_fn = prior
|
|
|
|
def speculate(self) -> SpeculationEntry:
|
|
assert self.instruction_pointer is not None
|
|
assert self.instruction_pointer > 0
|
|
return self.speculation_log.next(
|
|
self.f_code.co_filename,
|
|
self.lineno,
|
|
self.instruction_pointer - 1,
|
|
self.instructions[self.instruction_pointer - 1],
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
output: OutputGraph,
|
|
instructions: list[Instruction],
|
|
f_locals: dict[str, Any],
|
|
f_globals: dict[str, Any],
|
|
f_builtins: dict[str, Any],
|
|
code_options: dict[str, Any],
|
|
symbolic_locals: dict[str, VariableTracker],
|
|
symbolic_globals: dict[str, VariableTracker],
|
|
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
|
f_code: types.CodeType,
|
|
export: bool,
|
|
inline_depth: int,
|
|
speculation_log: SpeculationLog,
|
|
exn_vt_stack: ExceptionStack,
|
|
distributed_state: Optional[DistributedState],
|
|
# This determines whether to use the execution recorder.
|
|
closure: Optional[tuple[types.CellType]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.speculation_log = speculation_log
|
|
self.distributed_state = distributed_state
|
|
|
|
# Mutable state checkpointed by copy_graphstate()
|
|
self.output = output
|
|
self.symbolic_locals = symbolic_locals
|
|
self.symbolic_globals = symbolic_globals
|
|
self.symbolic_torch_function_state = symbolic_torch_function_state
|
|
self.stack = []
|
|
self.instruction_pointer = 0
|
|
self.start_point = None
|
|
self.current_instruction = create_instruction("NOP")
|
|
self.block_stack = []
|
|
# states before SETUP_WITH for checkpointing and fallback
|
|
self.active_generic_context_managers: list[GenericContextWrappingVariable] = []
|
|
self.lineno = -1
|
|
self.kw_names = None
|
|
self.accept_prefix_inst = True
|
|
self.prefix_insts = []
|
|
self.exn_vt_stack = exn_vt_stack
|
|
|
|
# Properties of the input/output code
|
|
self.instructions: list[Instruction] = instructions
|
|
self.indexof: dict[Instruction, int] = get_indexof(self.instructions)
|
|
self.f_locals: dict[str, Any] = (
|
|
f_locals # needed for recording accessed locals for replay
|
|
)
|
|
self.f_globals: dict[str, Any] = f_globals
|
|
self.f_builtins: dict[str, Any] = f_builtins
|
|
self.code_options: dict[str, Any] = code_options
|
|
self.f_code: types.CodeType = f_code
|
|
|
|
# Execution record for replaying errors
|
|
if closure is not None and config.replay_record_enabled:
|
|
self.exec_recorder = ExecutionRecorder(
|
|
code=f_code, closure=closure, code_options=code_options
|
|
)
|
|
else:
|
|
self.exec_recorder = None
|
|
# Stack of module being parsed, current nn.module is at the end of ordered dict.
|
|
# The first field of tuple is the fully qualified name of current module
|
|
# in original hierarchy. The second field is the type of current nn.module
|
|
self.nn_module_stack: dict[str, tuple[str, type[Any]]] = {}
|
|
self.num_calls: dict[str, int] = {}
|
|
# Flag to indicate whether tracing is used for export.
|
|
self.export = export
|
|
self.one_graph = False
|
|
|
|
self.current_speculation = None
|
|
|
|
self.strict_checks_fn = None
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from .resume_execution import (
|
|
CO_ASYNC_GENERATOR,
|
|
CO_COROUTINE,
|
|
CO_GENERATOR,
|
|
CO_ITERABLE_COROUTINE,
|
|
)
|
|
|
|
if f_code.co_flags & (
|
|
CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
|
|
):
|
|
self.push(BuiltinVariable(None))
|
|
|
|
self.inline_depth = inline_depth
|
|
self.inconsistent_side_effects = False
|
|
self._constants_cache: list[Optional[VariableTracker]] = [None] * len(
|
|
f_code.co_consts
|
|
)
|
|
linecache.lazycache(f_code.co_filename, f_globals)
|
|
|
|
|
|
class InstructionTranslator(InstructionTranslatorBase):
|
|
@staticmethod
|
|
def current_tx() -> "InstructionTranslator":
|
|
return tls.current_tx
|
|
|
|
@contextlib.contextmanager
|
|
def set_current_tx(self):
|
|
prior = getattr(tls, "current_tx", None)
|
|
tls.current_tx = self
|
|
try:
|
|
yield
|
|
finally:
|
|
tls.current_tx = prior
|
|
|
|
def __init__(
|
|
self,
|
|
instructions: list[Instruction],
|
|
f_code,
|
|
f_locals,
|
|
f_globals,
|
|
f_builtins,
|
|
closure,
|
|
torch_function_mode_stack,
|
|
code_options,
|
|
compiler_fn,
|
|
one_graph,
|
|
export,
|
|
export_constraints,
|
|
frame_state,
|
|
speculation_log: SpeculationLog,
|
|
exn_vt_stack: ExceptionStack,
|
|
distributed_state: Optional[DistributedState],
|
|
) -> None:
|
|
_step_logger()(
|
|
logging.INFO,
|
|
f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}",
|
|
)
|
|
super().__init__(
|
|
output=OutputGraph(
|
|
code_options,
|
|
compiler_fn,
|
|
self,
|
|
export,
|
|
export_constraints,
|
|
frame_state,
|
|
local_scope=f_locals,
|
|
global_scope=f_globals,
|
|
f_code=f_code,
|
|
torch_function_mode_stack=torch_function_mode_stack,
|
|
),
|
|
instructions=instructions,
|
|
f_locals=f_locals,
|
|
f_globals=f_globals,
|
|
f_builtins=f_builtins,
|
|
closure=closure,
|
|
code_options=code_options,
|
|
symbolic_locals={}, # set below
|
|
# A global var is inserted only after a STORE_GLOBAL happens to it
|
|
symbolic_globals={},
|
|
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
|
|
f_code=f_code,
|
|
export=export,
|
|
inline_depth=0,
|
|
speculation_log=speculation_log,
|
|
exn_vt_stack=exn_vt_stack,
|
|
distributed_state=distributed_state,
|
|
)
|
|
|
|
self._throw_if_in_functorch()
|
|
|
|
# as soon as we create the tracing context we should keep it active, so any calls
|
|
# into dynamo apis can rely on finding it
|
|
with tracing(self.output.tracing_context), self.set_current_tx():
|
|
self.one_graph: bool = one_graph
|
|
self.export = export
|
|
if self.export:
|
|
assert self.one_graph, (
|
|
"Export without one graph - something has gone wrong."
|
|
)
|
|
|
|
self.symbolic_locals = {}
|
|
# Populate `symbolic_locals` with non-cell variables.
|
|
cell_and_freevars: set[str] = set(self.cell_and_freevars())
|
|
|
|
dynamism = code_context.get_context(f_code).get("dynamism", None)
|
|
for name, value in f_locals.items():
|
|
if name not in cell_and_freevars:
|
|
local_dynamism = None
|
|
if dynamism:
|
|
local_dynamism = frozenset(dynamism.get(name, {}).items())
|
|
var = LazyVariableTracker.create(
|
|
value,
|
|
LocalSource(
|
|
name,
|
|
is_input=True,
|
|
dynamism=local_dynamism,
|
|
),
|
|
)
|
|
self.symbolic_locals[name] = var
|
|
|
|
# Populate `symbolic_locals` with cells created by this frame,
|
|
# effectively implementing the `MAKE_CELL` instructions.
|
|
side_effects = self.output.side_effects
|
|
for name in self.cellvars():
|
|
if name in f_locals:
|
|
# This models cells that are also function inputs.
|
|
value = f_locals[name]
|
|
# NOTE: root frame inputs that are captured by a nested
|
|
# function become special cell objects -- they exist in
|
|
# `f_locals` as contents of the cells, rather than the cells
|
|
# objects themselves.
|
|
#
|
|
# In Dynamo, we choose to represent such input cell objects
|
|
# as newly created (rather than pre-existing) cell objects,
|
|
# because
|
|
#
|
|
# 1. The reason for representing a pre-existing cell object
|
|
# is to emit guard or codegen mutations. However, local
|
|
# cells should never be used for guards. Moreover, at this
|
|
# point these input cell objects should've never been
|
|
# accessed by anyone else, since Dynamo intercepts the frame
|
|
# right after its evaluation starts, i.e., right after these
|
|
# cell objects are created. So they should have no external
|
|
# reference, meaning no mutation needs to be propagated.
|
|
#
|
|
# 2. This conveniently allows codegen to prune away
|
|
# mutations to these cells, unless they escape the frame.
|
|
contents_source = LocalSource(
|
|
name, is_input=True, is_derefed_cell_contents=True
|
|
)
|
|
contents_var: VariableTracker = LazyVariableTracker.create(
|
|
value, contents_source
|
|
)
|
|
cell_var = side_effects.track_cell_new()
|
|
side_effects.store_cell(cell_var, contents_var)
|
|
else:
|
|
cell_var = side_effects.track_cell_new()
|
|
cell_var.local_name = name
|
|
self.symbolic_locals[name] = cell_var
|
|
|
|
# Populate `symbolic_locals` with cells captured by this frame,
|
|
# effectively implementing the `COPY_FREE_VARS` instruction.
|
|
for name, cell in zip(self.freevars(), closure):
|
|
cell_source = LocalCellSource(name)
|
|
contents_source = LocalSource(name, is_derefed_cell_contents=True)
|
|
try:
|
|
contents_var = LazyVariableTracker.create(
|
|
cell.cell_contents, contents_source
|
|
)
|
|
except ValueError:
|
|
# Cell has not yet been assigned
|
|
contents_var = variables.DeletedVariable()
|
|
cell_var = side_effects.track_cell_existing(
|
|
cell_source, cell, contents_var
|
|
)
|
|
cell_var.local_name = name
|
|
self.symbolic_locals[name] = cell_var
|
|
|
|
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
|
|
torch_function_mode_stack
|
|
)
|
|
|
|
self.debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] = []
|
|
if export:
|
|
# export gets confused if we never realize unused inputs
|
|
# in export mode just eagerly realize everything
|
|
self.symbolic_locals = variables.LazyVariableTracker.realize_all(
|
|
self.symbolic_locals
|
|
)
|
|
|
|
def _throw_if_in_functorch(self):
|
|
# Fallback to eager in case of a graph break inside vmap
|
|
eager = torch._dynamo.lookup_backend("eager")
|
|
compiler_fn = inspect.getattr_static(
|
|
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
|
|
)
|
|
ci = torch._C._functorch.peek_interpreter_stack()
|
|
forbidden_keys = (
|
|
torch._C._functorch.TransformType.Vmap,
|
|
torch._C._functorch.TransformType.Grad,
|
|
torch._C._functorch.TransformType.Jvp,
|
|
)
|
|
|
|
if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
|
|
name = ci.key().name.lower()
|
|
msg = (
|
|
"If you are reaching here, it means dynamo failed for one of the following reasons:\n"
|
|
# Calling a torch.compiled function
|
|
f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. "
|
|
f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. "
|
|
"For more information, see PyTorch issue #128711.\n"
|
|
# if it reaches here, it means Dynamo failed to inline a functorch function
|
|
f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
|
|
)
|
|
unimplemented_v2(
|
|
gb_type="Unsupported functorch tracing attempt",
|
|
context="",
|
|
explanation=msg,
|
|
hints=[],
|
|
)
|
|
|
|
def get_example_value(self, source: Source):
|
|
if isinstance(source, LocalSource):
|
|
return self.f_locals[source.local_name]
|
|
if isinstance(source, GlobalSource):
|
|
return self.f_globals[source.global_name]
|
|
raise KeyError
|
|
|
|
def run(self):
|
|
super().run()
|
|
|
|
def should_compile_partial_graph(self):
|
|
if sys.version_info >= (3, 11):
|
|
# Do not compile if current instruction's block is not the top with block
|
|
entry = self.current_instruction.exn_tab_entry
|
|
if entry and (
|
|
not self.block_stack or entry.target is not self.block_stack[-1].target
|
|
):
|
|
return False
|
|
return (
|
|
all(b.can_restore() for b in self.block_stack)
|
|
and not self.one_graph
|
|
and not self.active_generic_context_managers
|
|
)
|
|
|
|
def create_call_resume_at(self, inst):
|
|
self.instruction_pointer = None
|
|
|
|
if inst.opname == "RETURN_VALUE":
|
|
return [create_instruction("RETURN_VALUE")]
|
|
elif inst.opname == "RETURN_CONST":
|
|
return [create_instruction("RETURN_CONST", argval=inst.argval)]
|
|
|
|
reads = livevars_analysis(self.instructions, inst)
|
|
all_argnames = tuple(
|
|
k
|
|
for k in self.symbolic_locals.keys()
|
|
if k in reads and k not in self.cell_and_freevars()
|
|
)
|
|
# NOTE: do not use isinstance, since it realizes lazy VT's
|
|
argnames = tuple(
|
|
k
|
|
for k in all_argnames
|
|
if not type.__instancecheck__(NullVariable, self.symbolic_locals[k])
|
|
)
|
|
argnames_null = tuple(
|
|
k
|
|
for k in all_argnames
|
|
if type.__instancecheck__(NullVariable, self.symbolic_locals[k])
|
|
)
|
|
if sys.version_info < (3, 12):
|
|
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
|
|
|
cg = PyCodegen(self)
|
|
|
|
# Handle inactive context variables.
|
|
# The resume function assumes that context variables are the class, NOT the object.
|
|
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
|
stack_ctx_vars = []
|
|
for i, var in enumerate(self.stack):
|
|
if type.__instancecheck__(ContextWrappingVariable, var):
|
|
ctx = cast(ContextWrappingVariable, var)
|
|
target_values = (
|
|
() if ctx.target_values is None else tuple(ctx.target_values)
|
|
)
|
|
stack_ctx_vars.append((i, target_values))
|
|
# Replace the current stack var with the context class
|
|
ctx.reconstruct_type(cg)
|
|
cg.extend_output(create_swap(len(self.stack) - i + 1))
|
|
cg.append_output(create_instruction("POP_TOP"))
|
|
|
|
argnames_ctx_vars = []
|
|
for name in argnames:
|
|
if type.__instancecheck__(
|
|
ContextWrappingVariable, var := self.symbolic_locals[name]
|
|
):
|
|
ctx = cast(ContextWrappingVariable, var)
|
|
target_values = (
|
|
() if ctx.target_values is None else tuple(ctx.target_values)
|
|
)
|
|
argnames_ctx_vars.append((name, target_values))
|
|
# Replace the local with the context class
|
|
ctx.reconstruct_type(cg)
|
|
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
|
|
|
# Python does not allow null to be an arg to a function, so
|
|
# we remove nulls from the stack and restore them in the
|
|
# prologue of the resume function
|
|
|
|
# sorted list of indices of nulls on the stack
|
|
null_idxes: list[int] = []
|
|
if sys.version_info >= (3, 11):
|
|
# find indices of NullVariables
|
|
for i, var in enumerate(self.stack):
|
|
if type.__instancecheck__(NullVariable, var):
|
|
null_idxes.append(i)
|
|
# generate bytecode to pop the nulls
|
|
null_cnt = 0
|
|
for i, var in enumerate(reversed(self.stack)):
|
|
if type.__instancecheck__(NullVariable, var):
|
|
for j in range(2, i + 2 - null_cnt):
|
|
cg.append_output(create_instruction("SWAP", arg=j))
|
|
cg.extend_output(cg.pop_null())
|
|
null_cnt += 1
|
|
|
|
# we popped all nulls from the stack at runtime,
|
|
# so we should not count NullVariables
|
|
stack_len = len(self.stack) - len(null_idxes)
|
|
nargs = stack_len + len(argnames)
|
|
|
|
name = unique_id(f"__resume_at_{inst.offset}")
|
|
|
|
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
|
self.f_code,
|
|
self.lineno,
|
|
inst.offset,
|
|
tuple(b.target.offset for b in self.block_stack),
|
|
stack_len,
|
|
argnames,
|
|
argnames_null,
|
|
tuple(b.resume_fn() for b in self.block_stack),
|
|
tuple(stack_ctx_vars),
|
|
tuple(argnames_ctx_vars),
|
|
tuple(null_idxes),
|
|
)
|
|
|
|
# Add original GraphModule context to the resume function to handle
|
|
# the case of a graph break while tracing a GraphModule
|
|
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
|
"orig_graphmodule", lambda: None
|
|
)()
|
|
if orig_graphmodule_maybe is not None:
|
|
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
|
orig_graphmodule_maybe
|
|
)
|
|
|
|
if new_code.co_freevars:
|
|
# expose code object for debugging purposes
|
|
self.output.install_global_unsafe(name, new_code)
|
|
cg.make_function_with_closure(name, new_code, True, stack_len)
|
|
else:
|
|
# This is safe: we pre-generate a unique name
|
|
self.output.install_global_unsafe(
|
|
name, types.FunctionType(new_code, self.f_globals, name)
|
|
)
|
|
cg.extend_output(cg.load_function_name(name, True, stack_len))
|
|
|
|
cg.extend_output([cg.create_load(k) for k in argnames])
|
|
cg.extend_output(create_call_function(nargs, False))
|
|
cg.append_output(create_instruction("RETURN_VALUE"))
|
|
return cg.get_instructions()
|
|
|
|
def symbolic_locals_contain_module_class(self):
|
|
for v in self.symbolic_locals.values():
|
|
if isinstance(v, UserDefinedClassVariable) and issubclass(
|
|
v.as_python_constant(), torch.nn.Module
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def replace_tos_if_return_is_generator(self):
|
|
if (
|
|
len(self.stack)
|
|
and (tos := self.stack[-1])
|
|
and isinstance(tos, LocalGeneratorObjectVariable)
|
|
):
|
|
self.stack[-1] = ListIteratorVariable(
|
|
tos.force_unpack_var_sequence(self),
|
|
mutation_type=ValueMutationNew(),
|
|
)
|
|
|
|
def _return(self, inst):
|
|
self.replace_tos_if_return_is_generator()
|
|
assert self.instruction_pointer is not None
|
|
assert self.start_point is not None
|
|
get_metrics_context().increment(
|
|
"ir_count", self.instruction_pointer - self.start_point
|
|
)
|
|
|
|
if (
|
|
not config.allow_empty_graphs
|
|
and self.output.count_calls() == 0
|
|
and not self.inconsistent_side_effects
|
|
and not self.symbolic_locals_contain_module_class()
|
|
and not self.export
|
|
and not self.one_graph
|
|
):
|
|
raise exc.SkipFrame("because no content in function call")
|
|
|
|
self.instruction_pointer = None
|
|
_step_logger()(
|
|
logging.INFO,
|
|
f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
|
|
)
|
|
log.debug("%s triggered compile", inst.opname)
|
|
self.output.compile_subgraph(
|
|
self,
|
|
reason=GraphCompileReason(
|
|
"return_value", [self.frame_summary()], graph_break=False
|
|
),
|
|
)
|
|
return_inst = (
|
|
create_instruction("RETURN_VALUE")
|
|
if inst.opname == "RETURN_VALUE"
|
|
else create_instruction("RETURN_CONST", argval=inst.argval)
|
|
)
|
|
self.output.add_output_instructions([return_inst])
|
|
raise ReturnValueOp
|
|
|
|
def RETURN_VALUE(self, inst):
|
|
self._return(inst)
|
|
|
|
def RETURN_CONST(self, inst):
|
|
self._return(inst)
|
|
|
|
|
|
if sys.version_info >= (3, 11):
|
|
_binary_op_lookup = [
|
|
getattr(
|
|
InstructionTranslator,
|
|
opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}",
|
|
)
|
|
for opname, _ in dis._nb_ops # type: ignore[attr-defined]
|
|
]
|
|
|
|
|
|
class InliningInstructionTranslator(InstructionTranslatorBase):
|
|
"""Trace and inline a called method"""
|
|
|
|
symbolic_result: Optional[VariableTracker]
|
|
|
|
@classmethod
|
|
def inline_call(cls, parent, func, args, kwargs):
|
|
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
|
tracer = cls.build_inline_tracer(parent, func, args, kwargs)
|
|
return tracer.inline_call_()
|
|
|
|
@staticmethod
|
|
def check_inlineable(func):
|
|
if func.has_self():
|
|
unimplemented_v2(
|
|
gb_type="Inline attempt with __self__",
|
|
context=str(func),
|
|
explanation="Attempted to inline a function with the `__self__` attribute. "
|
|
"Dynamo is expected to decompose method calls into function calls with a `self` argument.",
|
|
hints=[],
|
|
)
|
|
|
|
result = trace_rules.check_verbose(func, is_inlined_call=True)
|
|
if result.skipped:
|
|
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
|
|
|
|
# _origin marks this as coming from an internal dynamo known function that is safe to
|
|
# trace through.
|
|
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
|
|
produce_trampoline_autograd_apply,
|
|
]:
|
|
# Known sound
|
|
return trace_rules.SkipResult(
|
|
False, "allowlist in dynamo known function"
|
|
)
|
|
fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
|
|
hints = [
|
|
f"Avoid calling the function `{fn_qualname}`.",
|
|
]
|
|
if "_dynamo" not in func.get_filename():
|
|
hints += [
|
|
f"Remove the function `{fn_qualname}` or the file `{func.get_filename()}` "
|
|
"from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of "
|
|
"attempting to trace into the function.",
|
|
"Please file an issue to PyTorch.",
|
|
# TODO suggest mark_force_inline when implemented
|
|
]
|
|
unimplemented_v2(
|
|
gb_type="Attempted to inline function marked as skipped",
|
|
context=f"qualname: {fn_qualname}, name: {func.get_name()}, "
|
|
f"filename: `{func.get_filename()}`, skip reason: {result.reason}",
|
|
explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` "
|
|
"should not be traced.",
|
|
hints=hints,
|
|
)
|
|
|
|
if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
|
|
func.get_function(), "_torchdynamo_disable", False
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Skip inlining `torch.compiler.disable()`d function",
|
|
context=str(func.get_function()),
|
|
explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`",
|
|
hints=[
|
|
"Remove the `torch.compiler.disable` call",
|
|
],
|
|
)
|
|
else:
|
|
return result
|
|
|
|
@staticmethod
|
|
def build_inline_tracer(
|
|
parent,
|
|
func: VariableTracker,
|
|
args: list[VariableTracker],
|
|
kwargs,
|
|
):
|
|
if isinstance(func, SkipFunctionVariable):
|
|
unimplemented_v2(
|
|
gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
|
|
context=f"Attempted to inline a SkipFunctionVariable {func}",
|
|
explanation="Attempted to inline a function that was previously determined to be marked as intentionally skipped.",
|
|
hints=[],
|
|
)
|
|
assert isinstance(
|
|
func,
|
|
(
|
|
UserFunctionVariable,
|
|
NestedUserFunctionVariable,
|
|
LocalGeneratorFunctionVariable,
|
|
LocalGeneratorObjectVariable,
|
|
),
|
|
)
|
|
result = InliningInstructionTranslator.check_inlineable(func)
|
|
assert result.skipped is False
|
|
try:
|
|
sub_locals = func.bind_args(parent, args, kwargs)
|
|
except TypeError as e:
|
|
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
|
|
raise ArgsMismatchError( # noqa: B904
|
|
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
|
|
reason=str(e),
|
|
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
|
|
args=[arg.python_type() for arg in args],
|
|
kwargs=kwargs,
|
|
),
|
|
)
|
|
|
|
for v in itertools.chain(sub_locals.values()):
|
|
if not isinstance(v, VariableTracker):
|
|
unimplemented_v2(
|
|
gb_type="Encountered unconverted argument when attempting to inline",
|
|
context=f"func: {func}, arg: {v}",
|
|
explanation="An argument to an inlined function was not successfully converted to a VariableTracker.",
|
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
|
)
|
|
|
|
code: types.CodeType = func.get_code()
|
|
if code.co_name in ("__setitem__", "__setattr__") and not (
|
|
args and isinstance(args[0], variables.UserDefinedObjectVariable)
|
|
):
|
|
unimplemented_v2(
|
|
gb_type="Unsupported __setitem__/__setattr__ inline attempt",
|
|
context=f"code name: {code.co_name}, args: {args}",
|
|
explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.",
|
|
hints=[],
|
|
)
|
|
|
|
suffix = ""
|
|
# TODO: mlazos, add support for enabling multiple artifact logs
|
|
# with a single alias
|
|
if torch._logging._internal.log_state.is_artifact_enabled("bytecode"):
|
|
suffix = f"\n{dis.Bytecode(code).dis()}"
|
|
if sys.version_info >= (3, 11):
|
|
cur_inst = parent.current_instruction
|
|
parent_code = parent.f_code
|
|
header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno)
|
|
|
|
def get_trace_call_log_str():
|
|
line = get_instruction_source_311(parent_code, cur_inst).rstrip()
|
|
return f"TRACE inlined call {code.co_name} from {header}\n{line}"
|
|
|
|
trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
|
|
log.debug("INLINING %s%s, %s", code, suffix, result.reason)
|
|
|
|
# Detect inline GraphModule calls in order to propagate node metadata,
|
|
# by checking if the first argument (self) is a variable tracking a GraphModule.
|
|
if args and isinstance(args[0], NNModuleVariable):
|
|
module = parent.output.get_submodule(args[0].module_key)
|
|
if isinstance(module, torch.fx.GraphModule):
|
|
# The inline call might not actually be a call to `forward`,
|
|
# but it is enough to add a context for `forward` in case it is called.
|
|
code_context.get_context(module.forward.__code__)[
|
|
"orig_graphmodule"
|
|
] = weakref.ref(module)
|
|
|
|
tracer: InliningInstructionTranslator
|
|
if is_generator(code):
|
|
tracer = InliningGeneratorInstructionTranslator(
|
|
parent,
|
|
code,
|
|
sub_locals,
|
|
parent.symbolic_globals,
|
|
parent.symbolic_torch_function_state,
|
|
func,
|
|
)
|
|
else:
|
|
# need the line below to make MyPy happy
|
|
assert not isinstance(func, LocalGeneratorObjectVariable)
|
|
tracer = InliningInstructionTranslator(
|
|
parent,
|
|
code,
|
|
sub_locals,
|
|
parent.symbolic_globals,
|
|
parent.symbolic_torch_function_state,
|
|
func,
|
|
)
|
|
return tracer
|
|
|
|
def inline_call_(self):
|
|
parent = self.parent
|
|
code = self.f_code
|
|
|
|
strict_ctx: Any = contextlib.nullcontext()
|
|
if parent.strict_checks_fn:
|
|
strict_ctx = self.strict_translation_mode(parent.strict_checks_fn)
|
|
try:
|
|
with strict_ctx:
|
|
self.run()
|
|
except exc.ObservedException as e:
|
|
msg = f"Observed exception DURING INLING {code} : {e}"
|
|
log.debug(msg)
|
|
# bubble up the exception to the parent frame.
|
|
raise
|
|
except exc.SkipFrame as e:
|
|
msg = f"SKIPPED INLINING {code}: {e}"
|
|
log.debug(msg)
|
|
raise Unsupported(msg) from e
|
|
except Exception:
|
|
log.debug("FAILED INLINING %s", code)
|
|
raise
|
|
assert self.symbolic_result is not None
|
|
|
|
if self.f_globals is parent.f_globals:
|
|
# Merge symbolic_globals back if parent and child are in the same namespace
|
|
parent.symbolic_globals.update(self.symbolic_globals)
|
|
|
|
parent.inconsistent_side_effects |= self.inconsistent_side_effects
|
|
|
|
log.debug("DONE INLINING %s", code)
|
|
|
|
if config.enable_faithful_generator_behavior or (
|
|
isinstance(self, InliningGeneratorInstructionTranslator)
|
|
and self.is_generator_from_ctx_manager
|
|
):
|
|
if (
|
|
is_generator(code)
|
|
and isinstance(self, InliningGeneratorInstructionTranslator)
|
|
and self.generator_exhausted
|
|
):
|
|
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
|
# When the generator returns None, we raise StopIteration
|
|
exc.raise_observed_exception(StopIteration, self)
|
|
else:
|
|
return self.symbolic_result
|
|
else:
|
|
if is_generator(code):
|
|
assert isinstance(self, InliningGeneratorInstructionTranslator)
|
|
assert self.symbolic_result.as_python_constant() is None
|
|
return ListIteratorVariable(
|
|
self.generated_items,
|
|
mutation_type=ValueMutationNew(),
|
|
)
|
|
else:
|
|
return self.symbolic_result
|
|
|
|
def __init__(
|
|
self,
|
|
parent: InstructionTranslatorBase,
|
|
code: types.CodeType,
|
|
symbolic_locals: dict[str, VariableTracker],
|
|
symbolic_globals: dict[str, VariableTracker],
|
|
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
|
funcvar: BaseUserFunctionVariable,
|
|
) -> None:
|
|
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
|
|
f_builtins = f_globals["__builtins__"]
|
|
if not isinstance(f_builtins, dict):
|
|
f_builtins = f_builtins.__dict__
|
|
instructions = cleaned_instructions(code)
|
|
propagate_line_nums(instructions)
|
|
super().__init__(
|
|
output=parent.output,
|
|
f_locals={},
|
|
f_globals=f_globals,
|
|
f_builtins=f_builtins,
|
|
symbolic_locals=symbolic_locals,
|
|
symbolic_globals=symbolic_globals,
|
|
symbolic_torch_function_state=symbolic_torch_function_state,
|
|
instructions=instructions,
|
|
code_options={k: getattr(code, k) for k in get_code_keys()},
|
|
f_code=code,
|
|
export=parent.export,
|
|
inline_depth=parent.inline_depth + 1,
|
|
speculation_log=parent.speculation_log,
|
|
exn_vt_stack=parent.exn_vt_stack,
|
|
distributed_state=parent.distributed_state,
|
|
)
|
|
self.funcvar = funcvar
|
|
self.parent = parent
|
|
self.num_calls = parent.num_calls
|
|
self.symbolic_result = None
|
|
self.nn_module_stack = parent.nn_module_stack.copy()
|
|
self.one_graph = parent.one_graph
|
|
|
|
@property
|
|
def fake_mode(self):
|
|
return self.parent.fake_mode
|
|
|
|
def run_ctx_mgr(self):
|
|
return TracingContext.current_frame(self.parent.frame_summary())
|
|
|
|
def should_compile_partial_graph(self):
|
|
return False # inlining functions is all-or-nothing
|
|
|
|
def create_call_resume_at(self, offset):
|
|
unimplemented_v2(
|
|
gb_type="Graph break in inlined function",
|
|
context="",
|
|
explanation="Graph breaks in an inlined call are not supported.",
|
|
hints=[],
|
|
)
|
|
|
|
def RETURN_VALUE(self, inst):
|
|
self.symbolic_result = self.pop() # type: ignore[assignment]
|
|
self.instruction_pointer = None
|
|
raise ReturnValueOp
|
|
|
|
def RETURN_CONST(self, inst):
|
|
self.symbolic_result = self._load_const(inst)
|
|
self.instruction_pointer = None
|
|
raise ReturnValueOp
|
|
|
|
def get_globals_source_and_value(self, name):
|
|
if "__name__" in self.f_globals:
|
|
module_name = self.f_globals["__name__"]
|
|
module_source = self.import_source(module_name)
|
|
if "torch_package" in module_name:
|
|
fglobals_value = (
|
|
torch.package.package_importer._package_imported_modules[
|
|
module_name
|
|
]
|
|
) # type: ignore[assignment]
|
|
else:
|
|
fglobals_value = _import_module(module_name)
|
|
fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
|
|
global_source = AttrSource(module_source, name)
|
|
else:
|
|
globals_name = self.output.install_global_by_id(
|
|
"___unnamed_scope", self.f_globals
|
|
)
|
|
globals_source = GlobalSource(globals_name)
|
|
fglobals_value = self.f_globals # type: ignore[assignment]
|
|
fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
|
|
global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment]
|
|
return fglobals_value, fglobals_vt, global_source
|
|
|
|
def _load_global(self, inst):
|
|
if self.output.global_scope is self.f_globals:
|
|
# If the global scope matches that of the root frame, use handler in
|
|
# root frame instruction translator, to enforce consistency.
|
|
super()._load_global(inst)
|
|
else:
|
|
name = inst.argval
|
|
|
|
_, fglobals_vt, global_source = self.get_globals_source_and_value(name)
|
|
if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
|
|
self.push(self.output.side_effects.load_attr(fglobals_vt, name))
|
|
else:
|
|
try:
|
|
value = self.f_globals[name]
|
|
except KeyError:
|
|
return self.load_builtin(inst)
|
|
|
|
self.push(VariableTracker.build(self, value, global_source))
|
|
|
|
def STORE_GLOBAL(self, inst):
|
|
if self.output.global_scope is self.f_globals:
|
|
# If the global scope matches that of the root frame, use handler in
|
|
# root frame instruction translator, to enforce consistency.
|
|
super().STORE_GLOBAL(inst)
|
|
else:
|
|
value = self.pop()
|
|
if isinstance(value, RemovableHandleVariable):
|
|
unimplemented_v2(
|
|
gb_type="Storing Tensor hook handle in globals (inline call)",
|
|
context=inst.argval,
|
|
explanation="This is not supported.",
|
|
hints=[],
|
|
)
|
|
name = inst.argval
|
|
_fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
|
|
self.output.side_effects.store_attr(fglobals_vt, name, value)
|
|
|
|
|
|
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
|
generated_items: list[VariableTracker]
|
|
# Flag wether or not the InlineGenerator should consume the entire iterator
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.generated_items = []
|
|
self.generator_exhausted = False
|
|
self.is_generator_from_ctx_manager = False
|
|
|
|
def YIELD_VALUE(self, inst: Instruction):
|
|
top = self.pop()
|
|
self.generated_items.append(top)
|
|
if len(self.generated_items) > MAX_ITERATOR_LIMIT:
|
|
raise exc.InfiniteGeneratorError(
|
|
"Too many yield values in generator. Maybe you are inlining an infinite generator. "
|
|
f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}",
|
|
)
|
|
self.push(ConstantVariable.create(None))
|
|
if (
|
|
config.enable_faithful_generator_behavior
|
|
or self.is_generator_from_ctx_manager
|
|
):
|
|
self.symbolic_result = top
|
|
# Stop tracing
|
|
raise YieldValueOp
|
|
|
|
def GET_YIELD_FROM_ITER(self, inst):
|
|
tos = self.stack[-1]
|
|
if not isinstance(tos, ListIteratorVariable):
|
|
self.pop()
|
|
res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type]
|
|
self.push(res)
|
|
|
|
def RETURN_VALUE(self, inst):
|
|
self.generator_exhausted = True
|
|
return super().RETURN_VALUE(inst)
|
|
|
|
def RETURN_CONST(self, inst):
|
|
self.generator_exhausted = True
|
|
return super().RETURN_CONST(inst)
|
|
|
|
def YIELD_FROM(self, inst):
|
|
assert len(self.stack) >= 2
|
|
val = self.pop()
|
|
tos = self.stack[-1]
|
|
if not (isinstance(val, ConstantVariable) and val.value is None):
|
|
# invoke send
|
|
# Unreachable code - if you hit this, you are implementing generator support and have
|
|
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
|
# subgenerator and lines up with this line in Python 3.10
|
|
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
|
|
unimplemented_v2(
|
|
gb_type="Unreachable sub-generator code",
|
|
context="",
|
|
explanation="Should only be encountered while implementing generator support.",
|
|
hints=[],
|
|
)
|
|
|
|
try:
|
|
val = tos.next_variable(self)
|
|
except (StopIteration, exc.ObservedUserStopIteration) as ex:
|
|
if isinstance(ex, exc.ObservedUserStopIteration):
|
|
exc.handle_observed_exception(self)
|
|
|
|
# The iterator is exhausted. Stop the loop and return.
|
|
self.pop()
|
|
self.push(ConstantVariable.create(ex.value))
|
|
else:
|
|
# Repeat the YIELD_FROM instruction in the next eval loop
|
|
assert (
|
|
isinstance(self.instruction_pointer, int)
|
|
and self.instruction_pointer > 0
|
|
)
|
|
self.instruction_pointer -= 1
|
|
|
|
self.push(val)
|
|
# Add the value to yield into generated_items and replace the top of the stack with None
|
|
self.YIELD_VALUE(inst)
|
|
|
|
def SEND(self, inst):
|
|
assert len(self.stack) >= 2
|
|
val = self.pop()
|
|
tos = self.stack[-1]
|
|
if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or (
|
|
isinstance(tos, UserDefinedObjectVariable)
|
|
and isinstance(tos.value, collections.abc.Iterator)
|
|
):
|
|
if isinstance(val, ConstantVariable) and val.value is None:
|
|
try:
|
|
val = tos.next_variable(self)
|
|
except (StopIteration, exc.ObservedUserStopIteration) as ex:
|
|
# To implement SEND, we have to look at the implementation
|
|
# when the iterator returns StopIteration. This translates to this code
|
|
# 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
|
|
# 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
|
|
# The implementation is different in 3.11 and 3.12. In 3.12, we rely
|
|
# on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
|
|
if sys.version_info < (3, 12):
|
|
self.pop() # Python 3.12 uses new opcode END_SEND
|
|
self.push(ConstantVariable.create(ex.value))
|
|
self.jump(inst)
|
|
else:
|
|
self.push(val)
|
|
else:
|
|
# invoke send
|
|
# Unreachable code - if you hit this, you are implementing generator support and have
|
|
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
|
|
# subgenerator and lines up with this line in Python 3.11
|
|
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
|
|
unimplemented_v2(
|
|
gb_type="Unreachable sub-generator code",
|
|
context="",
|
|
explanation="Should only be encountered while implementing generator support.",
|
|
hints=[],
|
|
)
|
|
else:
|
|
unimplemented_v2(
|
|
gb_type="SEND with bad type",
|
|
context=f"TOS type: {typestr(tos)}",
|
|
explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.",
|
|
hints=[],
|
|
)
|