298 lines
10 KiB
Python
298 lines
10 KiB
Python
|
|
||
|
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
|
||
|
from sympy.matrices.expressions import MatrixExpr
|
||
|
from sympy.core.mul import Mul
|
||
|
from sympy.printing.precedence import PRECEDENCE
|
||
|
from sympy.external import import_module
|
||
|
from sympy.codegen.cfunctions import Sqrt
|
||
|
from sympy import S
|
||
|
from sympy import Integer
|
||
|
|
||
|
import sympy
|
||
|
|
||
|
torch = import_module('torch')
|
||
|
|
||
|
|
||
|
class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter):
|
||
|
|
||
|
printmethod = "_torchcode"
|
||
|
|
||
|
mapping = {
|
||
|
sympy.Abs: "torch.abs",
|
||
|
sympy.sign: "torch.sign",
|
||
|
|
||
|
# XXX May raise error for ints.
|
||
|
sympy.ceiling: "torch.ceil",
|
||
|
sympy.floor: "torch.floor",
|
||
|
sympy.log: "torch.log",
|
||
|
sympy.exp: "torch.exp",
|
||
|
Sqrt: "torch.sqrt",
|
||
|
sympy.cos: "torch.cos",
|
||
|
sympy.acos: "torch.acos",
|
||
|
sympy.sin: "torch.sin",
|
||
|
sympy.asin: "torch.asin",
|
||
|
sympy.tan: "torch.tan",
|
||
|
sympy.atan: "torch.atan",
|
||
|
sympy.atan2: "torch.atan2",
|
||
|
# XXX Also may give NaN for complex results.
|
||
|
sympy.cosh: "torch.cosh",
|
||
|
sympy.acosh: "torch.acosh",
|
||
|
sympy.sinh: "torch.sinh",
|
||
|
sympy.asinh: "torch.asinh",
|
||
|
sympy.tanh: "torch.tanh",
|
||
|
sympy.atanh: "torch.atanh",
|
||
|
sympy.Pow: "torch.pow",
|
||
|
|
||
|
sympy.re: "torch.real",
|
||
|
sympy.im: "torch.imag",
|
||
|
sympy.arg: "torch.angle",
|
||
|
|
||
|
# XXX May raise error for ints and complexes
|
||
|
sympy.erf: "torch.erf",
|
||
|
sympy.loggamma: "torch.lgamma",
|
||
|
|
||
|
sympy.Eq: "torch.eq",
|
||
|
sympy.Ne: "torch.ne",
|
||
|
sympy.StrictGreaterThan: "torch.gt",
|
||
|
sympy.StrictLessThan: "torch.lt",
|
||
|
sympy.LessThan: "torch.le",
|
||
|
sympy.GreaterThan: "torch.ge",
|
||
|
|
||
|
sympy.And: "torch.logical_and",
|
||
|
sympy.Or: "torch.logical_or",
|
||
|
sympy.Not: "torch.logical_not",
|
||
|
sympy.Max: "torch.max",
|
||
|
sympy.Min: "torch.min",
|
||
|
|
||
|
# Matrices
|
||
|
sympy.MatAdd: "torch.add",
|
||
|
sympy.HadamardProduct: "torch.mul",
|
||
|
sympy.Trace: "torch.trace",
|
||
|
|
||
|
# XXX May raise error for integer matrices.
|
||
|
sympy.Determinant: "torch.det",
|
||
|
}
|
||
|
|
||
|
_default_settings = dict(
|
||
|
AbstractPythonCodePrinter._default_settings,
|
||
|
torch_version=None,
|
||
|
requires_grad=False,
|
||
|
dtype="torch.float64",
|
||
|
)
|
||
|
|
||
|
def __init__(self, settings=None):
|
||
|
super().__init__(settings)
|
||
|
|
||
|
version = self._settings['torch_version']
|
||
|
self.requires_grad = self._settings['requires_grad']
|
||
|
self.dtype = self._settings['dtype']
|
||
|
if version is None and torch:
|
||
|
version = torch.__version__
|
||
|
self.torch_version = version
|
||
|
|
||
|
def _print_Function(self, expr):
|
||
|
|
||
|
op = self.mapping.get(type(expr), None)
|
||
|
if op is None:
|
||
|
return super()._print_Basic(expr)
|
||
|
children = [self._print(arg) for arg in expr.args]
|
||
|
if len(children) == 1:
|
||
|
return "%s(%s)" % (
|
||
|
self._module_format(op),
|
||
|
children[0]
|
||
|
)
|
||
|
else:
|
||
|
return self._expand_fold_binary_op(op, children)
|
||
|
|
||
|
# mirrors the tensorflow version
|
||
|
_print_Expr = _print_Function
|
||
|
_print_Application = _print_Function
|
||
|
_print_MatrixExpr = _print_Function
|
||
|
_print_Relational = _print_Function
|
||
|
_print_Not = _print_Function
|
||
|
_print_And = _print_Function
|
||
|
_print_Or = _print_Function
|
||
|
_print_HadamardProduct = _print_Function
|
||
|
_print_Trace = _print_Function
|
||
|
_print_Determinant = _print_Function
|
||
|
|
||
|
def _print_Inverse(self, expr):
|
||
|
return '{}({})'.format(self._module_format("torch.linalg.inv"),
|
||
|
self._print(expr.args[0]))
|
||
|
|
||
|
def _print_Transpose(self, expr):
|
||
|
if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]:
|
||
|
# For square matrices, we can use the .t() method
|
||
|
return "{}({}).t()".format("torch.transpose", self._print(expr.arg))
|
||
|
else:
|
||
|
# For non-square matrices or more general cases
|
||
|
# transpose first and second dimensions (typical matrix transpose)
|
||
|
return "{}.permute({})".format(
|
||
|
self._print(expr.arg),
|
||
|
", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1]
|
||
|
)
|
||
|
|
||
|
def _print_PermuteDims(self, expr):
|
||
|
return "%s.permute(%s)" % (
|
||
|
self._print(expr.expr),
|
||
|
", ".join(str(i) for i in expr.permutation.array_form)
|
||
|
)
|
||
|
|
||
|
def _print_Derivative(self, expr):
|
||
|
# this version handles multi-variable and mixed partial derivatives. The tensorflow version does not.
|
||
|
variables = expr.variables
|
||
|
expr_arg = expr.expr
|
||
|
|
||
|
# Handle multi-variable or repeated derivatives
|
||
|
if len(variables) > 1 or (
|
||
|
len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1):
|
||
|
result = self._print(expr_arg)
|
||
|
var_groups = {}
|
||
|
|
||
|
# Group variables by base symbol
|
||
|
for var in variables:
|
||
|
if isinstance(var, tuple):
|
||
|
base_var, order = var
|
||
|
var_groups[base_var] = var_groups.get(base_var, 0) + order
|
||
|
else:
|
||
|
var_groups[var] = var_groups.get(var, 0) + 1
|
||
|
|
||
|
# Apply gradients in sequence
|
||
|
for var, order in var_groups.items():
|
||
|
for _ in range(order):
|
||
|
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var))
|
||
|
return result
|
||
|
|
||
|
# Handle single variable case
|
||
|
if len(variables) == 1:
|
||
|
variable = variables[0]
|
||
|
if isinstance(variable, tuple) and len(variable) == 2:
|
||
|
base_var, order = variable
|
||
|
if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported")
|
||
|
result = self._print(expr_arg)
|
||
|
for _ in range(order):
|
||
|
result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var))
|
||
|
return result
|
||
|
return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable))
|
||
|
|
||
|
return self._print(expr_arg) # Empty variables case
|
||
|
|
||
|
def _print_Piecewise(self, expr):
|
||
|
from sympy import Piecewise
|
||
|
e, cond = expr.args[0].args
|
||
|
if len(expr.args) == 1:
|
||
|
return '{}({}, {}, {})'.format(
|
||
|
self._module_format("torch.where"),
|
||
|
self._print(cond),
|
||
|
self._print(e),
|
||
|
0)
|
||
|
|
||
|
return '{}({}, {}, {})'.format(
|
||
|
self._module_format("torch.where"),
|
||
|
self._print(cond),
|
||
|
self._print(e),
|
||
|
self._print(Piecewise(*expr.args[1:])))
|
||
|
|
||
|
def _print_Pow(self, expr):
|
||
|
# XXX May raise error for
|
||
|
# int**float or int**complex or float**complex
|
||
|
base, exp = expr.args
|
||
|
if expr.exp == S.Half:
|
||
|
return "{}({})".format(
|
||
|
self._module_format("torch.sqrt"), self._print(base))
|
||
|
return "{}({}, {})".format(
|
||
|
self._module_format("torch.pow"),
|
||
|
self._print(base), self._print(exp))
|
||
|
|
||
|
def _print_MatMul(self, expr):
|
||
|
# Separate matrix and scalar arguments
|
||
|
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
|
||
|
args = [arg for arg in expr.args if arg not in mat_args]
|
||
|
# Handle scalar multipliers if present
|
||
|
if args:
|
||
|
return "%s*%s" % (
|
||
|
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
|
||
|
self._expand_fold_binary_op("torch.matmul", mat_args)
|
||
|
)
|
||
|
else:
|
||
|
return self._expand_fold_binary_op("torch.matmul", mat_args)
|
||
|
|
||
|
def _print_MatPow(self, expr):
|
||
|
return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp)
|
||
|
|
||
|
def _print_MatrixBase(self, expr):
|
||
|
data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]"
|
||
|
params = [str(data)]
|
||
|
params.append(f"dtype={self.dtype}")
|
||
|
if self.requires_grad:
|
||
|
params.append("requires_grad=True")
|
||
|
|
||
|
return "{}({})".format(
|
||
|
self._module_format("torch.tensor"),
|
||
|
", ".join(params)
|
||
|
)
|
||
|
|
||
|
def _print_isnan(self, expr):
|
||
|
return f'torch.isnan({self._print(expr.args[0])})'
|
||
|
|
||
|
def _print_isinf(self, expr):
|
||
|
return f'torch.isinf({self._print(expr.args[0])})'
|
||
|
|
||
|
def _print_Identity(self, expr):
|
||
|
if all(dim.is_Integer for dim in expr.shape):
|
||
|
return "{}({})".format(
|
||
|
self._module_format("torch.eye"),
|
||
|
self._print(expr.shape[0])
|
||
|
)
|
||
|
else:
|
||
|
# For symbolic dimensions, fall back to a more general approach
|
||
|
return "{}({}, {})".format(
|
||
|
self._module_format("torch.eye"),
|
||
|
self._print(expr.shape[0]),
|
||
|
self._print(expr.shape[1])
|
||
|
)
|
||
|
|
||
|
def _print_ZeroMatrix(self, expr):
|
||
|
return "{}({})".format(
|
||
|
self._module_format("torch.zeros"),
|
||
|
self._print(expr.shape)
|
||
|
)
|
||
|
|
||
|
def _print_OneMatrix(self, expr):
|
||
|
return "{}({})".format(
|
||
|
self._module_format("torch.ones"),
|
||
|
self._print(expr.shape)
|
||
|
)
|
||
|
|
||
|
def _print_conjugate(self, expr):
|
||
|
return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})"
|
||
|
|
||
|
def _print_ImaginaryUnit(self, expr):
|
||
|
return "1j" # uses the Python built-in 1j notation for the imaginary unit
|
||
|
|
||
|
def _print_Heaviside(self, expr):
|
||
|
args = [self._print(expr.args[0]), "0.5"]
|
||
|
if len(expr.args) > 1:
|
||
|
args[1] = self._print(expr.args[1])
|
||
|
return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})"
|
||
|
|
||
|
def _print_gamma(self, expr):
|
||
|
return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})"
|
||
|
|
||
|
def _print_polygamma(self, expr):
|
||
|
if expr.args[0] == S.Zero:
|
||
|
return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})"
|
||
|
else:
|
||
|
raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)")
|
||
|
|
||
|
_module = "torch"
|
||
|
_einsum = "einsum"
|
||
|
_add = "add"
|
||
|
_transpose = "t"
|
||
|
_ones = "ones"
|
||
|
_zeros = "zeros"
|
||
|
|
||
|
def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings):
|
||
|
printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype})
|
||
|
return printer.doprint(expr, **settings)
|