730 lines
25 KiB
Python
730 lines
25 KiB
Python
import re
|
|
|
|
import sympy
|
|
from sympy.external import import_module
|
|
from sympy.parsing.latex.errors import LaTeXParsingError
|
|
|
|
lark = import_module("lark")
|
|
|
|
if lark:
|
|
from lark import Transformer, Token, Tree # type: ignore
|
|
else:
|
|
class Transformer: # type: ignore
|
|
def transform(self, *args):
|
|
pass
|
|
|
|
|
|
class Token: # type: ignore
|
|
pass
|
|
|
|
|
|
class Tree: # type: ignore
|
|
pass
|
|
|
|
|
|
# noinspection PyPep8Naming,PyMethodMayBeStatic
|
|
class TransformToSymPyExpr(Transformer):
|
|
"""Returns a SymPy expression that is generated by traversing the ``lark.Tree``
|
|
passed to the ``.transform()`` function.
|
|
|
|
Notes
|
|
=====
|
|
|
|
**This class is never supposed to be used directly.**
|
|
|
|
In order to tweak the behavior of this class, it has to be subclassed and then after
|
|
the required modifications are made, the name of the new class should be passed to
|
|
the :py:class:`LarkLaTeXParser` class by using the ``transformer`` argument in the
|
|
constructor.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
visit_tokens : bool, optional
|
|
For information about what this option does, see `here
|
|
<https://lark-parser.readthedocs.io/en/latest/visitors.html#lark.visitors.Transformer>`_.
|
|
|
|
Note that the option must be set to ``True`` for the default parser to work.
|
|
"""
|
|
|
|
SYMBOL = sympy.Symbol
|
|
DIGIT = sympy.core.numbers.Integer
|
|
|
|
def CMD_INFTY(self, tokens):
|
|
return sympy.oo
|
|
|
|
def GREEK_SYMBOL_WITH_PRIMES(self, tokens):
|
|
# we omit the first character because it is a backslash. Also, if the variable name has "var" in it,
|
|
# like "varphi" or "varepsilon", we remove that too
|
|
variable_name = re.sub("var", "", tokens[1:])
|
|
|
|
return sympy.Symbol(variable_name)
|
|
|
|
def LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens):
|
|
base, sub = tokens.value.split("_")
|
|
if sub.startswith("{"):
|
|
return sympy.Symbol("%s_{%s}" % (base, sub[1:-1]))
|
|
else:
|
|
return sympy.Symbol("%s_{%s}" % (base, sub))
|
|
|
|
def GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens):
|
|
base, sub = tokens.value.split("_")
|
|
greek_letter = re.sub("var", "", base[1:])
|
|
|
|
if sub.startswith("{"):
|
|
return sympy.Symbol("%s_{%s}" % (greek_letter, sub[1:-1]))
|
|
else:
|
|
return sympy.Symbol("%s_{%s}" % (greek_letter, sub))
|
|
|
|
def LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens):
|
|
base, sub = tokens.value.split("_")
|
|
if sub.startswith("{"):
|
|
greek_letter = sub[2:-1]
|
|
else:
|
|
greek_letter = sub[1:]
|
|
|
|
greek_letter = re.sub("var", "", greek_letter)
|
|
return sympy.Symbol("%s_{%s}" % (base, greek_letter))
|
|
|
|
|
|
def GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens):
|
|
base, sub = tokens.value.split("_")
|
|
greek_base = re.sub("var", "", base[1:])
|
|
|
|
if sub.startswith("{"):
|
|
greek_sub = sub[2:-1]
|
|
else:
|
|
greek_sub = sub[1:]
|
|
|
|
greek_sub = re.sub("var", "", greek_sub)
|
|
return sympy.Symbol("%s_{%s}" % (greek_base, greek_sub))
|
|
|
|
def multi_letter_symbol(self, tokens):
|
|
if len(tokens) == 4: # no primes (single quotes) on symbol
|
|
return sympy.Symbol(tokens[2])
|
|
if len(tokens) == 5: # there are primes on the symbol
|
|
return sympy.Symbol(tokens[2] + tokens[4])
|
|
|
|
def number(self, tokens):
|
|
if tokens[0].type == "CMD_IMAGINARY_UNIT":
|
|
return sympy.I
|
|
|
|
if "." in tokens[0]:
|
|
return sympy.core.numbers.Float(tokens[0])
|
|
else:
|
|
return sympy.core.numbers.Integer(tokens[0])
|
|
|
|
def latex_string(self, tokens):
|
|
return tokens[0]
|
|
|
|
def group_round_parentheses(self, tokens):
|
|
return tokens[1]
|
|
|
|
def group_square_brackets(self, tokens):
|
|
return tokens[1]
|
|
|
|
def group_curly_parentheses(self, tokens):
|
|
return tokens[1]
|
|
|
|
def eq(self, tokens):
|
|
return sympy.Eq(tokens[0], tokens[2])
|
|
|
|
def ne(self, tokens):
|
|
return sympy.Ne(tokens[0], tokens[2])
|
|
|
|
def lt(self, tokens):
|
|
return sympy.Lt(tokens[0], tokens[2])
|
|
|
|
def lte(self, tokens):
|
|
return sympy.Le(tokens[0], tokens[2])
|
|
|
|
def gt(self, tokens):
|
|
return sympy.Gt(tokens[0], tokens[2])
|
|
|
|
def gte(self, tokens):
|
|
return sympy.Ge(tokens[0], tokens[2])
|
|
|
|
def add(self, tokens):
|
|
if len(tokens) == 2: # +a
|
|
return tokens[1]
|
|
if len(tokens) == 3: # a + b
|
|
lh = tokens[0]
|
|
rh = tokens[2]
|
|
|
|
if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh):
|
|
return sympy.MatAdd(lh, rh)
|
|
|
|
return sympy.Add(lh, rh)
|
|
|
|
def sub(self, tokens):
|
|
if len(tokens) == 2: # -a
|
|
x = tokens[1]
|
|
|
|
if self._obj_is_sympy_Matrix(x):
|
|
return sympy.MatMul(-1, x)
|
|
|
|
return -x
|
|
if len(tokens) == 3: # a - b
|
|
lh = tokens[0]
|
|
rh = tokens[2]
|
|
|
|
if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh):
|
|
return sympy.MatAdd(lh, sympy.MatMul(-1, rh))
|
|
|
|
return sympy.Add(lh, -rh)
|
|
|
|
def mul(self, tokens):
|
|
lh = tokens[0]
|
|
rh = tokens[2]
|
|
|
|
if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh):
|
|
return sympy.MatMul(lh, rh)
|
|
|
|
return sympy.Mul(lh, rh)
|
|
|
|
def div(self, tokens):
|
|
return self._handle_division(tokens[0], tokens[2])
|
|
|
|
def adjacent_expressions(self, tokens):
|
|
# Most of the time, if two expressions are next to each other, it means implicit multiplication,
|
|
# but not always
|
|
from sympy.physics.quantum import Bra, Ket
|
|
if isinstance(tokens[0], Ket) and isinstance(tokens[1], Bra):
|
|
from sympy.physics.quantum import OuterProduct
|
|
return OuterProduct(tokens[0], tokens[1])
|
|
elif tokens[0] == sympy.Symbol("d"):
|
|
# If the leftmost token is a "d", then it is highly likely that this is a differential
|
|
return tokens[0], tokens[1]
|
|
elif isinstance(tokens[0], tuple):
|
|
# then we have a derivative
|
|
return sympy.Derivative(tokens[1], tokens[0][1])
|
|
else:
|
|
return sympy.Mul(tokens[0], tokens[1])
|
|
|
|
def superscript(self, tokens):
|
|
def isprime(x):
|
|
return isinstance(x, Token) and x.type == "PRIMES"
|
|
|
|
def iscmdprime(x):
|
|
return isinstance(x, Token) and (x.type == "PRIMES_VIA_CMD"
|
|
or x.type == "CMD_PRIME")
|
|
|
|
def isstar(x):
|
|
return isinstance(x, Token) and x.type == "STARS"
|
|
|
|
def iscmdstar(x):
|
|
return isinstance(x, Token) and (x.type == "STARS_VIA_CMD"
|
|
or x.type == "CMD_ASTERISK")
|
|
|
|
base = tokens[0]
|
|
if len(tokens) == 3: # a^b OR a^\prime OR a^\ast
|
|
sup = tokens[2]
|
|
if len(tokens) == 5:
|
|
# a^{'}, a^{''}, ... OR
|
|
# a^{*}, a^{**}, ... OR
|
|
# a^{\prime}, a^{\prime\prime}, ... OR
|
|
# a^{\ast}, a^{\ast\ast}, ...
|
|
sup = tokens[3]
|
|
|
|
if self._obj_is_sympy_Matrix(base):
|
|
if sup == sympy.Symbol("T"):
|
|
return sympy.Transpose(base)
|
|
if sup == sympy.Symbol("H"):
|
|
return sympy.adjoint(base)
|
|
if isprime(sup):
|
|
sup = sup.value
|
|
if len(sup) % 2 == 0:
|
|
return base
|
|
return sympy.Transpose(base)
|
|
if iscmdprime(sup):
|
|
sup = sup.value
|
|
if (len(sup)/len(r"\prime")) % 2 == 0:
|
|
return base
|
|
return sympy.Transpose(base)
|
|
if isstar(sup):
|
|
sup = sup.value
|
|
# need .doit() in order to be consistent with
|
|
# sympy.adjoint() which returns the evaluated adjoint
|
|
# of a matrix
|
|
if len(sup) % 2 == 0:
|
|
return base.doit()
|
|
return sympy.adjoint(base)
|
|
if iscmdstar(sup):
|
|
sup = sup.value
|
|
# need .doit() for same reason as above
|
|
if (len(sup)/len(r"\ast")) % 2 == 0:
|
|
return base.doit()
|
|
return sympy.adjoint(base)
|
|
|
|
if isprime(sup) or iscmdprime(sup) or isstar(sup) or iscmdstar(sup):
|
|
raise LaTeXParsingError(f"{base} with superscript {sup} is not understood.")
|
|
|
|
return sympy.Pow(base, sup)
|
|
|
|
def matrix_prime(self, tokens):
|
|
base = tokens[0]
|
|
primes = tokens[1].value
|
|
|
|
if not self._obj_is_sympy_Matrix(base):
|
|
raise LaTeXParsingError(f"({base}){primes} is not understood.")
|
|
|
|
if len(primes) % 2 == 0:
|
|
return base
|
|
|
|
return sympy.Transpose(base)
|
|
|
|
def symbol_prime(self, tokens):
|
|
base = tokens[0]
|
|
primes = tokens[1].value
|
|
|
|
return sympy.Symbol(f"{base.name}{primes}")
|
|
|
|
def fraction(self, tokens):
|
|
numerator = tokens[1]
|
|
if isinstance(tokens[2], tuple):
|
|
# we only need the variable w.r.t. which we are differentiating
|
|
_, variable = tokens[2]
|
|
|
|
# we will pass this information upwards
|
|
return "derivative", variable
|
|
else:
|
|
denominator = tokens[2]
|
|
return self._handle_division(numerator, denominator)
|
|
|
|
def binomial(self, tokens):
|
|
return sympy.binomial(tokens[1], tokens[2])
|
|
|
|
def normal_integral(self, tokens):
|
|
underscore_index = None
|
|
caret_index = None
|
|
|
|
if "_" in tokens:
|
|
# we need to know the index because the next item in the list is the
|
|
# arguments for the lower bound of the integral
|
|
underscore_index = tokens.index("_")
|
|
|
|
if "^" in tokens:
|
|
# we need to know the index because the next item in the list is the
|
|
# arguments for the upper bound of the integral
|
|
caret_index = tokens.index("^")
|
|
|
|
lower_bound = tokens[underscore_index + 1] if underscore_index else None
|
|
upper_bound = tokens[caret_index + 1] if caret_index else None
|
|
|
|
differential_symbol = self._extract_differential_symbol(tokens)
|
|
|
|
if differential_symbol is None:
|
|
raise LaTeXParsingError("Differential symbol was not found in the expression."
|
|
"Valid differential symbols are \"d\", \"\\text{d}, and \"\\mathrm{d}\".")
|
|
|
|
# else we can assume that a differential symbol was found
|
|
differential_variable_index = tokens.index(differential_symbol) + 1
|
|
differential_variable = tokens[differential_variable_index]
|
|
|
|
# we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would
|
|
# evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero
|
|
if lower_bound is not None and upper_bound is None:
|
|
# then one was given and the other wasn't
|
|
raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.")
|
|
|
|
if upper_bound is not None and lower_bound is None:
|
|
# then one was given and the other wasn't
|
|
raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.")
|
|
|
|
# check if any expression was given or not. If it wasn't, then set the integrand to 1.
|
|
if underscore_index is not None and underscore_index == differential_variable_index - 3:
|
|
# The Token at differential_variable_index - 2 should be the integrand. However, if going one more step
|
|
# backwards after that gives us the underscore, then that means that there _was_ no integrand.
|
|
# Example: \int^7_0 dx
|
|
integrand = 1
|
|
elif caret_index is not None and caret_index == differential_variable_index - 3:
|
|
# The Token at differential_variable_index - 2 should be the integrand. However, if going one more step
|
|
# backwards after that gives us the caret, then that means that there _was_ no integrand.
|
|
# Example: \int_0^7 dx
|
|
integrand = 1
|
|
elif differential_variable_index == 2:
|
|
# this means we have something like "\int dx", because the "\int" symbol will always be
|
|
# at index 0 in `tokens`
|
|
integrand = 1
|
|
else:
|
|
# The Token at differential_variable_index - 1 is the differential symbol itself, so we need to go one
|
|
# more step before that.
|
|
integrand = tokens[differential_variable_index - 2]
|
|
|
|
if lower_bound is not None:
|
|
# then we have a definite integral
|
|
|
|
# we can assume that either both the lower and upper bounds are given, or
|
|
# neither of them are
|
|
return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound))
|
|
else:
|
|
# we have an indefinite integral
|
|
return sympy.Integral(integrand, differential_variable)
|
|
|
|
def group_curly_parentheses_int(self, tokens):
|
|
# return signature is a tuple consisting of the expression in the numerator, along with the variable of
|
|
# integration
|
|
if len(tokens) == 3:
|
|
return 1, tokens[1]
|
|
elif len(tokens) == 4:
|
|
return tokens[1], tokens[2]
|
|
# there are no other possibilities
|
|
|
|
def special_fraction(self, tokens):
|
|
numerator, variable = tokens[1]
|
|
denominator = tokens[2]
|
|
|
|
# We pass the integrand, along with information about the variable of integration, upw
|
|
return sympy.Mul(numerator, sympy.Pow(denominator, -1)), variable
|
|
|
|
def integral_with_special_fraction(self, tokens):
|
|
underscore_index = None
|
|
caret_index = None
|
|
|
|
if "_" in tokens:
|
|
# we need to know the index because the next item in the list is the
|
|
# arguments for the lower bound of the integral
|
|
underscore_index = tokens.index("_")
|
|
|
|
if "^" in tokens:
|
|
# we need to know the index because the next item in the list is the
|
|
# arguments for the upper bound of the integral
|
|
caret_index = tokens.index("^")
|
|
|
|
lower_bound = tokens[underscore_index + 1] if underscore_index else None
|
|
upper_bound = tokens[caret_index + 1] if caret_index else None
|
|
|
|
# we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would
|
|
# evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero
|
|
if lower_bound is not None and upper_bound is None:
|
|
# then one was given and the other wasn't
|
|
raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.")
|
|
|
|
if upper_bound is not None and lower_bound is None:
|
|
# then one was given and the other wasn't
|
|
raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.")
|
|
|
|
integrand, differential_variable = tokens[-1]
|
|
|
|
if lower_bound is not None:
|
|
# then we have a definite integral
|
|
|
|
# we can assume that either both the lower and upper bounds are given, or
|
|
# neither of them are
|
|
return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound))
|
|
else:
|
|
# we have an indefinite integral
|
|
return sympy.Integral(integrand, differential_variable)
|
|
|
|
def group_curly_parentheses_special(self, tokens):
|
|
underscore_index = tokens.index("_")
|
|
caret_index = tokens.index("^")
|
|
|
|
# given the type of expressions we are parsing, we can assume that the lower limit
|
|
# will always use braces around its arguments. This is because we don't support
|
|
# converting unconstrained sums into SymPy expressions.
|
|
|
|
# first we isolate the bottom limit
|
|
left_brace_index = tokens.index("{", underscore_index)
|
|
right_brace_index = tokens.index("}", underscore_index)
|
|
|
|
bottom_limit = tokens[left_brace_index + 1: right_brace_index]
|
|
|
|
# next, we isolate the upper limit
|
|
top_limit = tokens[caret_index + 1:]
|
|
|
|
# the code below will be useful for supporting things like `\sum_{n = 0}^{n = 5} n^2`
|
|
# if "{" in top_limit:
|
|
# left_brace_index = tokens.index("{", caret_index)
|
|
# if left_brace_index != -1:
|
|
# # then there's a left brace in the string, and we need to find the closing right brace
|
|
# right_brace_index = tokens.index("}", caret_index)
|
|
# top_limit = tokens[left_brace_index + 1: right_brace_index]
|
|
|
|
# print(f"top limit = {top_limit}")
|
|
|
|
index_variable = bottom_limit[0]
|
|
lower_limit = bottom_limit[-1]
|
|
upper_limit = top_limit[0] # for now, the index will always be 0
|
|
|
|
# print(f"return value = ({index_variable}, {lower_limit}, {upper_limit})")
|
|
|
|
return index_variable, lower_limit, upper_limit
|
|
|
|
def summation(self, tokens):
|
|
return sympy.Sum(tokens[2], tokens[1])
|
|
|
|
def product(self, tokens):
|
|
return sympy.Product(tokens[2], tokens[1])
|
|
|
|
def limit_dir_expr(self, tokens):
|
|
caret_index = tokens.index("^")
|
|
|
|
if "{" in tokens:
|
|
left_curly_brace_index = tokens.index("{", caret_index)
|
|
direction = tokens[left_curly_brace_index + 1]
|
|
else:
|
|
direction = tokens[caret_index + 1]
|
|
|
|
if direction == "+":
|
|
return tokens[0], "+"
|
|
elif direction == "-":
|
|
return tokens[0], "-"
|
|
else:
|
|
return tokens[0], "+-"
|
|
|
|
def group_curly_parentheses_lim(self, tokens):
|
|
limit_variable = tokens[1]
|
|
if isinstance(tokens[3], tuple):
|
|
destination, direction = tokens[3]
|
|
else:
|
|
destination = tokens[3]
|
|
direction = "+-"
|
|
|
|
return limit_variable, destination, direction
|
|
|
|
def limit(self, tokens):
|
|
limit_variable, destination, direction = tokens[2]
|
|
|
|
return sympy.Limit(tokens[-1], limit_variable, destination, direction)
|
|
|
|
def differential(self, tokens):
|
|
return tokens[1]
|
|
|
|
def derivative(self, tokens):
|
|
return sympy.Derivative(tokens[-1], tokens[5])
|
|
|
|
def list_of_expressions(self, tokens):
|
|
if len(tokens) == 1:
|
|
# we return it verbatim because the function_applied node expects
|
|
# a list
|
|
return tokens
|
|
else:
|
|
def remove_tokens(args):
|
|
if isinstance(args, Token):
|
|
if args.type != "COMMA":
|
|
# An unexpected token was encountered
|
|
raise LaTeXParsingError("A comma token was expected, but some other token was encountered.")
|
|
return False
|
|
return True
|
|
|
|
return filter(remove_tokens, tokens)
|
|
|
|
def function_applied(self, tokens):
|
|
return sympy.Function(tokens[0])(*tokens[2])
|
|
|
|
def min(self, tokens):
|
|
return sympy.Min(*tokens[2])
|
|
|
|
def max(self, tokens):
|
|
return sympy.Max(*tokens[2])
|
|
|
|
def bra(self, tokens):
|
|
from sympy.physics.quantum import Bra
|
|
return Bra(tokens[1])
|
|
|
|
def ket(self, tokens):
|
|
from sympy.physics.quantum import Ket
|
|
return Ket(tokens[1])
|
|
|
|
def inner_product(self, tokens):
|
|
from sympy.physics.quantum import Bra, Ket, InnerProduct
|
|
return InnerProduct(Bra(tokens[1]), Ket(tokens[3]))
|
|
|
|
def sin(self, tokens):
|
|
return sympy.sin(tokens[1])
|
|
|
|
def cos(self, tokens):
|
|
return sympy.cos(tokens[1])
|
|
|
|
def tan(self, tokens):
|
|
return sympy.tan(tokens[1])
|
|
|
|
def csc(self, tokens):
|
|
return sympy.csc(tokens[1])
|
|
|
|
def sec(self, tokens):
|
|
return sympy.sec(tokens[1])
|
|
|
|
def cot(self, tokens):
|
|
return sympy.cot(tokens[1])
|
|
|
|
def sin_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.asin(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.sin(tokens[-1]), exponent)
|
|
|
|
def cos_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.acos(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.cos(tokens[-1]), exponent)
|
|
|
|
def tan_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.atan(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.tan(tokens[-1]), exponent)
|
|
|
|
def csc_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.acsc(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.csc(tokens[-1]), exponent)
|
|
|
|
def sec_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.asec(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.sec(tokens[-1]), exponent)
|
|
|
|
def cot_power(self, tokens):
|
|
exponent = tokens[2]
|
|
if exponent == -1:
|
|
return sympy.acot(tokens[-1])
|
|
else:
|
|
return sympy.Pow(sympy.cot(tokens[-1]), exponent)
|
|
|
|
def arcsin(self, tokens):
|
|
return sympy.asin(tokens[1])
|
|
|
|
def arccos(self, tokens):
|
|
return sympy.acos(tokens[1])
|
|
|
|
def arctan(self, tokens):
|
|
return sympy.atan(tokens[1])
|
|
|
|
def arccsc(self, tokens):
|
|
return sympy.acsc(tokens[1])
|
|
|
|
def arcsec(self, tokens):
|
|
return sympy.asec(tokens[1])
|
|
|
|
def arccot(self, tokens):
|
|
return sympy.acot(tokens[1])
|
|
|
|
def sinh(self, tokens):
|
|
return sympy.sinh(tokens[1])
|
|
|
|
def cosh(self, tokens):
|
|
return sympy.cosh(tokens[1])
|
|
|
|
def tanh(self, tokens):
|
|
return sympy.tanh(tokens[1])
|
|
|
|
def asinh(self, tokens):
|
|
return sympy.asinh(tokens[1])
|
|
|
|
def acosh(self, tokens):
|
|
return sympy.acosh(tokens[1])
|
|
|
|
def atanh(self, tokens):
|
|
return sympy.atanh(tokens[1])
|
|
|
|
def abs(self, tokens):
|
|
return sympy.Abs(tokens[1])
|
|
|
|
def floor(self, tokens):
|
|
return sympy.floor(tokens[1])
|
|
|
|
def ceil(self, tokens):
|
|
return sympy.ceiling(tokens[1])
|
|
|
|
def factorial(self, tokens):
|
|
return sympy.factorial(tokens[0])
|
|
|
|
def conjugate(self, tokens):
|
|
return sympy.conjugate(tokens[1])
|
|
|
|
def square_root(self, tokens):
|
|
if len(tokens) == 2:
|
|
# then there was no square bracket argument
|
|
return sympy.sqrt(tokens[1])
|
|
elif len(tokens) == 3:
|
|
# then there _was_ a square bracket argument
|
|
return sympy.root(tokens[2], tokens[1])
|
|
|
|
def exponential(self, tokens):
|
|
return sympy.exp(tokens[1])
|
|
|
|
def log(self, tokens):
|
|
if tokens[0].type == "FUNC_LG":
|
|
# we don't need to check if there's an underscore or not because having one
|
|
# in this case would be meaningless
|
|
# TODO: ANTLR refers to ISO 80000-2:2019. should we keep base 10 or base 2?
|
|
return sympy.log(tokens[1], 10)
|
|
elif tokens[0].type == "FUNC_LN":
|
|
return sympy.log(tokens[1])
|
|
elif tokens[0].type == "FUNC_LOG":
|
|
# we check if a base was specified or not
|
|
if "_" in tokens:
|
|
# then a base was specified
|
|
return sympy.log(tokens[3], tokens[2])
|
|
else:
|
|
# a base was not specified
|
|
return sympy.log(tokens[1])
|
|
|
|
def _extract_differential_symbol(self, s: str):
|
|
differential_symbols = {"d", r"\text{d}", r"\mathrm{d}"}
|
|
|
|
differential_symbol = next((symbol for symbol in differential_symbols if symbol in s), None)
|
|
|
|
return differential_symbol
|
|
|
|
def matrix(self, tokens):
|
|
def is_matrix_row(x):
|
|
return (isinstance(x, Tree) and x.data == "matrix_row")
|
|
|
|
def is_not_col_delim(y):
|
|
return (not isinstance(y, Token) or y.type != "MATRIX_COL_DELIM")
|
|
|
|
matrix_body = tokens[1].children
|
|
return sympy.Matrix([[y for y in x.children if is_not_col_delim(y)]
|
|
for x in matrix_body if is_matrix_row(x)])
|
|
|
|
def determinant(self, tokens):
|
|
if len(tokens) == 2: # \det A
|
|
if not self._obj_is_sympy_Matrix(tokens[1]):
|
|
raise LaTeXParsingError("Cannot take determinant of non-matrix.")
|
|
|
|
return tokens[1].det()
|
|
|
|
if len(tokens) == 3: # | A |
|
|
return self.matrix(tokens).det()
|
|
|
|
def trace(self, tokens):
|
|
if not self._obj_is_sympy_Matrix(tokens[1]):
|
|
raise LaTeXParsingError("Cannot take trace of non-matrix.")
|
|
|
|
return sympy.Trace(tokens[1])
|
|
|
|
def adjugate(self, tokens):
|
|
if not self._obj_is_sympy_Matrix(tokens[1]):
|
|
raise LaTeXParsingError("Cannot take adjugate of non-matrix.")
|
|
|
|
# need .doit() since MatAdd does not support .adjugate() method
|
|
return tokens[1].doit().adjugate()
|
|
|
|
def _obj_is_sympy_Matrix(self, obj):
|
|
if hasattr(obj, "is_Matrix"):
|
|
return obj.is_Matrix
|
|
|
|
return isinstance(obj, sympy.Matrix)
|
|
|
|
def _handle_division(self, numerator, denominator):
|
|
if self._obj_is_sympy_Matrix(denominator):
|
|
raise LaTeXParsingError("Cannot divide by matrices like this since "
|
|
"it is not clear if left or right multiplication "
|
|
"by the inverse is intended. Try explicitly "
|
|
"multiplying by the inverse instead.")
|
|
|
|
if self._obj_is_sympy_Matrix(numerator):
|
|
return sympy.MatMul(numerator, sympy.Pow(denominator, -1))
|
|
|
|
return sympy.Mul(numerator, sympy.Pow(denominator, -1))
|