581 lines
13 KiB
Python
581 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
import math
|
|
import operator
|
|
from typing import Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.functions import (
|
|
_keep_float,
|
|
BitwiseFn_bitwise_and,
|
|
BitwiseFn_bitwise_or,
|
|
FloatPow,
|
|
FloatTrueDiv,
|
|
FloorDiv,
|
|
IntTrueDiv,
|
|
Max,
|
|
Min,
|
|
Mod,
|
|
OpaqueUnaryFn_exp,
|
|
OpaqueUnaryFn_log,
|
|
OpaqueUnaryFn_log2,
|
|
OpaqueUnaryFn_sqrt,
|
|
PowByNatural,
|
|
RoundDecimal,
|
|
RoundToInt,
|
|
ToFloat,
|
|
TruncToInt,
|
|
)
|
|
|
|
|
|
# The sympy interpretation of operators. It will also sometimes work with
|
|
# plain int/float, but if you do certain operations you will get out a
|
|
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
|
|
# check PythonReferenceAnalysis.
|
|
# NB: For magic methods this needs to use normal magic methods
|
|
# so that test_magic_methods works
|
|
class ReferenceAnalysis:
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
return sympy.sympify(c)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return a | b
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
|
|
return sympy.Eq(a, b)
|
|
return a == b
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return cls.not_(cls.eq(a, b))
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return a < b
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return a > b
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return a <= b
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return a >= b
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
assert not isinstance(a, bool)
|
|
return ~a
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return FloatTrueDiv(1.0, x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return PowByNatural(x, 2)
|
|
|
|
@staticmethod
|
|
def trunc_to_int(x, dtype):
|
|
return TruncToInt(x)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return sympy.ceiling(x)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return sympy.floor(x)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return _keep_float(sympy.floor)(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return _keep_float(sympy.ceiling)(x)
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
if dtype == torch.float64:
|
|
return ToFloat(x)
|
|
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return Mod(x, y)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return abs(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return -x
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return FloatTrueDiv(a, b)
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
return IntTrueDiv(a, b)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return FloorDiv(a, b)
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
raise NotImplementedError("TODO: truncdiv")
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return _keep_float(operator.add)(a, b)
|
|
|
|
@classmethod
|
|
def sym_sum(cls, args):
|
|
return sympy.Add(*args)
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return _keep_float(operator.mul)(a, b)
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return _keep_float(operator.sub)(a, b)
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return OpaqueUnaryFn_exp(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return OpaqueUnaryFn_log(x)
|
|
|
|
@staticmethod
|
|
def log2(x):
|
|
return OpaqueUnaryFn_log2(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return OpaqueUnaryFn_sqrt(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return _keep_float(FloatPow)(a, b)
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
return PowByNatural(a, b)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return Min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return Max(a, b)
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return RoundToInt(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
return RoundDecimal(a, b)
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return BitwiseFn_bitwise_and(a, b)
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return BitwiseFn_bitwise_or(a, b)
|
|
|
|
|
|
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
|
# Python types and is FX traceable. Inheritance here is purely for code
|
|
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
|
|
class PythonReferenceAnalysis(ReferenceAnalysis):
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
if dtype is torch.int64:
|
|
return int(c)
|
|
elif dtype is torch.double:
|
|
return float(c)
|
|
elif dtype is torch.bool:
|
|
return bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.sym_not(a)
|
|
|
|
@classmethod
|
|
def sym_sum(cls, args):
|
|
if len(args) == 0:
|
|
return 0
|
|
if len(args) == 1:
|
|
return args[0]
|
|
acc = cls.add(args[0], args[1])
|
|
for i in range(2, len(args)):
|
|
acc = cls.add(acc, args[i])
|
|
return acc
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
return x % y
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
if dtype == torch.float64:
|
|
return torch.sym_float(x)
|
|
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
raise AssertionError("exp is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
raise AssertionError("log is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def log2(x):
|
|
return torch._sym_log2(x) # type: ignore[attr-defined]
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch._sym_sqrt(x) # type: ignore[attr-defined]
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.sym_min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.sym_max(a, b)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return math.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return math.ceil(x)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return float(math.floor(x))
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return float(math.ceil(x))
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
# Pray that safe_pow is not needed here lol. In particular, this
|
|
# never participates in VR low/high ranges, so overflow should be
|
|
# unlikely
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return round(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
return round(a, ndigits=b)
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return a | b
|
|
|
|
|
|
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
|
|
# operators to make things faster
|
|
class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
|
|
@staticmethod
|
|
def sym_sum(args):
|
|
return torch.sym_sum(args)
|
|
|
|
|
|
def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
return torch.ops.prims.convert_element_type.default(x, dtype)
|
|
|
|
|
|
# Suppose we have some int/float arguments. This diagram commutes:
|
|
#
|
|
# int/float -- PythonReferenceAnalysis.op --> int/float
|
|
# | |
|
|
# | |
|
|
# torch.tensor(..., dtype=torch.int64/torch.float64)
|
|
# | |
|
|
# V V
|
|
# Tensor -- TensorReferenceAnalysis.op --> Tensor
|
|
#
|
|
# NB: int before and after must be representable in int64 (we will
|
|
# insert guards accordingly.)
|
|
#
|
|
# This is guaranteed to be FX traceable with OpOverloads only.
|
|
class TensorReferenceAnalysis:
|
|
# NB: This is actually dead, because with Proxy tracing the factory
|
|
# function isn't traced correctly. Here for completeness.
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
d: Union[int, float, bool]
|
|
if dtype is torch.int64:
|
|
d = int(c)
|
|
elif dtype is torch.double:
|
|
d = float(c)
|
|
elif dtype is torch.bool:
|
|
d = bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
return torch.ops.aten.scalar_tensor.default(d, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return torch.ops.aten.logical_or.default(a, b)
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return torch.ops.aten.logical_and.default(a, b)
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return torch.ops.aten.bitwise_and(a, b)
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return torch.ops.aten.bitwise_or(a, b)
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
return torch.ops.aten.eq.Tensor(a, b)
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return torch.ops.aten.ne.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return torch.ops.aten.lt.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return torch.ops.aten.gt.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return torch.ops.aten.le.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return torch.ops.aten.ge.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.ops.aten.logical_not.default(a)
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return torch.ops.aten.reciprocal.default(x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
# TODO: maybe composite implicit autograd doesn't work here?
|
|
return torch.ops.aten.square.default(x)
|
|
|
|
@staticmethod
|
|
def trunc_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.trunc.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def ceil_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.ceil.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def floor_to_int(x, dtype):
|
|
return _to_dtype(torch.ops.aten.floor.default(x), dtype)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return torch.ops.aten.floor.default(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return torch.ops.aten.ceil.default(x)
|
|
|
|
@staticmethod
|
|
def to_dtype(x, dtype):
|
|
return _to_dtype(x, dtype)
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
# TODO: https://github.com/pytorch/pytorch/pull/133654
|
|
raise NotImplementedError(
|
|
"no C-style modulus operation available from frontend atm"
|
|
)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return torch.ops.aten.abs.default(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return torch.ops.aten.neg.default(x)
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return torch.ops.aten.true_divide.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
raise NotImplementedError(
|
|
"Python int truediv difficult to implement in PyTorch atm"
|
|
)
|
|
|
|
# TODO: This is wrong, CPython has a custom implementation of true
|
|
# division that results in higher precision when the floats are
|
|
# sufficiently large. Short term fix: add a guard here
|
|
return torch.ops.aten.true_divide.default(
|
|
_to_dtype(a, torch.float64), _to_dtype(b, torch.float64)
|
|
)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
raise NotImplementedError(
|
|
"no C-style truncdiv operation available from frontend atm"
|
|
)
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return torch.ops.aten.add.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return torch.ops.aten.mul.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return torch.ops.aten.sub.Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return torch.ops.aten.exp.default(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return torch.ops.aten.log.default(x)
|
|
|
|
@staticmethod
|
|
def log2(x):
|
|
return torch.ops.aten.log2.default(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch.ops.aten.sqrt.default(x)
|
|
|
|
@staticmethod
|
|
def sin(x):
|
|
return torch.ops.aten.sin.default(x)
|
|
|
|
@staticmethod
|
|
def cos(x):
|
|
return torch.ops.aten.cos.default(x)
|
|
|
|
@staticmethod
|
|
def tanh(x):
|
|
return torch.ops.aten.tanh.default(x)
|
|
|
|
@staticmethod
|
|
def sinh(x):
|
|
return torch.ops.aten.sinh.default(x)
|
|
|
|
@staticmethod
|
|
def cosh(x):
|
|
return torch.ops.aten.cosh.default(x)
|
|
|
|
@staticmethod
|
|
def tan(x):
|
|
return torch.ops.aten.tan.default(x)
|
|
|
|
@staticmethod
|
|
def acos(x):
|
|
return torch.ops.aten.acos.default(x)
|
|
|
|
@staticmethod
|
|
def atan(x):
|
|
return torch.ops.aten.atan.default(x)
|
|
|
|
@staticmethod
|
|
def asin(x):
|
|
return torch.ops.aten.asin.default(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return torch.ops.aten.pow.Tensor_Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def pow_by_natural(a, b):
|
|
# NB: pow handles int x int fine
|
|
return torch.ops.aten.pow.Tensor_Tensor(a, b)
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.ops.aten.minimum.default(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.ops.aten.maximum.default(a, b)
|
|
|
|
@staticmethod
|
|
def round_to_int(a, dtype):
|
|
return torch.ops.aten.round.default(a)
|
|
|
|
@staticmethod
|
|
def round_decimal(a, b):
|
|
raise NotImplementedError(
|
|
"round decimal doesn't support Tensor second argument atm"
|
|
)
|
|
|
|
# return torch.ops.aten.round.decimals(a, b)
|