# mypy: allow-untyped-defs import math import operator from typing import Union import sympy import torch from torch.utils._sympy.functions import ( _keep_float, BitwiseFn_bitwise_and, BitwiseFn_bitwise_or, FloatPow, FloatTrueDiv, FloorDiv, IntTrueDiv, Max, Min, Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_log2, OpaqueUnaryFn_sqrt, PowByNatural, RoundDecimal, RoundToInt, ToFloat, TruncToInt, ) # The sympy interpretation of operators. It will also sometimes work with # plain int/float, but if you do certain operations you will get out a # sympy.Basic in the end. If you want the Python/FX traceable interpretation, # check PythonReferenceAnalysis. # NB: For magic methods this needs to use normal magic methods # so that test_magic_methods works class ReferenceAnalysis: @staticmethod def constant(c, dtype): return sympy.sympify(c) @staticmethod def or_(a, b): return a | b @staticmethod def and_(a, b): return a & b @staticmethod def eq(a, b): if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr): return sympy.Eq(a, b) return a == b @classmethod def ne(cls, a, b): return cls.not_(cls.eq(a, b)) @staticmethod def lt(a, b): return a < b @staticmethod def gt(a, b): return a > b @staticmethod def le(a, b): return a <= b @staticmethod def ge(a, b): return a >= b @staticmethod def not_(a): assert not isinstance(a, bool) return ~a @staticmethod def reciprocal(x): return FloatTrueDiv(1.0, x) @staticmethod def square(x): return PowByNatural(x, 2) @staticmethod def trunc_to_int(x, dtype): return TruncToInt(x) @staticmethod def ceil_to_int(x, dtype): return sympy.ceiling(x) @staticmethod def floor_to_int(x, dtype): return sympy.floor(x) @staticmethod def floor(x): return _keep_float(sympy.floor)(x) @staticmethod def ceil(x): return _keep_float(sympy.ceiling)(x) @staticmethod def to_dtype(x, dtype): if dtype == torch.float64: return ToFloat(x) raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): return Mod(x, y) @staticmethod def abs(x): return abs(x) @staticmethod def neg(x): return -x @staticmethod def truediv(a, b): return FloatTrueDiv(a, b) @staticmethod def int_truediv(a, b): return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): return FloorDiv(a, b) @staticmethod def truncdiv(a, b): raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): return _keep_float(operator.add)(a, b) @classmethod def sym_sum(cls, args): return sympy.Add(*args) @staticmethod def mul(a, b): return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): return OpaqueUnaryFn_exp(x) @staticmethod def log(x): return OpaqueUnaryFn_log(x) @staticmethod def log2(x): return OpaqueUnaryFn_log2(x) @staticmethod def sqrt(x): return OpaqueUnaryFn_sqrt(x) @staticmethod def pow(a, b): return _keep_float(FloatPow)(a, b) @staticmethod def pow_by_natural(a, b): return PowByNatural(a, b) @staticmethod def minimum(a, b): return Min(a, b) @staticmethod def maximum(a, b): return Max(a, b) @staticmethod def round_to_int(a, dtype): return RoundToInt(a) @staticmethod def round_decimal(a, b): return RoundDecimal(a, b) @staticmethod def bitwise_and(a, b): return BitwiseFn_bitwise_and(a, b) @staticmethod def bitwise_or(a, b): return BitwiseFn_bitwise_or(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain # Python types and is FX traceable. Inheritance here is purely for code # sharing (TODO: considering splitting out a BaseReferenceAnalysis). class PythonReferenceAnalysis(ReferenceAnalysis): @staticmethod def constant(c, dtype): if dtype is torch.int64: return int(c) elif dtype is torch.double: return float(c) elif dtype is torch.bool: return bool(c) else: raise AssertionError(f"unrecognized dtype {dtype}") @staticmethod def not_(a): return torch.sym_not(a) @classmethod def sym_sum(cls, args): if len(args) == 0: return 0 if len(args) == 1: return args[0] acc = cls.add(args[0], args[1]) for i in range(2, len(args)): acc = cls.add(acc, args[i]) return acc @staticmethod def floordiv(a, b): return a // b @staticmethod def mod(x, y): return x % y @staticmethod def truncdiv(a, b): return a / b @staticmethod def to_dtype(x, dtype): if dtype == torch.float64: return torch.sym_float(x) raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @staticmethod def log(x): raise AssertionError("log is not valid shape sympy expr") @staticmethod def log2(x): return torch._sym_log2(x) # type: ignore[attr-defined] @staticmethod def sqrt(x): return torch._sym_sqrt(x) # type: ignore[attr-defined] @staticmethod def minimum(a, b): return torch.sym_min(a, b) @staticmethod def maximum(a, b): return torch.sym_max(a, b) @staticmethod def floor_to_int(x, dtype): return math.floor(x) @staticmethod def ceil_to_int(x, dtype): return math.ceil(x) @staticmethod def floor(x): return float(math.floor(x)) @staticmethod def ceil(x): return float(math.ceil(x)) @staticmethod def truediv(a, b): return a / b @staticmethod def pow(a, b): return a**b @staticmethod def pow_by_natural(a, b): # Pray that safe_pow is not needed here lol. In particular, this # never participates in VR low/high ranges, so overflow should be # unlikely return a**b @staticmethod def round_to_int(a, dtype): return round(a) @staticmethod def round_decimal(a, b): return round(a, ndigits=b) @staticmethod def bitwise_and(a, b): return a & b @staticmethod def bitwise_or(a, b): return a | b # Like PythonReferenceAnalysis, but some export-unfriendly choices of # operators to make things faster class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis): @staticmethod def sym_sum(args): return torch.sym_sum(args) def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return torch.ops.prims.convert_element_type.default(x, dtype) # Suppose we have some int/float arguments. This diagram commutes: # # int/float -- PythonReferenceAnalysis.op --> int/float # | | # | | # torch.tensor(..., dtype=torch.int64/torch.float64) # | | # V V # Tensor -- TensorReferenceAnalysis.op --> Tensor # # NB: int before and after must be representable in int64 (we will # insert guards accordingly.) # # This is guaranteed to be FX traceable with OpOverloads only. class TensorReferenceAnalysis: # NB: This is actually dead, because with Proxy tracing the factory # function isn't traced correctly. Here for completeness. @staticmethod def constant(c, dtype): d: Union[int, float, bool] if dtype is torch.int64: d = int(c) elif dtype is torch.double: d = float(c) elif dtype is torch.bool: d = bool(c) else: raise AssertionError(f"unrecognized dtype {dtype}") return torch.ops.aten.scalar_tensor.default(d, dtype=dtype) @staticmethod def or_(a, b): return torch.ops.aten.logical_or.default(a, b) @staticmethod def and_(a, b): return torch.ops.aten.logical_and.default(a, b) @staticmethod def bitwise_and(a, b): return torch.ops.aten.bitwise_and(a, b) @staticmethod def bitwise_or(a, b): return torch.ops.aten.bitwise_or(a, b) @staticmethod def eq(a, b): return torch.ops.aten.eq.Tensor(a, b) @classmethod def ne(cls, a, b): return torch.ops.aten.ne.Tensor(a, b) @staticmethod def lt(a, b): return torch.ops.aten.lt.Tensor(a, b) @staticmethod def gt(a, b): return torch.ops.aten.gt.Tensor(a, b) @staticmethod def le(a, b): return torch.ops.aten.le.Tensor(a, b) @staticmethod def ge(a, b): return torch.ops.aten.ge.Tensor(a, b) @staticmethod def not_(a): return torch.ops.aten.logical_not.default(a) @staticmethod def reciprocal(x): return torch.ops.aten.reciprocal.default(x) @staticmethod def square(x): # TODO: maybe composite implicit autograd doesn't work here? return torch.ops.aten.square.default(x) @staticmethod def trunc_to_int(x, dtype): return _to_dtype(torch.ops.aten.trunc.default(x), dtype) @staticmethod def ceil_to_int(x, dtype): return _to_dtype(torch.ops.aten.ceil.default(x), dtype) @staticmethod def floor_to_int(x, dtype): return _to_dtype(torch.ops.aten.floor.default(x), dtype) @staticmethod def floor(x): return torch.ops.aten.floor.default(x) @staticmethod def ceil(x): return torch.ops.aten.ceil.default(x) @staticmethod def to_dtype(x, dtype): return _to_dtype(x, dtype) @staticmethod def mod(x, y): # TODO: https://github.com/pytorch/pytorch/pull/133654 raise NotImplementedError( "no C-style modulus operation available from frontend atm" ) @staticmethod def abs(x): return torch.ops.aten.abs.default(x) @staticmethod def neg(x): return torch.ops.aten.neg.default(x) @staticmethod def truediv(a, b): return torch.ops.aten.true_divide.Tensor(a, b) @staticmethod def int_truediv(a, b): raise NotImplementedError( "Python int truediv difficult to implement in PyTorch atm" ) # TODO: This is wrong, CPython has a custom implementation of true # division that results in higher precision when the floats are # sufficiently large. Short term fix: add a guard here return torch.ops.aten.true_divide.default( _to_dtype(a, torch.float64), _to_dtype(b, torch.float64) ) @staticmethod def floordiv(a, b): return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor") @staticmethod def truncdiv(a, b): raise NotImplementedError( "no C-style truncdiv operation available from frontend atm" ) @staticmethod def add(a, b): return torch.ops.aten.add.Tensor(a, b) @staticmethod def mul(a, b): return torch.ops.aten.mul.Tensor(a, b) @staticmethod def sub(a, b): return torch.ops.aten.sub.Tensor(a, b) @staticmethod def exp(x): return torch.ops.aten.exp.default(x) @staticmethod def log(x): return torch.ops.aten.log.default(x) @staticmethod def log2(x): return torch.ops.aten.log2.default(x) @staticmethod def sqrt(x): return torch.ops.aten.sqrt.default(x) @staticmethod def sin(x): return torch.ops.aten.sin.default(x) @staticmethod def cos(x): return torch.ops.aten.cos.default(x) @staticmethod def tanh(x): return torch.ops.aten.tanh.default(x) @staticmethod def sinh(x): return torch.ops.aten.sinh.default(x) @staticmethod def cosh(x): return torch.ops.aten.cosh.default(x) @staticmethod def tan(x): return torch.ops.aten.tan.default(x) @staticmethod def acos(x): return torch.ops.aten.acos.default(x) @staticmethod def atan(x): return torch.ops.aten.atan.default(x) @staticmethod def asin(x): return torch.ops.aten.asin.default(x) @staticmethod def pow(a, b): return torch.ops.aten.pow.Tensor_Tensor(a, b) @staticmethod def pow_by_natural(a, b): # NB: pow handles int x int fine return torch.ops.aten.pow.Tensor_Tensor(a, b) @staticmethod def minimum(a, b): return torch.ops.aten.minimum.default(a, b) @staticmethod def maximum(a, b): return torch.ops.aten.maximum.default(a, b) @staticmethod def round_to_int(a, dtype): return torch.ops.aten.round.default(a) @staticmethod def round_decimal(a, b): raise NotImplementedError( "round decimal doesn't support Tensor second argument atm" ) # return torch.ops.aten.round.decimals(a, b)