team-10/env/Lib/site-packages/torch/utils/_sympy/printers.py
2025-08-02 07:34:44 +02:00

477 lines
19 KiB
Python

import sys
from typing import Optional
import sympy
from sympy.printing.precedence import PRECEDENCE, precedence
from sympy.printing.str import StrPrinter
INDEX_TYPE = "int64_t"
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(StrPrinter):
# override this so that _print_FloorDiv is used
printmethod = "_torch_sympystr"
def _print_Mul(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, "*", precedence(expr))
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
return self.stringify(expr.args, " + ", precedence(expr))
def _print_Relational(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
return f"({s})"
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
return self._print_FloorDiv(expr)
def _print_Identity(self, expr: sympy.Expr) -> str:
return self._print(expr.args[0])
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# NB: We use sym_float here because the printer is used for cache
# serialization, and cache guards get evaluated with SymInt to
# propagate guards to the parent ShapeEnv. However, this comes at a
# runtime cost for guards involving float. If this is unacceptable
# overhead, what you want to do is have two separate printers for
# SymInt, one for when the inputs are guaranteed to be int, and
# another for when they could be SymInt.
#
# NB: sym_min/sym_max also have this problem, but I chose not to fix
# those.
#
# See https://github.com/pytorch/pytorch/issues/142507 for more
# context.
return f"torch.sym_float({self._print(expr.args[0])})"
def _print_And(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " and ", precedence(expr))
def _print_Or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " or ", precedence(expr))
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
if div != "1":
x = f"({x} // {div})"
return f"({x} % {mod})"
def _print_Infinity(self, expr: sympy.Expr) -> str:
return "math.inf"
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
return f"{x} // {div}"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
def _helper_sqrt(self, expr: sympy.Expr) -> str:
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr: sympy.Expr) -> str:
c, p, q = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)
if div != 1:
div = self.doprint(div)
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.doprint(mod)
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = expr.args
x = self.doprint(x)
div = self.doprint(div)
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
def _print_PythonMod(self, expr: sympy.Expr) -> str:
x, div = expr.args
x = self.doprint(x)
div = self.doprint(div)
return f"c10::div_mod({x}, {div})"
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
base, exp = expr.args
if exp == 0.5 or exp == -0.5:
base = self._print(base)
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
elif exp < -1:
r = (
"1.0/("
+ self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
+ ")"
)
elif exp == -1:
r = "1.0/" + self._print(base)
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
return "true"
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
return "false"
def _print_Infinity(self, expr: sympy.Expr) -> str:
return "std::numeric_limits<double>::infinity()"
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
return f"-{self._print_Infinity(expr)}"