team-10/venv/Lib/site-packages/torch/fx/experimental/sym_node.py
2025-08-02 02:00:33 +02:00

1822 lines
58 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
"""
This file does three things:
- Contains the definition of SymNode
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- Does not depend on sympy at import time
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""
import builtins
import functools
import inspect
import itertools
import logging
import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, TYPE_CHECKING, Union
import torch
import torch._logging.structured as structured
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import ( # noqa: F401
sym_float,
sym_ite,
sym_max,
sym_min,
sym_not,
SymBool,
SymFloat,
SymInt,
)
from torch._logging import dtrace_structured
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
from torch.types import py_sym_types as SymTypes
def _to_symtype(t):
if t is bool:
return SymBool
if t is int:
return SymInt
if t is float:
return SymFloat
return t
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
# Note [optimized_summation]: indicates that SymNode is an Add expression of the form
# a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
# for common patterns see _optimized_add.
# The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression,
# so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
# a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
_optimized_summation: bool = False
def __init__(
self,
expr,
shape_env,
pytype,
hint: Optional[Union[int, float, bool]],
constant=None,
fx_node=None,
optimized_summation=False,
):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
self._optimized_summation = optimized_summation
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
# in hopes that we've learned enough about the unbacked symints to
# discharge the hint; otherwise, you're likely to just error out.
#
# (A previous version of this system had some optimizations to only
# recompute when it was possible we had learned enough about the
# unbacked symint that a hint was now possible, but as we added more
# potential refinements to unbacked symints this got harder to keep
# in sync, so we've deleted it for now.)
def compute_hint():
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
# This occasionally gets exercised by, e.g.,
# convert_shape_to_symint. It's just a nicety so you don't HAVE
# to have a correct hint on hand when making a SymNode.
# Don't attempt to compute for unbacked, this can be quite
# expensive.
if has_free_unbacked_symbols(self.expr):
return None
hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
if hint is not None:
hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
return hint
if hint is not None:
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
if self.shape_env and self.shape_env._translation_validation_enabled:
# This is technically not TV, but this assert is expensive so
# let's only do it when we're already doing expensive things
computed_hint = compute_hint()
assert (
hint == computed_hint
), f"{hint} != {computed_hint} (for {self.expr})"
else:
hint = compute_hint()
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
# Record the FX node of the current node if we are doing translation
# validation. They will be used for building the input assertions for
# the translation validation problem.
tx_validation_en = (
self.shape_env and self.shape_env._translation_validation_enabled
)
self.fx_node = tx_validation_en and fx_node
def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
def _value_eq(self, other: SymNode) -> bool:
# Purposely don't include the shape_env in the eq.
return (
self._expr == other._expr
and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
)
def _value_hash(self) -> int:
# Purposely don't include the shape_env in the hash.
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@property
def expr(self):
return self.shape_env.replace(self._expr)
@property
def hint(self):
return self._hint
def has_hint(self):
return self._hint is not None
def require_hint(self, fallback=None):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
if self._hint is None:
if fallback is not None:
# Say we have some expr like 2*u0 + s0
# The hint will be None, since the expr contains at least 1 unbacked.
# We will:
# - replace every backed free symbol with its corresponding hint
# - replace every unbacked free symbol with the fallback
# - regenerate the expression with those symbol replacements
# Note: this is not really complete either, since right now
# this logic does not take into account any value ranges
# for the unbacked symints, we may need to beef it up at some point.
unbacked_symbols = free_unbacked_symbols(self.expr)
replacements = {
s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s]
for s in self.expr.free_symbols
}
return self.expr.xreplace(replacements)
# NB: we expect this to raise
return self.shape_env.size_hint(self.expr)
return self._hint
def maybe_as_int(self):
if self.expr.is_number:
return int(self.expr)
else:
return None
# NB: This does conversions, not sure if this is good or not
def maybe_as_float(self):
import sympy
if isinstance(self.expr, sympy.Float):
return float(self.expr)
else:
return None
def maybe_as_bool(self):
import sympy
if self.expr is sympy.true:
return True
elif self.expr is sympy.false:
return False
else:
return None
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def is_bool(self):
return self.pytype is bool
def is_nested_int(self):
# Unbacked SymInts cannot be nested int today
return (
self._hint is not None
and isinstance(self._hint, SymInt)
and self._hint.node.is_nested_int()
)
def wrap_int(self, num):
assert type(num) is int
import sympy
return SymNode(
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
)
def wrap_float(self, num):
assert type(num) is float
import sympy
return SymNode(
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
)
def wrap_bool(self, num):
assert type(num) is bool
import sympy
return SymNode(
sympy.true if num else sympy.false,
self.shape_env,
bool,
num,
constant=num,
fx_node=num,
)
def clone(self):
return self
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
rep = [
f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
]
if self._hint is not None:
rep.append(f"hint={self._hint}")
if self.constant is not None:
rep.append(f"constant={self.constant}")
if self.fx_node is not None:
rep.append(f"fx_node={self.fx_node}")
return ", ".join(rep) + ")"
def _graph_repr(self) -> builtins.str:
# Representation used by GraphModule to create a pythonic version of a graph
return self.str()
# These methods call the metaprogrammed methods, they're hand written
# here so we get good stack traces
def abs(self) -> SymNode:
return self._abs() # type: ignore[attr-defined]
def pos(self) -> SymNode:
return self._pos() # type: ignore[attr-defined]
def round(self, ndigits=None) -> SymNode:
return self._round(ndigits) # type: ignore[attr-defined]
def trunc(self) -> SymNode:
return self._trunc() # type: ignore[attr-defined]
def add(self, other) -> SymNode:
return self._add(other) # type: ignore[attr-defined]
def sub(self, other) -> SymNode:
return self._sub(other) # type: ignore[attr-defined]
def mul(self, other) -> SymNode:
return self._mul(other) # type: ignore[attr-defined]
def mod(self, other) -> SymNode:
return self._mod(other) # type: ignore[attr-defined]
def float_pow(self, other) -> SymNode:
return self._float_pow(other) # type: ignore[attr-defined]
def pow_by_natural(self, other) -> SymNode:
return self._pow_by_natural(other) # type: ignore[attr-defined]
def and_(self, other) -> SymNode:
return self._and_(other) # type: ignore[attr-defined]
def or_(self, other) -> SymNode:
return self._or_(other) # type: ignore[attr-defined]
def float_truediv(self, other) -> SymNode:
return self._float_truediv(other) # type: ignore[attr-defined]
def int_truediv(self, other) -> SymNode:
return self._int_truediv(other) # type: ignore[attr-defined]
def int_floordiv(self, other) -> SymNode:
return self._int_floordiv(other) # type: ignore[attr-defined]
def lshift(self, other) -> SymNode:
return self._lshift(other) # type: ignore[attr-defined]
def rshift(self, other) -> SymNode:
return self._rshift(other) # type: ignore[attr-defined]
def sym_not(self) -> SymNode: # noqa: F811
return self._sym_not() # type: ignore[attr-defined]
def eq(self, other) -> SymNode:
return self._eq(other) # type: ignore[attr-defined]
def ne(self, other) -> SymNode:
return self._ne(other) # type: ignore[attr-defined]
def gt(self, other) -> SymNode:
return self._gt(other) # type: ignore[attr-defined]
def lt(self, other) -> SymNode:
return self._lt(other) # type: ignore[attr-defined]
def le(self, other) -> SymNode:
return self._le(other) # type: ignore[attr-defined]
def ge(self, other) -> SymNode:
return self._ge(other) # type: ignore[attr-defined]
def floor(self) -> SymNode:
return self._floor() # type: ignore[attr-defined]
def is_integer(self) -> SymNode:
return self._is_integer() # type: ignore[attr-defined]
def sym_float(self) -> SymNode: # noqa: F811
return self._sym_float() # type: ignore[attr-defined]
def sym_int(self) -> SymNode:
return self._sym_int() # type: ignore[attr-defined]
def ceil(self) -> SymNode:
return self._ceil() # type: ignore[attr-defined]
def neg(self) -> SymNode:
return self._neg() # type: ignore[attr-defined]
def sym_min(self, other) -> SymNode: # noqa: F811
return self._sym_min(other) # type: ignore[attr-defined]
def sym_max(self, other) -> SymNode: # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]
def sym_ite(self, then_val, else_val) -> SymNode:
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
def is_contiguous(self, sizes, strides) -> SymNode:
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
# Make C++ happy
def sym_or(self, other):
return self.or_(other)
def sym_and(self, other):
return self.and_(other)
# Integer bitwise ops
def bitwise_and(self, other):
return self._bitwise_and(other) # type: ignore[attr-defined]
def bitwise_or(self, other):
return self._bitwise_or(other) # type: ignore[attr-defined]
# There is no int_truediv available from C++
def truediv(self, other):
return self.float_truediv(other)
def floordiv(self, other) -> SymNode:
return self.int_floordiv(other)
# We didn't bind integer pow in C++
def pow(self, other):
return self.float_pow(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# This one is currently done by hand, but if we add other variadic
# functions consider factoring it out to be metaprogrammed too. Note that
# some load bearing logic is directly in torch.sym_sum
def sym_sum(self, args) -> SymNode:
import sympy
# Inner impl
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
if get_proxy_mode():
return to_node(
self,
handle_sym_dispatch(
torch.sym_sum,
(tuple(wrap_node(a) for a in args),),
{},
),
)
exprs = [a.expr for a in args]
out = sympy.Add(*exprs)
size_hints = []
out_hint = None
for a in args:
if a.hint is None:
break
size_hints.append(a.hint)
else:
out_hint = sum(size_hints)
fx_node, _ = self.shape_env._create_fx_call_function(
torch.sym_sum, (tuple(a.fx_node for a in args),)
)
# NB: Only for integers!
return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
def evaluate(self, size_oblivious=False):
return self.shape_env.evaluate_sym_node(self, size_oblivious)
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.evaluate()
try:
return int(r)
except Exception:
log.warning("Failed to convert to int: %s", r)
raise
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.evaluate()
try:
return float(r)
except Exception:
log.warning("Failed to convert to float: %s", r)
raise
def guard_bool(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.evaluate()
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def expect_true(self, file, line):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
if (
self.has_hint()
and not free_unbacked_symbols(self.expr)
and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
):
# OK to generate guards
return self.guard_bool(file, line)
# Generate a deferred runtime assert (this might actually end up doing
# a regular guard if we can!)
# TODO: file/line here is very important, because the assert has been
# deferred so you can't backtrace easily
return self.shape_env.defer_runtime_assert(
self.expr, f"{file}:{line}", fx_node=self.fx_node
)
def expect_size(self, file, line):
from torch.fx.experimental.symbolic_shapes import _advise_is_size
b = self.ge(self.wrap_int(0))
# Generate a deferred runtime assert
r = b.expect_true(file, line)
# Refine compile time range, but only if it's unbacked.
# If you refine range for hinted variables, you can end up making
# improper deductions since compile time reasoning may be
# incompatible with runtime reasoning.
if r and not self.has_hint():
_advise_is_size(SymInt(self))
return r
def guard_size_oblivious(self, file, line):
"""
Like guard_bool, but if we encounter unbacked symbols, if those symbols
are size-like, we will treat them as >= 2 for the purposes of the analysis.
This CHANGES the runtime semantics, but all size-oblivious sites have been
audited to ensure that the runtime semantics don't change in a material way.
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
an unbacked one size, or a tensor reporting as non-contiguous even if it's
contiguous if it would have been reported contiguous due to being empty.
"""
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.evaluate(size_oblivious=True)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def bool_(self):
return self.guard_bool("", 0)
def is_symbolic(self):
return True
def nested_int(self):
return None
def is_constant(self):
return False
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"pos": operator.pos,
"abs": operator.abs,
"add": operator.add,
"and": operator.and_,
"bitwise_and": operator.and_,
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
"trunc": math.trunc,
"int_floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
"is_integer": lambda x: x.is_integer(),
"le": operator.le,
"lshift": operator.lshift,
"lt": operator.lt,
"mod": operator.mod,
"mul": operator.mul,
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"bitwise_or": operator.or_,
"float_pow": operator.pow,
"pow_by_natural": operator.pow,
"round": builtins.round,
"rshift": operator.rshift,
"sub": operator.sub,
"sym_float": sym_float,
"sym_ite": sym_ite,
"sym_max": sym_max,
"sym_min": sym_min,
"sym_not": sym_not,
"float_truediv": operator.truediv,
"int_truediv": operator.truediv,
}
unary_magic_methods = {
"abs",
"sym_float",
"sym_int",
"ceil",
"floor",
"neg",
"sym_not",
"pos",
"trunc",
}
# Adding math ops: sqrt, cos, sin, ...
def _get_sym_node_fn(name):
def fn(self):
return getattr(self, f"_sym_{name}")()
return fn
math_op_names = (
"sqrt",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"asin",
"acos",
"atan",
"log2",
)
for name in math_op_names:
sym_name = f"sym_{name}"
priv_sym_name = f"_{sym_name}"
setattr(SymNode, sym_name, _get_sym_node_fn(name))
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
unary_magic_methods.add(sym_name)
__all__.append(sym_name)
# Unary methods that are not magic methods
unary_nonmagic_methods = {
"is_integer",
}
unary_methods = unary_magic_methods | unary_nonmagic_methods
# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that implicitly convert SymBool into SymInt
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
# Methods that are only for float
only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
# remap necessary because an op name can have a bitwise and boolean implementation
bitwise_ops = {
"bitwise_and": "and",
"bitwise_or": "or",
}
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
for name in math_op_names:
sym_name = f"sym_{name}"
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
always_bool_magic_methods = {
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
"and",
"or",
"sym_not",
"is_non_overlapping_and_dense",
"is_integer",
}
# Methods that have a `__foo__` as well as `__rfoo__`
def _sympy_float_truediv(a, b):
from torch.utils._sympy.functions import FloatTrueDiv
return FloatTrueDiv(a, b)
def _sympy_int_truediv(a, b):
from torch.utils._sympy.functions import IntTrueDiv
return IntTrueDiv(a, b)
def _sympy_floordiv(a, b):
from torch.utils._sympy.functions import FloorDiv
return FloorDiv(a, b)
def _sympy_mod(a, b):
from torch.utils._sympy.functions import Mod, PythonMod
if a.is_nonnegative and b.is_nonnegative:
return Mod(a, b)
else:
return PythonMod(a, b)
def _sympy_pow_by_natural(a, b):
from torch.utils._sympy.functions import PowByNatural
return PowByNatural(a, b)
def _sympy_float_pow(a, b):
from torch.utils._sympy.functions import FloatPow
return FloatPow(a, b)
def _sympy_and(a, b):
import sympy
return sympy.And(a, b)
def _sympy_or(a, b):
import sympy
return sympy.Or(a, b)
def _sympy_lshift(a, b):
from torch.utils._sympy.functions import LShift
return LShift(a, b)
def _sympy_rshift(a, b):
from torch.utils._sympy.functions import RShift
return RShift(a, b)
def _binary_search_insert_arg(ordered_args, new_arg):
"""
If new_arg is found in ordered_args None is returned, else the new
ordered_args with new_arg inserted
"""
if len(ordered_args) == 0:
return [new_arg]
from sympy.core.basic import _args_sortkey as sort_key, Basic
# Fast path when new_arg > ordered_args[-1].
if sort_key(ordered_args[-1]) < sort_key(new_arg):
return ordered_args + [new_arg]
# Fast path when new_arg < ordered_args[0].
if sort_key(ordered_args[0]) > sort_key(new_arg):
return [new_arg] + ordered_args
low, high = 0, len(ordered_args) - 1
while low <= high:
mid = (low + high) // 2
compare_result = Basic.compare(ordered_args[mid], new_arg)
if compare_result == 0:
return None
elif compare_result < 0:
low = mid + 1
else:
high = mid - 1
ordered_args.insert(low, new_arg)
return ordered_args
def _optimized_add(
lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
):
"""
Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
1. Avoid running other optimizations when the Add is constructed.
2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
(comparing terms is expensive and shows in the profiles).
The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
(2) the result sympy expression.
"""
import sympy
from sympy.core.basic import _args_sortkey as sortkey
def make_optimized(ordered_args):
result = sympy.Add(*ordered_args, evaluate=False)
return (True, result)
from torch.utils._sympy.functions import _is_symbols_binary_summation
lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
if lhs_is_optimized_summation and rhs_is_optimized_summation:
# (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
return make_optimized(lhs._args + rhs._args)
# (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
return make_optimized(rhs._args + lhs._args)
# (a0+a2) + a1 => (a0+a1+a2)
if lhs_is_optimized_summation and rhs.is_symbol:
new_args = _binary_search_insert_arg(list(lhs._args), rhs)
if new_args is not None:
return make_optimized(new_args)
# a1 + (a0+a2)=> (a0+a1+a2)
if rhs_is_optimized_summation and lhs.is_symbol:
new_args = _binary_search_insert_arg(list(rhs._args), lhs)
if new_args is not None:
return make_optimized(new_args)
result = sympy.Add(lhs, rhs)
return (_is_symbols_binary_summation(result), result)
def _bitwise_and(a, b):
from torch.utils._sympy.functions import BitwiseFn_bitwise_and
return BitwiseFn_bitwise_and(a, b)
def _bitwise_or(a, b):
from torch.utils._sympy.functions import BitwiseFn_bitwise_or
return BitwiseFn_bitwise_or(a, b)
reflectable_magic_methods = {
"add": _optimized_add,
"sub": operator.sub,
"mul": operator.mul,
"mod": _sympy_mod,
"pow_by_natural": _sympy_pow_by_natural,
"float_pow": _sympy_float_pow,
"and": _sympy_and,
"bitwise_and": _bitwise_and,
"or": _sympy_or,
"bitwise_or": _bitwise_or,
"float_truediv": _sympy_float_truediv,
"int_truediv": _sympy_int_truediv,
"int_floordiv": _sympy_floordiv,
"lshift": _sympy_lshift,
"rshift": _sympy_rshift,
}
def _floor_ceil_helper(a, fn):
import sympy
if isinstance(a, sympy.Mul):
aa = a.args
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
coef = sympy.Integer(aa[0])
if aa[0] == coef: # structural equality test
return coef * aa[1]
if (
isinstance(a, sympy.Float)
and a == sympy.Integer(a)
or isinstance(a, sympy.Integer)
):
return sympy.Integer(a)
return fn(a)
def _sympy_floor(a):
from torch.utils._sympy.functions import FloorToInt
return FloorToInt(a)
# NB: this is Python trunc semantics which returns an int. Do NOT use this to
# represent torch.trunc (which is float to float)
def _sympy_trunc(a):
from torch.utils._sympy.functions import TruncToInt
return TruncToInt(a)
def _sympy_ceil(a):
from torch.utils._sympy.functions import CeilToInt
return CeilToInt(a)
def _sympy_eq(a, b):
import sympy
return sympy.Eq(a, b)
def _sympy_ne(a, b):
import sympy
return sympy.Ne(a, b)
def _sympy_gt(a, b):
import sympy
return sympy.Gt(a, b)
def _sympy_lt(a, b):
import sympy
return sympy.Lt(a, b)
def _sympy_le(a, b):
import sympy
return sympy.Le(a, b)
def _sympy_ge(a, b):
import sympy
return sympy.Ge(a, b)
def _sympy_min(a, b):
from torch.utils._sympy.functions import Min
return Min(a, b)
def _sympy_max(a, b):
from torch.utils._sympy.functions import Max
return Max(a, b)
def _sympy_ite(a, t, f):
import sympy
return sympy.Piecewise((t, a), (f, True))
current_module = sys.modules[__name__]
def _get_sym_math_fn(name):
def fn(a):
import torch.utils._sympy.functions
return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
return fn
for name in math_op_names:
priv_sympy_name = f"_sympy_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
def _sympy_abs(a):
import sympy
return sympy.Abs(a)
def _sympy_round(number, ndigits=None):
from torch.utils._sympy.functions import RoundDecimal, RoundToInt
if ndigits is None:
return RoundToInt(number)
else:
return RoundDecimal(number, ndigits)
def _sympy_sym_float(a):
from torch.utils._sympy.functions import ToFloat
# NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
# reports that it is an integer
return ToFloat(a)
def _sympy_is_integer(a):
import sympy
from torch.utils._sympy.functions import ToFloat
return sympy.Eq(ToFloat(sympy.floor(a)), a)
magic_methods = {
**reflectable_magic_methods,
"sym_not": operator.invert,
"pos": operator.pos,
"eq": _sympy_eq,
"ne": _sympy_ne,
"gt": _sympy_gt,
"lt": _sympy_lt,
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"trunc": _sympy_trunc,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": operator.neg,
"sym_min": _sympy_min,
"sym_max": _sympy_max,
"sym_ite": _sympy_ite,
"abs": _sympy_abs,
"round": _sympy_round,
"is_integer": _sympy_is_integer,
}
for name in math_op_names:
sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
def sympy_is_contiguous(sizes, strides):
dim = len(sizes)
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
def sympy_is_contiguous_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if len(dim_order) != dim:
return sympy.false
is_contiguous = sympy.true
z = sympy.S.One
# Contiguous if the strides make sense (or the dim is size 1)
for d in dim_order:
is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
z *= sizes[d]
# OR if any size is zero
for d in range(dim):
is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
return is_contiguous
# NB: There is a TODO in C++ to allow omitting the batch dim. If that
# happens you will need to refactor this
def sympy_is_channels_last_contiguous_2d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_contiguous_3d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
import sympy
from torch.utils._sympy.functions import Max
dim = len(sizes)
if dim != len(dim_order):
return sympy.false
m = sympy.S.Zero
r = sympy.true
# special case for trivial C dimension. default to NCHW
r &= sympy.Ne(strides[1], 0)
for d in dim_order:
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
# Fallback to NCHW as default layout for ambiguous cases
# This is the flaw of implicit memory_format from strides.
# N111 tensor with identical strides for size 1 dimension;
# Two cases could lead us here:
# a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
# b. N11W contiguous Tensor sliced on the W-dimension.
# ([N,1,1,1]@[W,W,W,W])
if d == 0:
r &= sympy.Ne(m, strides[1])
# This is necessary to:
# 1. distinguish the memory_format of N1H1;
# [H, 1, 1, 1] channels_last stride
# [H, H, 1, 1] contiguous stride
# 2. permutation of 1C1W:
# [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
# [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
# channels_last
m = strides[d] * Max(sizes[d], 1)
return r
def sympy_is_channels_last_strides_2d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_strides_3d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
sizes_strides_methods = {
# TODO: These could also be done with indicators, maybe it is better
# for reasoning to do it that way
"is_contiguous": sympy_is_contiguous,
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
}
alternate_impl_if_hinted_methods = {
"sym_min": builtins.min,
"sym_max": builtins.max,
}
def to_node(self, num):
if isinstance(num, SymTypes):
return num.node
elif type(num) is bool:
return self.wrap_bool(num)
elif type(num) is int:
return self.wrap_int(num)
elif type(num) is float:
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
elif x.is_bool():
return SymBool(x)
else:
raise AssertionError(f"unrecognized return type {x}")
def method_to_operator(method):
return METHOD_TO_OPERATOR[method]
def _make_node_magic(method, func):
func = lru_cache(256)(func)
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
def uninteresting_files() -> set[str]:
import torch
mods = [
torch._dynamo.eval_frame,
torch._dynamo.utils,
torch.fx.experimental.sym_node,
torch,
]
import torch._dynamo.guards
return (
{inspect.getfile(m) for m in mods}
| torch._dynamo.guards.uninteresting_files()
| {"<string>"}
)
def capture_provenance(fn):
@functools.wraps(fn)
def wrapper(self, other=None):
if other is None:
result = fn(self)
else:
result = fn(self, other)
if torch._logging._internal.GET_DTRACE_STRUCTURED:
if other is not None:
arguments = [self, other]
else:
arguments = [self]
def get_id(sym_node) -> Optional[int]:
# We don't want to return an ID if the input is a constant
import sympy
if sym_node.constant is not None:
return None
elif id(sym_node) == id(result):
return None
elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
return None
elif sym_node.expr in (sympy.true, sympy.false):
return None
return id(sym_node)
dtrace_structured(
"expression_created",
metadata_fn=lambda: {
"method": method,
"result": str(result),
"result_id": id(result),
"arguments": [str(a) for a in arguments],
"argument_ids": [
get_id(i) for i in arguments if get_id(i) is not None
],
"user_stack": structured.get_user_stack(3),
"stack": structured.get_framework_stack(3),
},
)
return result
return wrapper
@capture_provenance
def binary_magic_impl(self, other):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = method_to_operator(method)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
alternate_impl = alternate_impl_if_hinted_methods.get(method)
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
assert isinstance(other, SymNode)
optimized_summation = False
try:
if method == "mod":
from torch.utils._sympy.functions import Mod, PythonMod
# Special handling for mod that requires access to the value
# ranges
shape_env = self.shape_env
if (
self.expr.is_nonnegative
or shape_env.bound_sympy(self.expr).lower >= 0
) and (
other.expr.is_nonnegative
or shape_env.bound_sympy(other.expr).lower >= 0
):
out = Mod(self.expr, other.expr)
else:
out = PythonMod(self.expr, other.expr)
elif method == "add":
# see Note [optimized_summation]
(optimized_summation, out) = func(
self.expr,
other.expr,
self._optimized_summation,
other._optimized_summation,
)
else:
# TODO: consider constant prop here
out = func(self.expr, other.expr)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
pytype: type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and
# min do not type promote. To avoid having data-dependent control flow
# here, we just set the type to float if one of the args is a float. In
# case of a type mismatch, we assume that it will be detected during
# evaluation.
if method in always_float_magic_methods:
pytype = float
elif method in always_bool_magic_methods:
pytype = bool
elif self.pytype is float or other.pytype is float:
pytype = float
else:
pytype = self.pytype
if (
pytype is not None
and out_hint is not None
and not isinstance(out_hint, SymTypes)
):
out_hint = pytype(out_hint)
# Create a FX node that corresponds to the operation being applied to
# this node.
fx_node, _ = self.shape_env._create_fx_call_function(
op, (self.fx_node, other.fx_node)
)
result = SymNode(
out,
self.shape_env,
pytype,
out_hint,
fx_node=fx_node,
optimized_summation=optimized_summation, # see Note [optimized_summation]
)
return result
@capture_provenance
def unary_magic_impl(self):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = method_to_operator(method)
if get_proxy_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
if method == "floor" or method == "ceiling":
expr = self.shape_env._simplify_floor_div(expr)
try:
out = func(expr)
except Exception:
log.warning("failed to eval %s(%s)", method, expr)
raise
sym_node_log.debug("%s %s -> %s", func, expr, out)
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
pytype: type
if method in always_int_magic_methods:
pytype = int
elif method in always_bool_magic_methods:
pytype = bool
elif method in always_float_magic_methods:
pytype = float
else:
pytype = self.pytype
fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
if method in unary_methods:
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
out_hint = then_node.hint if pred_node.hint else else_node.hint
if get_proxy_mode():
return to_node(
pred_node,
handle_sym_dispatch(
sym_ite,
(
wrap_node(pred_node),
wrap_node(then_node),
wrap_node(else_node),
),
{},
),
)
try:
out = func(pred_node.expr, then_node.expr, else_node.expr)
except Exception:
log.warning(
"failed to eval %s(%s, %s, %s)",
method,
pred_node.expr,
then_node.expr,
else_node.expr,
)
raise
fx_node, _ = pred_node.shape_env._create_fx_call_function(
sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
)
return SymNode(
out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
)
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = builtins.round
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
expr = self.expr
try:
out = func(expr, ndigits)
except Exception:
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
raise
if ndigits is None:
pytype = int
else:
pytype = self.pytype
out_hint = None
if self.hint is not None:
out_hint = op(self.hint, ndigits)
# Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
# same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
# without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
# hack down below works, because all round function down the line all take ndigits=None as default in their
# signature.
# TODO: Remove the args construction below if a different sentinel is used by FX.
# ezyang(May 2024): LOL
args = [self.fx_node]
if ndigits is not None:
args.append(ndigits)
fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
setattr(SymNode, f"_{method_attr}", round_impl)
else:
setattr(SymNode, f"_{method_attr}", binary_magic_impl)
def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = getattr(sys.modules[__name__], method)
if get_proxy_mode():
return to_node(
self,
handle_sym_dispatch(
op,
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
{},
),
)
size_exprs = [s.expr for s in sizes]
stride_exprs = [s.expr for s in strides]
try:
out = func(size_exprs, stride_exprs)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
raise
# bool is never expandable
size_hints = []
out_hint = None
for s in sizes:
if s.hint is None:
break
size_hints.append(s.hint)
else:
stride_hints = []
for s in strides:
if s.hint is None:
break
stride_hints.append(s.hint)
else:
out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool!
pytype: type
if method.endswith("_indicator"):
pytype = int
else:
pytype = bool
return SymNode(out, self.shape_env, pytype, out_hint)
setattr(SymNode, f"_{method}", sizes_strides_impl)
# TODO: This is technically hotpath, but in the ideal end state
# guards on this will resolve at a higher level so you never
# spend time in this code
def sizes_strides_user(sizes, strides):
import sympy
from torch.fx.experimental.symbolic_shapes import (
eval_is_non_overlapping_and_dense,
)
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(
getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
)
)
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
else:
# TODO: this is an awful implementation
return bool(
func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
)
)
# Skip for is_non_overlapping_and_dense_indicator
if not hasattr(sys.modules[__name__], method):
setattr(sys.modules[__name__], method, sizes_strides_user)
for method, func in magic_methods.items():
_make_node_magic(method, func)
for method, func in sizes_strides_methods.items():
_make_node_sizes_strides(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"sym_{method}"
else:
method_attr = method
def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
if isinstance(x, (int, float, bool)):
return x
if isinstance(x, SymBool):
return x.node.guard_bool("", 0)
raise AssertionError("expect to be called with constant SymBools")
def is_constant(x):
if isinstance(x, (int, float, bool)):
return True
if isinstance(x, (SymInt, SymFloat, SymBool)):
return x.node.is_constant()
return False
# Promotion rules for binary operations. NB: we preserve PYTHON semantics
# - if args are same type, do nothing
# - if one arg is float, promote other arg to float
# - nb: this applies to floordiv, even though output is integral
# (it's still float)
# - pow is funny business
# - if both ints
# - trigger a guard on exponent >= 0
# - if non-negative, output is int
# - otherwise, output is float
# - otherwise, promote other arg to float
# - nb: complex is impossible to handle correctly lol, with
# negative base and integral float need to diverge semantics and
# just always return complex. Neener neener pretend this problem
# doesn't exist
# - equality is pain: Python does the fancy thing where it unpacks the
# mantissa from the float and then compares that against the int.
# Which means it is able to tell that
# 9007199254740993 != 9007199254740992. (rather than if the LHS was
# promoted to float, in which case it would have truncated to the RHS
# and subsequently been equal). We'll model this exactly by having
# special mixed type equality operations. Unfortunately, we need to
# do this for all comparison operations (maybe I'll only implement
# compare)
# - sym_ite mumble mumble really shouldn't allow mixed but whatever
if method in bool_becomes_int_magic_methods:
def promote(x):
"""Implements True+True=2, which works in python but not sympy"""
if isinstance(x, SymBool):
return SymInt(x.node.wrap_int(int(x)))
return x
else:
def promote(x):
return x
def promote2(self, other):
# TODO: Remove eq and other relations from this list.
# CPython has fancy implementations for these to get as much precision
# as possible instead of just promoting to float64 and praying, so we
# need to handle them specially too.
# Also, note that int_truediv doesn't go through this path: both
# arguments are "int" so there isn't any promotion
if method not in [
"add",
"sub",
"mul",
"mod",
"float_pow",
"float_truediv",
"int_floordiv",
"sym_min",
"sym_max",
# TODO: remove these
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
]:
return self, other
f_self = isinstance(self, (float, torch.SymFloat))
f_other = isinstance(other, (float, torch.SymFloat))
if f_self or f_other:
if not f_self:
self = torch.sym_float(self)
if not f_other:
other = torch.sym_float(other)
return self, other
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
def unary_magic_impl(self):
self = promote(self)
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
def binary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
self = promote(self)
other = promote(other)
self, other = promote2(self, other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(self.node, method_attr)(other_node))
return get_constant(ret) if is_constant(ret) else ret
def rbinary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
return NotImplemented
self = promote(self)
other = promote(other)
self, other = promote2(self, other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(other_node, method_attr)(self.node))
return get_constant(ret) if is_constant(ret) else ret
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
elif method in unary_nonmagic_methods:
orig = getattr(user_type, method)
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
elif method == "sym_ite":
def sym_ite_magic_impl(pred, then_val, else_val):
pred_node = pred.node
then_node = to_node(pred_node, then_val)
else_node = to_node(pred_node, else_val)
if then_node is NotImplemented or else_node is NotImplemented:
return NotImplemented
assert (
isinstance(then_node, SymNode)
and isinstance(else_node, SymNode)
and then_node.pytype == else_node.pytype
)
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
return get_constant(ret) if ret.node.is_constant() else ret
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
elif method == "round":
def round_magic_impl(self, ndigits=None):
if is_constant(self):
return builtins.round(get_constant(self), ndigits)
return wrap_node(getattr(self.node, method)(ndigits))
setattr(user_type, f"__{method}__", round_magic_impl)
else:
method_name = method
if method in bitwise_ops:
method_name = bitwise_ops[method]
setattr(user_type, f"__{method_name}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment]
if method in only_bool_magic_methods:
_make_user_magic(method, SymBool)
continue
if method in only_float_magic_methods:
_make_user_magic(method, SymFloat)
continue
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
_make_user_magic(method, SymBool)
_make_user_magic(method, SymInt)
if method not in bitwise_ops:
_make_user_magic(method, SymFloat)
del method
del func