679 lines
27 KiB
Python
679 lines
27 KiB
Python
# mypy: allow-untyped-defs
|
|
|
|
"""
|
|
This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
|
|
It includes functionality for:
|
|
- Constructing bytecode sequences for Python operations
|
|
- Managing stack operations and variable tracking
|
|
- Handling graph outputs and their conversions
|
|
- Supporting different Python versions (3.11+, 3.12+, 3.13+)
|
|
- Converting high-level operations to low-level bytecode instructions
|
|
- Managing constant loading and attribute access
|
|
- Supporting function creation and closure handling
|
|
"""
|
|
|
|
import collections
|
|
import dataclasses
|
|
import re
|
|
import sys
|
|
import types
|
|
from collections import Counter
|
|
from typing import Optional, Union
|
|
|
|
import torch.nn
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import graph_break_hints, utils
|
|
from .bytecode_transformation import (
|
|
add_push_null,
|
|
add_push_null_call_function_ex,
|
|
create_call_function,
|
|
create_call_method,
|
|
create_dup_top,
|
|
create_instruction,
|
|
create_load_const,
|
|
create_load_method,
|
|
create_rot_n,
|
|
Instruction,
|
|
)
|
|
from .exc import IncorrectUsage, unimplemented_v2
|
|
from .source import AttrSource, ChainedSource, DictGetItemSource, Source
|
|
from .utils import is_safe_constant, rot_n_helper
|
|
from .variables.base import ValueMutationExisting, VariableTracker
|
|
from .variables.functions import (
|
|
ContextlibContextManagerLocalGeneratorObjectVariable,
|
|
LocalGeneratorObjectVariable,
|
|
)
|
|
from .variables.nn_module import NNModuleVariable
|
|
from .variables.tensor import (
|
|
NumpyNdarrayVariable,
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
UnspecializedPythonVariable,
|
|
)
|
|
from .variables.torch_function import TensorWithTFOverrideVariable
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphOutputEntry:
|
|
index: int
|
|
variable: VariableTracker
|
|
|
|
|
|
class PyCodegen:
|
|
"""
|
|
Helper class uses for constructing Python bytecode
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tx=None,
|
|
root: Optional[torch.nn.Module] = None,
|
|
graph_output_var: Optional[str] = None,
|
|
tempvars=None,
|
|
overridden_sources=None,
|
|
) -> None:
|
|
self.root = root
|
|
self.top_of_stack: Optional[Union[VariableTracker, Source]] = None
|
|
self.uses: Counter[VariableTracker] = collections.Counter()
|
|
self.graph_outputs: dict[int, GraphOutputEntry] = {}
|
|
self._output: list[Instruction] = []
|
|
# This determines which VariableTracker/Source should be stored as
|
|
# locals, and maps the VariableTracker/Source to the local variable
|
|
# name. Note that it could map to None initially, in which case we'll
|
|
# overwrite it to map to real temporary names via `add_cache`.
|
|
self.tempvars = tempvars or {}
|
|
self.tx = tx
|
|
self.graph_output_var = graph_output_var
|
|
self.code_options = self.tx.output.code_options
|
|
self.cell_and_freevars = self.tx.cell_and_freevars
|
|
self.new_var = self.tx.output.new_var
|
|
self.value_from_source: bool = True
|
|
# This serves as a way for codegen to use a different source; we need
|
|
# this because sometimes we can't easily modify the original source
|
|
# without affecting other components, e.g., guards.
|
|
self.overridden_sources: dict[Source, Source] = overridden_sources or {}
|
|
|
|
def restore_stack(self, stack_values, *, value_from_source=True):
|
|
prev = self.value_from_source
|
|
self.value_from_source &= value_from_source
|
|
try:
|
|
self.foreach(stack_values)
|
|
finally:
|
|
self.value_from_source = prev
|
|
|
|
def graph_output_vars(self):
|
|
return [x.variable for x in self.graph_outputs.values()]
|
|
|
|
def call_reconstruct(self, value):
|
|
res = value.reconstruct(self)
|
|
assert res is None, f"reconstruct!=None {value}"
|
|
|
|
def add_push_null(self, gen_fn, call_function_ex=False):
|
|
"""
|
|
`gen_fn` generates instructions via PyCodegen methods
|
|
that push a single callable to the stack.
|
|
|
|
`add_push_null` pushes a NULL to the stack before or after the
|
|
instructions generated by `gen_fn`, depending on Python version.
|
|
|
|
Will attempt to use the NULL push bit for instructions
|
|
with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
|
|
"""
|
|
old_len = len(self._output)
|
|
if sys.version_info < (3, 13):
|
|
# gen_fn may DUP_TOP instead if TOS is not cleared.
|
|
# Will cause problems since NULL will be pushed right
|
|
# before the generated instructions in <= 3.12
|
|
self.clear_tos()
|
|
gen_fn()
|
|
# inplace modify self._output
|
|
added_insts = self._output[old_len:]
|
|
del self._output[old_len:]
|
|
if call_function_ex:
|
|
self._output.extend(add_push_null_call_function_ex(added_insts))
|
|
else:
|
|
self._output.extend(add_push_null(added_insts))
|
|
if sys.version_info >= (3, 13):
|
|
# NULL will be at top of stack
|
|
self.clear_tos()
|
|
|
|
def __call__(self, value, allow_cache=True):
|
|
"""
|
|
Generate code such that top-of-stack (TOS) is set to value.
|
|
|
|
`allow_cache` controls the behavior in the following manner. `value` can
|
|
either be a VariableTracker or a Source.
|
|
|
|
If `value` is a `Source`, `allow_cache` must be True (invariant asserted
|
|
below). If the source was reconstructed earlier, we will reuse the
|
|
generated code by loading from top of stack or tempvars.
|
|
|
|
If `value` is a `VariableTracker`, we have the following cases:
|
|
|
|
1) `allow_cache=True`
|
|
a) If the value.source is not None, we will emit the code based on
|
|
`value.source` to handle aliasing.
|
|
b) If value.source is None (example reconstructing a local list
|
|
returned by the compiled function), we will reconstruct the variable
|
|
tracker (w/o any source) to emit bytecode that generates a new
|
|
python object.
|
|
|
|
In both cases of value.source being None or not, if the value was
|
|
reconstructed earlier, we will reuse the generated code by loading from
|
|
top of stack or tempvars.
|
|
|
|
2) `allow_cache=False` - This is a special case (allow_cache defaults to
|
|
True).
|
|
a) If the value.source is not None, we reconstruct the variable
|
|
tracker and emit a new python object. You might wonder what about
|
|
aliasing? The place where we use this config also has the followup
|
|
code where the original python object is assigned to this new python
|
|
value to handle aliasing (check side_effects.py and search for
|
|
allow_cache=False).
|
|
|
|
b) If value.source is None, this is not allowed. TODO - assert this.
|
|
|
|
Notable effects:
|
|
1. `self.top_of_stack` will be set to `value`, if we don't codegen
|
|
`value` based on source.
|
|
2. `self.uses[value]` will increment, if we don't codegen `value` based
|
|
on source or cache/top-of-stack reuse; in other words, if we codegen
|
|
as if `value` is modelling some brand new python value.
|
|
"""
|
|
if isinstance(value, Source):
|
|
# If the source needs to be overridden, use the new one.
|
|
source = self.overridden_sources.get(value, value)
|
|
assert allow_cache is True, "allow_cache must be True for Source"
|
|
if self.top_of_stack is value:
|
|
self._output.append(create_dup_top())
|
|
return
|
|
|
|
if self.tempvars.get(source) is not None:
|
|
self._output.append(self.create_load(self.tempvars[source]))
|
|
self.top_of_stack = source
|
|
return
|
|
|
|
try:
|
|
self.call_reconstruct(source)
|
|
except NotImplementedError:
|
|
unimplemented_v2(
|
|
gb_type="Reconstruction failure: source.reconstruct not implemented",
|
|
context=str(source),
|
|
explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.",
|
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
|
)
|
|
|
|
self._output.append(create_dup_top())
|
|
self.add_cache(source)
|
|
self.top_of_stack = source
|
|
|
|
return
|
|
|
|
assert isinstance(value, VariableTracker)
|
|
output = self._output
|
|
graph_outputs = self.graph_outputs
|
|
|
|
if allow_cache:
|
|
if self.top_of_stack is value:
|
|
output.append(create_dup_top())
|
|
return
|
|
|
|
if self.tempvars.get(value) is not None:
|
|
output.append(self.create_load(self.tempvars[value]))
|
|
self.top_of_stack = value
|
|
return
|
|
|
|
if value.is_realized() and isinstance(
|
|
value, ContextlibContextManagerLocalGeneratorObjectVariable
|
|
):
|
|
raise IncorrectUsage(
|
|
"NYI: Returning a @contextmanager object from a torch.compile function"
|
|
)
|
|
|
|
# Dynamo normally prefers codegen from source to account for aliasing.
|
|
if (
|
|
value.source is not None
|
|
and allow_cache
|
|
and not (
|
|
value.is_realized() and isinstance(value, LocalGeneratorObjectVariable)
|
|
)
|
|
):
|
|
# There's a corner case for export: for instance, if the computation
|
|
# graph is just identity on an input tensor, Dynamo would just emit
|
|
# a `LOAD_FAST` from the input source, rather than generating an
|
|
# identity FX graph.
|
|
#
|
|
# However, export wants to maximize graph capture; in the case
|
|
# above, export _wants to_ obtain an identity FX graph (despite it
|
|
# appears unnecessarily expensive for `torch.compile`), so we have
|
|
# the following option to override Dynamo's preference for codegen
|
|
# from source. Morever, this option applies recursively, for cases
|
|
# like input tensor being returned in a new dictionary.
|
|
#
|
|
# And why the `ValueMutationExisting` check? Not sure, so leaving it
|
|
# to keep the old behavior, as when `value_from_source` was
|
|
# introduced. TODO sort out the invariants among side effect,
|
|
# codegen and export.
|
|
if (
|
|
isinstance(value.mutation_type, ValueMutationExisting)
|
|
or self.value_from_source
|
|
):
|
|
return self(value.source)
|
|
|
|
if value.is_python_constant() and is_safe_constant(value.as_python_constant()):
|
|
output.append(self.create_load_const(value.as_python_constant()))
|
|
elif isinstance(value, TensorWithTFOverrideVariable):
|
|
graph_outputs_key = self.add_graph_output(value)
|
|
|
|
self.add_push_null(
|
|
lambda: self.load_import_from(utils.__name__, "to_subclass")
|
|
)
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.append(
|
|
self.create_load_global(
|
|
value.global_mangled_class_name(self.tx), add=True
|
|
)
|
|
)
|
|
output.extend(create_call_function(2, False))
|
|
elif (
|
|
isinstance(value, SymNodeVariable)
|
|
and value.python_type() == float
|
|
and not self.tx.export
|
|
):
|
|
# This is a little unusual; force the output convention to be a
|
|
# Tensor here. Don't do this for export because this is
|
|
# apparently load bearing for export tests (but I am a bit
|
|
# doubtful it actually works in the real world)
|
|
# NB: It works to add_graph_output on a computed expression
|
|
# as_tensor here, because we memoize as_tensor calls on
|
|
# SymNodeVariable!
|
|
graph_outputs_key = self.add_graph_output(
|
|
value.as_tensor(self.tx, torch.float64)
|
|
)
|
|
|
|
def gen_fn():
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.append(self.create_load_attr("item"))
|
|
|
|
self.add_push_null(gen_fn)
|
|
output.extend(create_call_function(0, False))
|
|
elif isinstance(
|
|
value,
|
|
(
|
|
TensorVariable,
|
|
SymNodeVariable,
|
|
UnspecializedPythonVariable,
|
|
NumpyNdarrayVariable,
|
|
),
|
|
):
|
|
graph_outputs_key = self.add_graph_output(value)
|
|
|
|
if isinstance(value, NumpyNdarrayVariable):
|
|
self.add_push_null(
|
|
lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
|
|
)
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.extend(create_call_function(1, False))
|
|
elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
|
|
|
|
def gen_fn():
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
output.append(self.create_load_attr("item"))
|
|
|
|
self.add_push_null(gen_fn)
|
|
output.extend(create_call_function(0, False))
|
|
else:
|
|
self.load_graph_output(graph_outputs[graph_outputs_key].index)
|
|
elif isinstance(value, NNModuleVariable):
|
|
parts = value.module_key.split(".")
|
|
if parts[0] in self.code_options["co_varnames"]:
|
|
output.append(self.create_load(parts[0]))
|
|
parts = parts[1:]
|
|
else:
|
|
assert self.root is not None
|
|
output.append(self.create_load_const_unchecked(self.root))
|
|
for part in parts:
|
|
output.append(self.create_load_attr(part))
|
|
else:
|
|
self.uses[value] += 1
|
|
try:
|
|
self.call_reconstruct(value)
|
|
except NotImplementedError:
|
|
unimplemented_v2(
|
|
gb_type="Reconstruction failure",
|
|
context=str(value),
|
|
explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.",
|
|
hints=[
|
|
"If Dynamo attempting to trace a return statement and your code is attempting to return a variable "
|
|
"that Dynamo cannot reconstruct, then remove it from the return statement.",
|
|
*graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
|
|
"Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have"
|
|
"reconstruction rules may be fundamentally unreconstructable.",
|
|
],
|
|
)
|
|
if allow_cache and value in self.tempvars:
|
|
self._output.append(create_dup_top())
|
|
self.add_cache(value)
|
|
|
|
self.top_of_stack = value
|
|
|
|
def add_graph_output(self, value):
|
|
graph_outputs_key = id(value.as_proxy())
|
|
if graph_outputs_key not in self.graph_outputs:
|
|
self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
|
|
len(self.graph_outputs), value
|
|
)
|
|
return graph_outputs_key
|
|
|
|
def load_graph_output(self, index):
|
|
output = self._output
|
|
output.append(self.create_load(self.graph_output_var))
|
|
output.append(self.create_load_const(index))
|
|
output.append(self.create_binary_subscr())
|
|
|
|
def add_cache(self, value):
|
|
var = self.new_var()
|
|
self.tempvars[value] = var
|
|
self._output.append(self.create_store(var))
|
|
|
|
def foreach(self, items):
|
|
for i in items:
|
|
self(i)
|
|
|
|
def create_binary_subscr(self) -> Instruction:
|
|
return create_instruction("BINARY_SUBSCR")
|
|
|
|
def setup_globally_cached(self, name, value):
|
|
"""Store value in a new global"""
|
|
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
|
|
f_globals = self.tx.f_globals
|
|
if name in f_globals:
|
|
assert id(f_globals[name]) == id(value)
|
|
else:
|
|
f_globals[name] = value
|
|
return [self.create_load_global(name, add=True)]
|
|
|
|
def clear_tos(self):
|
|
self.top_of_stack = None
|
|
|
|
def append_output(self, inst):
|
|
assert isinstance(inst, Instruction)
|
|
self._output.append(inst)
|
|
self.clear_tos()
|
|
|
|
def extend_output(self, insts):
|
|
assert all(isinstance(x, Instruction) for x in insts)
|
|
self._output.extend(insts)
|
|
self.clear_tos()
|
|
|
|
def get_instructions(self) -> list[Instruction]:
|
|
return self._output
|
|
|
|
def create_load(self, name) -> Instruction:
|
|
assert name in self.code_options["co_varnames"], f"{name} missing"
|
|
return create_instruction("LOAD_FAST", argval=name)
|
|
|
|
def create_load_closure(self, name) -> Instruction:
|
|
assert name in self.cell_and_freevars()
|
|
inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
|
|
return create_instruction(inst_name, argval=name)
|
|
|
|
def create_load_deref(self, name) -> Instruction:
|
|
assert name in self.cell_and_freevars()
|
|
return create_instruction("LOAD_DEREF", argval=name)
|
|
|
|
def create_store(self, name) -> Instruction:
|
|
assert name in self.code_options["co_varnames"], f"{name} missing"
|
|
return create_instruction("STORE_FAST", argval=name)
|
|
|
|
def create_store_deref(self, name) -> Instruction:
|
|
assert name in self.cell_and_freevars()
|
|
return create_instruction("STORE_DEREF", argval=name)
|
|
|
|
def create_load_global(self, name, add=False) -> Instruction:
|
|
if add:
|
|
self.tx.output.update_co_names(name)
|
|
assert name in self.code_options["co_names"], f"{name} not in co_names"
|
|
return create_instruction("LOAD_GLOBAL", argval=name)
|
|
|
|
def create_load_const(self, value) -> Instruction:
|
|
return create_load_const(value)
|
|
|
|
def create_load_const_unchecked(self, value) -> Instruction:
|
|
return create_load_const(value, checked=False)
|
|
|
|
def load_method(self, name):
|
|
self.tx.output.update_co_names(name)
|
|
self.append_output(create_load_method(name))
|
|
|
|
def call_method(self, nargs):
|
|
self.extend_output(create_call_method(nargs))
|
|
|
|
def create_load_attr(self, name) -> Instruction:
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] += (name,)
|
|
return create_instruction("LOAD_ATTR", argval=name)
|
|
|
|
def load_attr(self, name):
|
|
self.append_output(self.create_load_attr(name))
|
|
|
|
def create_load_attrs(self, names):
|
|
return [self.create_load_attr(name) for name in names.split(".")]
|
|
|
|
def create_store_attr(self, name) -> Instruction:
|
|
if name not in self.code_options["co_names"]:
|
|
self.code_options["co_names"] += (name,)
|
|
return create_instruction("STORE_ATTR", argval=name)
|
|
|
|
def store_attr(self, name):
|
|
self.append_output(self.create_store_attr(name))
|
|
|
|
def load_function_name(self, fn_name, push_null, num_on_stack=0):
|
|
"""Load the global fn_name on the stack num_on_stack down"""
|
|
output = []
|
|
if push_null and sys.version_info >= (3, 11):
|
|
output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
|
|
if num_on_stack > 0:
|
|
output.extend(
|
|
[
|
|
*self.rot_n(num_on_stack + 2),
|
|
*self.rot_n(num_on_stack + 2),
|
|
]
|
|
)
|
|
else:
|
|
output.extend(
|
|
[
|
|
self.create_load_global(fn_name, add=True),
|
|
*self.rot_n(num_on_stack + 1),
|
|
]
|
|
)
|
|
return output
|
|
|
|
def rot_n(self, n):
|
|
try:
|
|
return create_rot_n(n)
|
|
except AttributeError:
|
|
# desired rotate bytecode doesn't exist, generate equivalent bytecode
|
|
return [
|
|
create_instruction("BUILD_TUPLE", arg=n),
|
|
self.create_load_const_unchecked(rot_n_helper(n)),
|
|
*create_rot_n(2),
|
|
create_instruction("CALL_FUNCTION_EX", arg=0),
|
|
create_instruction("UNPACK_SEQUENCE", arg=n),
|
|
]
|
|
|
|
def pop_null(self):
|
|
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
|
|
# nop function, calling it (which consumes the null), and popping the result.
|
|
assert sys.version_info >= (3, 11)
|
|
return [
|
|
self.create_load_const_unchecked(lambda: None),
|
|
# 3.13 swapped NULL and callable
|
|
*(
|
|
(create_instruction("SWAP", arg=2),)
|
|
if sys.version_info >= (3, 13)
|
|
else ()
|
|
),
|
|
*create_call_function(0, False),
|
|
create_instruction("POP_TOP"),
|
|
]
|
|
|
|
def pop_top(self):
|
|
self.append_output(create_instruction("POP_TOP"))
|
|
|
|
def call_function(self, nargs: int, push_null: bool):
|
|
self.extend_output(create_call_function(nargs, push_null=push_null))
|
|
|
|
def dup_top(self):
|
|
self.append_output(create_dup_top())
|
|
|
|
def store(self, varname):
|
|
self.append_output(self.create_store(varname))
|
|
|
|
def load_deref(self, varname):
|
|
self.append_output(self.create_load_deref(varname))
|
|
|
|
def make_function_with_closure(
|
|
self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
|
|
):
|
|
freevars = code.co_freevars
|
|
assert freevars
|
|
output = self._output
|
|
|
|
def gen_fn():
|
|
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
|
|
# requires that in the generated bytecode, these cells would keep
|
|
# their original local names, which we ensure via
|
|
# `CellVariable.local_name`.
|
|
for var in freevars:
|
|
assert var in self.cell_and_freevars()
|
|
output.append(self.create_load_closure(var))
|
|
output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
|
|
output.append(self.create_load_const(code))
|
|
if sys.version_info < (3, 11):
|
|
output.append(self.create_load_const(fn_name))
|
|
if sys.version_info >= (3, 13):
|
|
output.extend(
|
|
[
|
|
create_instruction("MAKE_FUNCTION"),
|
|
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
|
|
]
|
|
)
|
|
else:
|
|
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
|
|
|
|
if push_null and sys.version_info >= (3, 11):
|
|
self.add_push_null(gen_fn)
|
|
output.extend(self.rot_n(num_on_stack + 2))
|
|
output.extend(self.rot_n(num_on_stack + 2))
|
|
else:
|
|
gen_fn()
|
|
output.extend(self.rot_n(num_on_stack + 1))
|
|
self.clear_tos()
|
|
|
|
def create_load_python_module(self, mod) -> Instruction:
|
|
"""
|
|
Generate a LOAD_GLOBAL instruction to fetch a given python module.
|
|
"""
|
|
output = self.tx.output
|
|
global_scope = output.global_scope
|
|
name = re.sub(r"^.*[.]", "", mod.__name__)
|
|
if global_scope.get(name, None) is mod:
|
|
return self.create_load_global(name, add=True)
|
|
prefix = f"___module_{name}"
|
|
global_name = self.tx.output.install_global_by_id(prefix, mod)
|
|
return self.create_load_global(global_name, add=True)
|
|
|
|
def mark_source_temp(self, source: Source) -> None:
|
|
"""
|
|
Mark a source as a temp variable, so that it can be reused.
|
|
"""
|
|
if source not in self.tempvars:
|
|
self.tempvars[source] = None
|
|
|
|
def make_call_generated_code(self, fn_name: str) -> None:
|
|
"""Call the generated code function stored in fn_name"""
|
|
self.extend_output(self.load_function_name(fn_name, True))
|
|
|
|
graphargs = self.tx.output.graphargs
|
|
|
|
seen_sources: OrderedSet[Source] = OrderedSet()
|
|
|
|
def collect_temp_source(source):
|
|
if source in seen_sources:
|
|
# This source is used atleast twice, so it can be reused
|
|
self.mark_source_temp(source)
|
|
# Dont trace source further. This prevents us from marking too
|
|
# many nodes as temp sources.
|
|
return
|
|
|
|
seen_sources.add(source)
|
|
|
|
if isinstance(source, ChainedSource):
|
|
collect_temp_source(source.base)
|
|
|
|
if isinstance(source, DictGetItemSource) and isinstance(
|
|
source.index, Source
|
|
):
|
|
collect_temp_source(source.index)
|
|
|
|
# Collect all the sources that are used more than once, so that we can
|
|
# generate tmp variables in the generated pre-graph bytecode. This
|
|
# essentially implements CSE.
|
|
for arg in graphargs:
|
|
if arg.source is not None:
|
|
collect_temp_source(arg.source)
|
|
|
|
for arg in graphargs:
|
|
if arg.pass_arg_as_tensor:
|
|
self.add_push_null(
|
|
lambda: self.extend_output(
|
|
[
|
|
self.create_load_python_module(torch),
|
|
self.create_load_attr("_as_tensor_fullprec"),
|
|
]
|
|
)
|
|
)
|
|
self.call_reconstruct(arg)
|
|
self.extend_output(create_call_function(1, False))
|
|
else:
|
|
self.call_reconstruct(arg)
|
|
|
|
self.extend_output(create_call_function(len(graphargs), False))
|
|
|
|
def load_import_from(self, module_name, object_name) -> None:
|
|
source = AttrSource(self.tx.import_source(module_name), object_name)
|
|
# Note: This approach is somewhat aggressive because typically, a source is marked
|
|
# as a tempvar only when it is used more than once. In this case, we're marking it
|
|
# as a tempvar without performing that analysis. However, this is a simple solution,
|
|
# and in many cases, load imports are reused multiple times.
|
|
self.mark_source_temp(source)
|
|
self(source)
|
|
|
|
def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]:
|
|
if sys.version_info >= (3, 13):
|
|
output = create_call_function(nargs, push_null)
|
|
assert output[-1].opname == "CALL"
|
|
output.insert(-1, self.create_load_const(kw_names))
|
|
output[-1] = create_instruction("CALL_KW", arg=nargs)
|
|
return output
|
|
elif sys.version_info >= (3, 11):
|
|
output = create_call_function(nargs, push_null)
|
|
if sys.version_info >= (3, 12):
|
|
idx = -1
|
|
expected_inst = "CALL"
|
|
else:
|
|
idx = -2
|
|
expected_inst = "PRECALL"
|
|
assert output[idx].opname == expected_inst
|
|
kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
|
|
output.insert(idx, kw_names_inst)
|
|
return output
|
|
return [
|
|
self.create_load_const(kw_names),
|
|
create_instruction("CALL_FUNCTION_KW", arg=nargs),
|
|
]
|
|
|
|
def create_delete(self, value) -> Instruction:
|
|
return create_instruction("DELETE_FAST", argval=value)
|