import sys from typing import Optional import sympy from sympy.printing.precedence import PRECEDENCE, precedence from sympy.printing.str import StrPrinter INDEX_TYPE = "int64_t" # This printer contains rules that are supposed to be generic for both C/C++ and # Python class ExprPrinter(StrPrinter): # override this so that _print_FloorDiv is used printmethod = "_torch_sympystr" def _print_Mul(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, "*", precedence(expr)) def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: return self.stringify(expr.args, " + ", precedence(expr)) def _print_Relational(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr)) def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"]) def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"]) # NB: this is OK to put here, because Mod is only defined for positive # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str: s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) return f"({s})" def _print_CleanDiv(self, expr: sympy.Expr) -> str: return self._print_FloorDiv(expr) def _print_Identity(self, expr: sympy.Expr) -> str: return self._print(expr.args[0]) # This must be implemented because sympy will collect x * x into Pow(x, 2), without # any explicit intervention. We print it just like x * x, notably, we # never generate sympy.Pow with floats. # # NB: this pow by natural, you should never have used builtin sympy.pow # for FloatPow, and a symbolic exponent should be PowByNatural. These # means exp is guaranteed to be integer. def _print_Pow(self, expr: sympy.Expr) -> str: base, exp = expr.args assert exp == int(exp), exp exp = int(exp) assert exp >= 0 if exp > 0: return self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) return "1" # Explicit NotImplemented functions are to prevent default sympy printing # behavior, which will just barf out ToFloat(...) to your IR. The error # message is better here because it tells you which printer class it needs # to go in. def _print_ToFloat(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") def _print_Infinity(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: raise NotImplementedError( f"_print_NegativeInfinity not implemented for {type(self)}" ) def _print_FloorDiv(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") def _print_PythonMod(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") def _print_PowByNatural(self, expr: sympy.Expr) -> str: raise NotImplementedError( f"_print_PowByNatural not implemented for {type(self)}" ) def _print_FloatPow(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") def _print_TruncToInt(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") def _print_RoundToInt(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise NotImplementedError( f"_print_RoundDecimal not implemented for {type(self)}" ) # NB: Some float operations are INTENTIONALLY not implemented for # printers. You can implement them as a quick unblock, but it is better # to ask yourself why we haven't done this computation in the Tensor # universe instead def _print_TruncToFloat(self, expr: sympy.Expr) -> str: raise NotImplementedError( f"_print_TruncToFloat not implemented for {type(self)}" ) class PythonPrinter(ExprPrinter): def _print_ToFloat(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 # NB: We use sym_float here because the printer is used for cache # serialization, and cache guards get evaluated with SymInt to # propagate guards to the parent ShapeEnv. However, this comes at a # runtime cost for guards involving float. If this is unacceptable # overhead, what you want to do is have two separate printers for # SymInt, one for when the inputs are guaranteed to be int, and # another for when they could be SymInt. # # NB: sym_min/sym_max also have this problem, but I chose not to fix # those. # # See https://github.com/pytorch/pytorch/issues/142507 for more # context. return f"torch.sym_float({self._print(expr.args[0])})" def _print_And(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " and ", precedence(expr)) def _print_Or(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " or ", precedence(expr)) def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = ( self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args ) if div != "1": x = f"({x} // {div})" return f"({x} % {mod})" def _print_Infinity(self, expr: sympy.Expr) -> str: return "math.inf" def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: return "-math.inf" # WARNING: this is dangerous for Triton, which has C-style modulus def _print_PythonMod(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr: sympy.Expr) -> str: x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args) return f"{x} // {div}" # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python # does a special algorithm def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) def _helper_sqrt(self, expr: sympy.Expr) -> str: return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: return self._helper_sqrt(expr.args[0]) def _print_FloatPow(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) # TODO: Not sure this works with Triton, even when base/exp are integral def _print_PowByNatural(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) def _print_floor(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" def _print_FloorToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" def _print_TruncToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" def _print_CeilToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" def _print_Abs(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" # NB: It's expected that we've made explicit any promotion in the sympy # expression, so it doesn't matter that Python max/min doesn't perform # promotion def _print_Max(self, expr: sympy.Expr) -> str: assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" def _print_Min(self, expr: sympy.Expr) -> str: assert len(expr.args) >= 2 return f"min({', '.join(map(self._print, expr.args))})" def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.cos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.cosh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.acos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.sin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.sinh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.asin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.tan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.tanh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr: sympy.Expr) -> str: assert len(expr.args) == 2 number, ndigits = expr.args assert isinstance(ndigits, sympy.Integer) return f"round({self._print(number)}, {ndigits})" class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: return ( f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" ) def _print_Where(self, expr: sympy.Expr) -> str: c, p, q = ( self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args ) return f"{c} ? {p} : {q}" def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) if div != 1: div = self.doprint(div) if expr.is_integer: x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" else: x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" mod = self.doprint(mod) return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))" def _print_FloorDiv(self, expr: sympy.Expr) -> str: x, div = expr.args x = self.doprint(x) div = self.doprint(div) if expr.is_integer: return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" def _print_floor(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_FloorToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_TruncToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 r = f"std::trunc({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" def _print_TruncToFloat(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::trunc({self._print(expr.args[0])})" def _print_ToFloat(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"static_cast({self._print(expr.args[0])})" def _print_PythonMod(self, expr: sympy.Expr) -> str: x, div = expr.args x = self.doprint(x) div = self.doprint(div) return f"c10::div_mod({x}, {div})" def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: lhs, rhs = expr.args # TODO: This is only accurate up to 2**53 return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT # use std::pow, that operates on floats def _print_PowByNatural(self, expr: sympy.Expr) -> str: raise NotImplementedError( f"_print_PowByNatural not implemented for {type(self)}" ) def _print_FloatPow(self, expr: sympy.Expr) -> str: base, exp = expr.args return f"std::pow({self._print(base)}, {self._print(exp)})" def _print_Pow(self, expr: sympy.Expr) -> str: # Uses float constants to perform FP div base, exp = expr.args if exp == 0.5 or exp == -0.5: base = self._print(base) return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" if exp.is_integer: exp = int(exp) if exp > 0: r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) elif exp < -1: r = ( "1.0/(" + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + ")" ) elif exp == -1: r = "1.0/" + self._print(base) else: # exp == 0 r = "1.0" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r else: # TODO: float vs double return f"std::pow({base}, {float(exp)})" def _print_Rational(self, expr: sympy.Expr) -> str: # Uses float constants to perform FP div if expr.q == 1: r = f"{expr.p}" else: r = f"{expr.p}.0/{expr.q}.0" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_ceiling(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_CeilToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r def _print_Min(self, expr: sympy.Expr) -> str: args = [self._print(a) for a in expr.args] if len(args) == 2: return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" else: # Initializer list overload il = "{" + ", ".join(args) + "}" return f"std::min({il})" def _print_Max(self, expr: sympy.Expr) -> str: args = [self._print(a) for a in expr.args] if len(args) == 2: return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" else: # Initializer list overload il = "{" + ", ".join(args) + "}" return f"std::max({il})" def _print_Abs(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::abs({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::cos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::cosh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::acos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::sin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::sinh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::asin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::tan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::tanh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"std::atan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: return f"std::sqrt({self._print(expr.args[0])})" def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr: sympy.Expr) -> str: assert len(expr.args) == 2 number, ndigits = expr.args if number.is_integer: # ndigits < 0 should have been filtered by the sympy function assert ndigits < 0 raise ValueError( f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." ) number_str = self.parenthesize(number, PRECEDENCE["Mul"]) return f"static_cast(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})" def _print_BooleanTrue(self, expr: sympy.Expr) -> str: return "true" def _print_BooleanFalse(self, expr: sympy.Expr) -> str: return "false" def _print_Infinity(self, expr: sympy.Expr) -> str: return "std::numeric_limits::infinity()" def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: return f"-{self._print_Infinity(expr)}"