Adding all project files

This commit is contained in:
Martina Burlando 2025-08-02 02:00:33 +02:00
parent 6c9e127bdc
commit cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions

View file

@ -0,0 +1,16 @@
from .ode import (allhints, checkinfsol, classify_ode,
constantsimp, dsolve, homogeneous_order)
from .lie_group import infinitesimals
from .subscheck import checkodesol
from .systems import (canonical_odes, linear_ode_to_matrix,
linodesolve)
__all__ = [
'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp',
'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix',
'linodesolve'
]

View file

@ -0,0 +1,272 @@
r'''
This module contains the implementation of the 2nd_hypergeometric hint for
dsolve. This is an incomplete implementation of the algorithm described in [1].
The algorithm solves 2nd order linear ODEs of the form
.. math:: y'' + A(x) y' + B(x) y = 0\text{,}
where `A` and `B` are rational functions. The algorithm should find any
solution of the form
.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,}
where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function".
Currently only the 2F1 case is implemented in SymPy but the other cases are
described in the paper and could be implemented in future (contributions
welcome!).
References
==========
.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order
linear ODEs, (2004).
https://arxiv.org/abs/math-ph/0402063
'''
from sympy.core import S, Pow
from sympy.core.function import expand
from sympy.core.relational import Eq
from sympy.core.symbol import Symbol, Wild
from sympy.functions import exp, sqrt, hyper
from sympy.integrals import Integral
from sympy.polys import roots, gcd
from sympy.polys.polytools import cancel, factor
from sympy.simplify import collect, simplify, logcombine # type: ignore
from sympy.simplify.powsimp import powdenest
from sympy.solvers.ode.ode import get_numbered_constants
def match_2nd_hypergeometric(eq, func):
x = func.args[0]
df = func.diff(x)
a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)])
b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)])
c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)])
deq = a3*(func.diff(x, 2)) + b3*df + c3*func
r = collect(eq,
[func.diff(x, 2), func.diff(x), func]).match(deq)
if r:
if not all(val.is_polynomial() for val in r.values()):
n, d = eq.as_numer_denom()
eq = expand(n)
r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq)
if r and r[a3]!=0:
A = cancel(r[b3]/r[a3])
B = cancel(r[c3]/r[a3])
return [A, B]
else:
return []
def equivalence_hypergeometric(A, B, func):
# This method for finding the equivalence is only for 2F1 type.
# We can extend it for 1F1 and 0F1 type also.
x = func.args[0]
# making given equation in normal form
I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B))
# computing shifted invariant(J1) of the equation
J1 = factor(cancel(x**2*I1 + S(1)/4))
num, dem = J1.as_numer_denom()
num = powdenest(expand(num))
dem = powdenest(expand(dem))
# this function will compute the different powers of variable(x) in J1.
# then it will help in finding value of k. k is power of x such that we can express
# J1 = x**k * J0(x**k) then all the powers in J0 become integers.
def _power_counting(num):
_pow = {0}
for val in num:
if val.has(x):
if isinstance(val, Pow) and val.as_base_exp()[0] == x:
_pow.add(val.as_base_exp()[1])
elif val == x:
_pow.add(val.as_base_exp()[1])
else:
_pow.update(_power_counting(val.args))
return _pow
pow_num = _power_counting((num, ))
pow_dem = _power_counting((dem, ))
pow_dem.update(pow_num)
_pow = pow_dem
k = gcd(_pow)
# computing I0 of the given equation
I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True)
I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True)))
# Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing
# x with x**(1/k) should result in I0 being a rational function of x or
# otherwise the hypergeometric solver cannot be used. Note that k can be a
# non-integer rational such as 2/7.
if not I0.is_rational_function(x):
return None
num, dem = I0.as_numer_denom()
max_num_pow = max(_power_counting((num, )))
dem_args = dem.args
sing_point = []
dem_pow = []
# calculating singular point of I0.
for arg in dem_args:
if arg.has(x):
if isinstance(arg, Pow):
# (x-a)**n
dem_pow.append(arg.as_base_exp()[1])
sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0])
else:
# (x-a) type
dem_pow.append(arg.as_base_exp()[1])
sing_point.append(list(roots(arg, x).keys())[0])
dem_pow.sort()
# checking if equivalence is exists or not.
if equivalence(max_num_pow, dem_pow) == "2F1":
return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"}
else:
return None
def match_2nd_2F1_hypergeometric(I, k, sing_point, func):
x = func.args[0]
a = Wild("a")
b = Wild("b")
c = Wild("c")
t = Wild("t")
s = Wild("s")
r = Wild("r")
alpha = Wild("alpha")
beta = Wild("beta")
gamma = Wild("gamma")
delta = Wild("delta")
# I0 of the standard 2F1 equation.
I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2)
if sing_point != [0, 1]:
# If singular point is [0, 1] then we have standard equation.
eqs = []
sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)]
# making equations for the finding the mobius transformation
for i in range(3):
if i<len(sing_point):
eqs.append(Eq(sing_eqs[i], sing_point[i]))
else:
eqs.append(Eq(1/sing_eqs[i], 0))
# solving above equations for the mobius transformation
_beta = -alpha*sing_point[0]
_delta = -gamma*sing_point[1]
_gamma = alpha
if len(sing_point) == 3:
_gamma = (_beta + sing_point[2]*alpha)/(sing_point[2] - sing_point[1])
mob = (alpha*x + beta)/(gamma*x + delta)
mob = mob.subs(beta, _beta)
mob = mob.subs(delta, _delta)
mob = mob.subs(gamma, _gamma)
mob = cancel(mob)
t = (beta - delta*x)/(gamma*x - alpha)
t = cancel(((t.subs(beta, _beta)).subs(delta, _delta)).subs(gamma, _gamma))
else:
mob = x
t = x
# applying mobius transformation in I to make it into I0.
I = I.subs(x, t)
I = I*(t.diff(x))**2
I = factor(I)
dict_I = {x**2:0, x:0, 1:0}
I0_num, I0_dem = I0.as_numer_denom()
# collecting coeff of (x**2, x), of the standard equation.
# substituting (a-b) = s, (a+b) = r
dict_I0 = {x**2:s**2 - 1, x:(2*(1-r)*c + (r+s)*(r-s)), 1:c*(c-2)}
# collecting coeff of (x**2, x) from I0 of the given equation.
dict_I.update(collect(expand(cancel(I*I0_dem)), [x**2, x], evaluate=False))
eqs = []
# We are comparing the coeff of powers of different x, for finding the values of
# parameters of standard equation.
for key in [x**2, x, 1]:
eqs.append(Eq(dict_I[key], dict_I0[key]))
# We can have many possible roots for the equation.
# I am selecting the root on the basis that when we have
# standard equation eq = x*(x-1)*f(x).diff(x, 2) + ((a+b+1)*x-c)*f(x).diff(x) + a*b*f(x)
# then root should be a, b, c.
_c = 1 - factor(sqrt(1+eqs[2].lhs))
if not _c.has(Symbol):
_c = min(list(roots(eqs[2], c)))
_s = factor(sqrt(eqs[0].lhs + 1))
_r = _c - factor(sqrt(_c**2 + _s**2 + eqs[1].lhs - 2*_c))
_a = (_r + _s)/2
_b = (_r - _s)/2
rn = {'a':simplify(_a), 'b':simplify(_b), 'c':simplify(_c), 'k':k, 'mobius':mob, 'type':"2F1"}
return rn
def equivalence(max_num_pow, dem_pow):
# this function is made for checking the equivalence with 2F1 type of equation.
# max_num_pow is the value of maximum power of x in numerator
# and dem_pow is list of powers of different factor of form (a*x b).
# reference from table 1 in paper - "Non-Liouvillian solutions for second order
# linear ODEs" by L. Chan, E.S. Cheb-Terrab.
# We can extend it for 1F1 and 0F1 type also.
if max_num_pow == 2:
if dem_pow in [[2, 2], [2, 2, 2]]:
return "2F1"
elif max_num_pow == 1:
if dem_pow in [[1, 2, 2], [2, 2, 2], [1, 2], [2, 2]]:
return "2F1"
elif max_num_pow == 0:
if dem_pow in [[1, 1, 2], [2, 2], [1, 2, 2], [1, 1], [2], [1, 2], [2, 2]]:
return "2F1"
return None
def get_sol_2F1_hypergeometric(eq, func, match_object):
x = func.args[0]
from sympy.simplify.hyperexpand import hyperexpand
from sympy.polys.polytools import factor
C0, C1 = get_numbered_constants(eq, num=2)
a = match_object['a']
b = match_object['b']
c = match_object['c']
A = match_object['A']
sol = None
if c.is_integer == False:
sol = C0*hyper([a, b], [c], x) + C1*hyper([a-c+1, b-c+1], [2-c], x)*x**(1-c)
elif c == 1:
y2 = Integral(exp(Integral((-(a+b+1)*x + c)/(x**2-x), x))/(hyperexpand(hyper([a, b], [c], x))**2), x)*hyper([a, b], [c], x)
sol = C0*hyper([a, b], [c], x) + C1*y2
elif (c-a-b).is_integer == False:
sol = C0*hyper([a, b], [1+a+b-c], 1-x) + C1*hyper([c-a, c-b], [1+c-a-b], 1-x)*(1-x)**(c-a-b)
if sol:
# applying transformation in the solution
subs = match_object['mobius']
dtdx = simplify(1/(subs.diff(x)))
_B = ((a + b + 1)*x - c).subs(x, subs)*dtdx
_B = factor(_B + ((x**2 -x).subs(x, subs))*(dtdx.diff(x)*dtdx))
_A = factor((x**2 - x).subs(x, subs)*(dtdx**2))
e = exp(logcombine(Integral(cancel(_B/(2*_A)), x), force=True))
sol = sol.subs(x, match_object['mobius'])
sol = sol.subs(x, x**match_object['k'])
e = e.subs(x, x**match_object['k'])
if not A.is_zero:
e1 = Integral(A/2, x)
e1 = exp(logcombine(e1, force=True))
sol = cancel((e/e1)*x**((-match_object['k']+1)/2))*sol
sol = Eq(func, sol)
return sol
sol = cancel((e)*x**((-match_object['k']+1)/2))*sol
sol = Eq(func, sol)
return sol

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,484 @@
r"""
This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients,
nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients,
nth_linear_constant_coeff_variation_of_parameters,
and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters.
All the functions in this file are used by more than one solvers so, instead of creating
instances in other classes for using them it is better to keep it here as separate helpers.
"""
from collections import Counter
from sympy.core import Add, S
from sympy.core.function import diff, expand, _mexpand, expand_mul
from sympy.core.relational import Eq
from sympy.core.sorting import default_sort_key
from sympy.core.symbol import Dummy, Wild
from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \
atan2, conjugate
from sympy.integrals import Integral
from sympy.polys import (Poly, RootOf, rootof, roots)
from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore
from sympy.utilities import numbered_symbols
from sympy.solvers.solvers import solve
from sympy.matrices import wronskian
from .subscheck import sub_func_doit
from sympy.solvers.ode.ode import get_numbered_constants
def _test_term(coeff, func, order):
r"""
Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x),
where K is independent of x and y(x), order>= 0.
So we need to check that for each term, coeff == K*x**order from
some K. We have a few cases, since coeff may have several
different types.
"""
x = func.args[0]
f = func.func
if order < 0:
raise ValueError("order should be greater than 0")
if coeff == 0:
return True
if order == 0:
if x in coeff.free_symbols:
return False
return True
if coeff.is_Mul:
if coeff.has(f(x)):
return False
return x**order in coeff.args
elif coeff.is_Pow:
return coeff.as_base_exp() == (x, order)
elif order == 1:
return x == coeff
return False
def _get_euler_characteristic_eq_sols(eq, func, match_obj):
r"""
Returns the solution of homogeneous part of the linear euler ODE and
the list of roots of characteristic equation.
The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order
of the derivative on each term, and coeff is the coefficient of that derivative.
"""
x = func.args[0]
f = func.func
# First, set up characteristic equation.
chareq, symbol = S.Zero, Dummy('x')
for i in match_obj:
if i >= 0:
chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand()
chareq = Poly(chareq, symbol)
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
collectterms = []
# A generator of constants
constants = list(get_numbered_constants(eq, num=chareq.degree()*2))
constants.reverse()
# Create a dict root: multiplicity or charroots
charroots = Counter(chareqroots)
gsol = S.Zero
ln = log
for root, multiplicity in charroots.items():
for i in range(multiplicity):
if isinstance(root, RootOf):
gsol += (x**root) * constants.pop()
if multiplicity != 1:
raise ValueError("Value should be 1")
collectterms = [(0, root, 0)] + collectterms
elif root.is_real:
gsol += ln(x)**i*(x**root) * constants.pop()
collectterms = [(i, root, 0)] + collectterms
else:
reroot = re(root)
imroot = im(root)
gsol += ln(x)**i * (x**reroot) * (
constants.pop() * sin(abs(imroot)*ln(x))
+ constants.pop() * cos(imroot*ln(x)))
collectterms = [(i, reroot, imroot)] + collectterms
gsol = Eq(f(x), gsol)
gensols = []
# Keep track of when to use sin or cos for nonzero imroot
for i, reroot, imroot in collectterms:
if imroot == 0:
gensols.append(ln(x)**i*x**reroot)
else:
sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x))
if sin_form in gensols:
cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x))
gensols.append(cos_form)
else:
gensols.append(sin_form)
return gsol, gensols
def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True):
r"""
Helper function for the method of variation of parameters and nonhomogeneous euler eq.
See the
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters`
docstring for more information on this method.
The parameter are ``match_obj`` should be a dictionary that has the following
keys:
``list``
A list of solutions to the homogeneous equation.
``sol``
The general solution.
"""
f = func.func
x = func.args[0]
r = match_obj
psol = 0
wr = wronskian(roots, x)
if simplify_flag:
wr = simplify(wr) # We need much better simplification for
# some ODEs. See issue 4662, for example.
# To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1
wr = trigsimp(wr, deep=True, recursive=True)
if not wr:
# The wronskian will be 0 iff the solutions are not linearly
# independent.
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply " +
"variation of parameters to " + str(eq) + " (Wronskian == 0)")
if len(roots) != order:
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply " +
"variation of parameters to " +
str(eq) + " (number of terms != order)")
negoneterm = S.NegativeOne**(order)
for i in roots:
psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order]
negoneterm *= -1
if simplify_flag:
psol = simplify(psol)
psol = trigsimp(psol, deep=True)
return Eq(f(x), homogen_sol.rhs + psol)
def _get_const_characteristic_eq_sols(r, func, order):
r"""
Returns the roots of characteristic equation of constant coefficient
linear ODE and list of collectterms which is later on used by simplification
to use collect on solution.
The parameter `r` is a dict of order:coeff terms, where order is the order of the
derivative on each term, and coeff is the coefficient of that derivative.
"""
x = func.args[0]
# First, set up characteristic equation.
chareq, symbol = S.Zero, Dummy('x')
for i in r.keys():
if isinstance(i, str) or i < 0:
pass
else:
chareq += r[i]*symbol**i
chareq = Poly(chareq, symbol)
# Can't just call roots because it doesn't return rootof for unsolveable
# polynomials.
chareqroots = roots(chareq, multiple=True)
if len(chareqroots) != order:
chareqroots = [rootof(chareq, k) for k in range(chareq.degree())]
chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs())
# Create a dict root: multiplicity or charroots
charroots = Counter(chareqroots)
# We need to keep track of terms so we can run collect() at the end.
# This is necessary for constantsimp to work properly.
collectterms = []
gensols = []
conjugate_roots = [] # used to prevent double-use of conjugate roots
# Loop over roots in theorder provided by roots/rootof...
for root in chareqroots:
# but don't repoeat multiple roots.
if root not in charroots:
continue
multiplicity = charroots.pop(root)
for i in range(multiplicity):
if chareq_is_complex:
gensols.append(x**i*exp(root*x))
collectterms = [(i, root, 0)] + collectterms
continue
reroot = re(root)
imroot = im(root)
if imroot.has(atan2) and reroot.has(atan2):
# Remove this condition when re and im stop returning
# circular atan2 usages.
gensols.append(x**i*exp(root*x))
collectterms = [(i, root, 0)] + collectterms
else:
if root in conjugate_roots:
collectterms = [(i, reroot, imroot)] + collectterms
continue
if imroot == 0:
gensols.append(x**i*exp(reroot*x))
collectterms = [(i, reroot, 0)] + collectterms
continue
conjugate_roots.append(conjugate(root))
gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x))
gensols.append(x**i*exp(reroot*x) * cos( imroot * x))
# This ordering is important
collectterms = [(i, reroot, imroot)] + collectterms
return gensols, collectterms
# Ideally these kind of simplification functions shouldn't be part of solvers.
# odesimp should be improved to handle these kind of specific simplifications.
def _get_simplified_sol(sol, func, collectterms):
r"""
Helper function which collects the solution on
collectterms. Ideally this should be handled by odesimp.It is used
only when the simplify is set to True in dsolve.
The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is
the multiplicity of the root, reroot is real part and imroot being the imaginary part.
"""
f = func.func
x = func.args[0]
collectterms.sort(key=default_sort_key)
collectterms.reverse()
assert len(sol) == 1 and sol[0].lhs == f(x)
sol = sol[0].rhs
sol = expand_mul(sol)
for i, reroot, imroot in collectterms:
sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x))
sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x))
for i, reroot, imroot in collectterms:
sol = collect(sol, x**i*exp(reroot*x))
sol = powsimp(sol)
return Eq(f(x), sol)
def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero):
r"""
Returns a trial function match if undetermined coefficients can be applied
to ``expr``, and ``None`` otherwise.
A trial expression can be found for an expression for use with the method
of undetermined coefficients if the expression is an
additive/multiplicative combination of constants, polynomials in `x` (the
independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and
`e^{a x}` terms (in other words, it has a finite number of linearly
independent derivatives).
Note that you may still need to multiply each term returned here by
sufficient `x` to make it linearly independent with the solutions to the
homogeneous equation.
This is intended for internal use by ``undetermined_coefficients`` hints.
SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of
only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So,
for example, you will need to manually convert `\sin^2(x)` into `[1 +
\cos(2 x)]/2` to properly apply the method of undetermined coefficients on
it.
Examples
========
>>> from sympy import log, exp
>>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match
>>> from sympy.abc import x
>>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x)
{'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}}
>>> _undetermined_coefficients_match(log(x), x)
{'test': False}
"""
a = Wild('a', exclude=[x])
b = Wild('b', exclude=[x])
expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1)
retdict = {}
def _test_term(expr, x) -> bool:
r"""
Test if ``expr`` fits the proper form for undetermined coefficients.
"""
if not expr.has(x):
return True
if expr.is_Add:
return all(_test_term(i, x) for i in expr.args)
if expr.is_Mul:
if expr.has(sin, cos):
foundtrig = False
# Make sure that there is only one trig function in the args.
# See the docstring.
for i in expr.args:
if i.has(sin, cos):
if foundtrig:
return False
else:
foundtrig = True
return all(_test_term(i, x) for i in expr.args)
if expr.is_Function:
return expr.func in (sin, cos, exp, sinh, cosh) and \
bool(expr.args[0].match(a*x + b))
if expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \
expr.exp >= 0:
return True
if expr.is_Pow and expr.base.is_number:
return bool(expr.exp.match(a*x + b))
return expr.is_Symbol or bool(expr.is_number)
def _get_trial_set(expr, x, exprs=set()):
r"""
Returns a set of trial terms for undetermined coefficients.
The idea behind undetermined coefficients is that the terms expression
repeat themselves after a finite number of derivatives, except for the
coefficients (they are linearly dependent). So if we collect these,
we should have the terms of our trial function.
"""
def _remove_coefficient(expr, x):
r"""
Returns the expression without a coefficient.
Similar to expr.as_independent(x)[1], except it only works
multiplicatively.
"""
term = S.One
if expr.is_Mul:
for i in expr.args:
if i.has(x):
term *= i
elif expr.has(x):
term = expr
return term
expr = expand_mul(expr)
if expr.is_Add:
for term in expr.args:
if _remove_coefficient(term, x) in exprs:
pass
else:
exprs.add(_remove_coefficient(term, x))
exprs = exprs.union(_get_trial_set(term, x, exprs))
else:
term = _remove_coefficient(expr, x)
tmpset = exprs.union({term})
oldset = set()
while tmpset != oldset:
# If you get stuck in this loop, then _test_term is probably
# broken
oldset = tmpset.copy()
expr = expr.diff(x)
term = _remove_coefficient(expr, x)
if term.is_Add:
tmpset = tmpset.union(_get_trial_set(term, x, tmpset))
else:
tmpset.add(term)
exprs = tmpset
return exprs
def is_homogeneous_solution(term):
r""" This function checks whether the given trialset contains any root
of homogeneous equation"""
return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero
retdict['test'] = _test_term(expr, x)
if retdict['test']:
# Try to generate a list of trial solutions that will have the
# undetermined coefficients. Note that if any of these are not linearly
# independent with any of the solutions to the homogeneous equation,
# then they will need to be multiplied by sufficient x to make them so.
# This function DOES NOT do that (it doesn't even look at the
# homogeneous equation).
temp_set = set()
for i in Add.make_args(expr):
act = _get_trial_set(i, x)
if eq_homogeneous is not S.Zero:
while any(is_homogeneous_solution(ts) for ts in act):
act = {x*ts for ts in act}
temp_set = temp_set.union(act)
retdict['trialset'] = temp_set
return retdict
def _solve_undetermined_coefficients(eq, func, order, match, trialset):
r"""
Helper function for the method of undetermined coefficients.
See the
:py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients`
docstring for more information on this method.
The parameter ``trialset`` is the set of trial functions as returned by
``_undetermined_coefficients_match()['trialset']``.
The parameter ``match`` should be a dictionary that has the following
keys:
``list``
A list of solutions to the homogeneous equation.
``sol``
The general solution.
"""
r = match
coeffs = numbered_symbols('a', cls=Dummy)
coefflist = []
gensols = r['list']
gsol = r['sol']
f = func.func
x = func.args[0]
if len(gensols) != order:
raise NotImplementedError("Cannot find " + str(order) +
" solutions to the homogeneous equation necessary to apply" +
" undetermined coefficients to " + str(eq) +
" (number of terms != order)")
trialfunc = 0
for i in trialset:
c = next(coeffs)
coefflist.append(c)
trialfunc += c*i
eqs = sub_func_doit(eq, f(x), trialfunc)
coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1))))
eqs = _mexpand(eqs)
for i in Add.make_args(eqs):
s = separatevars(i, dict=True, symbols=[x])
if coeffsdict.get(s[x]):
coeffsdict[s[x]] += s['coeff']
else:
coeffsdict[s[x]] = s['coeff']
coeffvals = solve(list(coeffsdict.values()), coefflist)
if not coeffvals:
raise NotImplementedError(
"Could not solve `%s` using the "
"method of undetermined coefficients "
"(unable to solve for coefficients)." % eq)
psol = trialfunc.subs(coeffvals)
return Eq(f(x), gsol.rhs + psol)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,893 @@
r"""
This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`,
a function which gives all rational particular solutions to first order
Riccati ODEs. A general first order Riccati ODE is given by -
.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2
where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x`
with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE
anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation
is a Bernoulli ODE. The algorithm presented below can find rational
solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution,
or prove that no rational solution exists for the equation.
Background
==========
A Riccati equation can be transformed to its normal form
.. math:: y' + y^2 = a(x)
using the transformation
.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2}
where `a(x)` is given by
.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2
Thus, we can develop an algorithm to solve for the Riccati equation
in its normal form, which would in turn give us the solution for
the original Riccati equation.
Algorithm
=========
The algorithm implemented here is presented in the Ph.D thesis
"Rational and Algebraic Solutions of First-Order Algebraic ODEs"
by N. Thieu Vo. The entire thesis can be found here -
https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf
We have only implemented the Rational Riccati solver (Algorithm 11,
Pg 78-82 in Thesis). Before we proceed towards the implementation
of the algorithm, a few definitions to understand are -
1. Valuation of a Rational Function at `\infty`:
The valuation of a rational function `p(x)` at `\infty` is equal
to the difference between the degree of the denominator and the
numerator of `p(x)`.
NOTE: A general definition of valuation of a rational function
at any value of `x` can be found in Pg 63 of the thesis, but
is not of any interest for this algorithm.
2. Zeros and Poles of a Rational Function:
Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function
of `x`. Then -
a. The Zeros of `a(x)` are the roots of `S(x)`.
b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty`
can also be a pole of a(x). We say that `a(x)` has a pole at
`\infty` if `a(\frac{1}{x})` has a pole at 0.
Every pole is associated with an order that is equal to the multiplicity
of its appearance as a root of `T(x)`. A pole is called a simple pole if
it has an order 1. Similarly, a pole is called a multiple pole if it has
an order `\ge` 2.
Necessary Conditions
====================
For a Riccati equation in its normal form,
.. math:: y' + y^2 = a(x)
we can define
a. A pole is called a movable pole if it is a pole of `y(x)` and is not
a pole of `a(x)`.
b. Similarly, a pole is called a non-movable pole if it is a pole of both
`y(x)` and `a(x)`.
Then, the algorithm states that a rational solution exists only if -
a. Every pole of `a(x)` must be either a simple pole or a multiple pole
of even order.
b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2.
This algorithm finds all possible rational solutions for the Riccati ODE.
If no rational solutions are found, it means that no rational solutions
exist.
The algorithm works for Riccati ODEs where the coefficients are rational
functions in the independent variable `x` with rational number coefficients
i.e. in `Q(x)`. The coefficients in the rational function cannot be floats,
irrational numbers, symbols or any other kind of expression. The reasons
for this are -
1. When using symbols, different symbols could take the same value and this
would affect the multiplicity of poles if symbols are present here.
2. An integer degree bound is required to calculate a polynomial solution
to an auxiliary differential equation, which in turn gives the particular
solution for the original ODE. If symbols/floats/irrational numbers are
present, we cannot determine if the expression for the degree bound is an
integer or not.
Solution
========
With these definitions, we can state a general form for the solution of
the equation. `y(x)` must have the form -
.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i
where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`,
`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values
of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The
coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})`
can be determined from `a(x)`. We will have 2 choices each of these vectors
and part of the procedure is figuring out which of the 2 should be used
to get the solution correctly.
Implementation
==============
In this implementation, we use ``Poly`` to represent a rational function
rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot
represent rational functions directly using ``Poly``, we instead represent
a rational function with 2 ``Poly`` objects - one for its numerator and
the other for its denominator.
The code is written to match the steps given in the thesis (Pg 82)
Step 0 : Match the equation -
Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise
an error
Step 1 : Transform the equation to its normal form as explained in the
theory section.
Step 2 : Initialize an empty set of solutions, ``sol``.
Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``.
Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}`
to ``sol``.
Step 5 : Find the poles and their multiplicities of `a(x)`. Let
the number of poles be `n`. Also find the valuation of `a(x)` at
`\infty` using ``val_at_inf``.
NOTE: Although the algorithm considers `\infty` as a pole, it is
not mentioned if it a part of the set of finite poles. `\infty`
is NOT a part of the set of finite poles. If a pole exists at
`\infty`, we use its multiplicity to find the laurent series of
`a(x)` about `\infty`.
Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using
``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)``
combinations of choosing between 2 choices for each of the `n` c-vectors
and 1 d-vector.
NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig
mistake. The term `- d_N` must be replaced with `-N d_N`. The same
has been explained in the code as well.
For each of these above combinations, do
Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of
the polynomial solution we must find for the auxiliary equation.
Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is
one part of y(x) -
.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i
Step 10 : If `m` is a non-negative integer -
Step 11: Find a polynomial solution of degree `m` for the auxiliary equation.
There are 2 cases possible -
a. `m` is a non-negative integer: We can solve for the coefficients
in `p(x)` using Undetermined Coefficients.
b. `m` is not a non-negative integer: In this case, we cannot find
a polynomial solution to the auxiliary equation, and hence, we ignore
this value of `m`.
Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}`
to ``sol``.
Step 13 : For each solution in ``sol``, apply an inverse transformation,
so that the solutions of the original equation are found using the
solutions of the equation in its normal form.
"""
from itertools import product
from sympy.core import S
from sympy.core.add import Add
from sympy.core.numbers import oo, Float
from sympy.core.function import count_ops
from sympy.core.relational import Eq
from sympy.core.symbol import symbols, Symbol, Dummy
from sympy.functions import sqrt, exp
from sympy.functions.elementary.complexes import sign
from sympy.integrals.integrals import Integral
from sympy.polys.domains import ZZ
from sympy.polys.polytools import Poly
from sympy.polys.polyroots import roots
from sympy.solvers.solveset import linsolve
def riccati_normal(w, x, b1, b2):
"""
Given a solution `w(x)` to the equation
.. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2
and rational function coefficients `b_1(x)` and
`b_2(x)`, this function transforms the solution to
give a solution `y(x)` for its corresponding normal
Riccati ODE
.. math:: y'(x) + y(x)^2 = a(x)
using the transformation
.. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2
"""
return -b2*w - b2.diff(x)/(2*b2) - b1/2
def riccati_inverse_normal(y, x, b1, b2, bp=None):
"""
Inverse transforming the solution to the normal
Riccati ODE to get the solution to the Riccati ODE.
"""
# bp is the expression which is independent of the solution
# and hence, it need not be computed again
if bp is None:
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
# w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x))
return -y/b2 + bp
def riccati_reduced(eq, f, x):
"""
Convert a Riccati ODE into its corresponding
normal Riccati ODE.
"""
match, funcs = match_riccati(eq, f, x)
# If equation is not a Riccati ODE, exit
if not match:
return False
# Using the rational functions, find the expression for a(x)
b0, b1, b2 = funcs
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
b2.diff(x, 2)/(2*b2)
# Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x)
return f(x).diff(x) + f(x)**2 - a
def linsolve_dict(eq, syms):
"""
Get the output of linsolve as a dict
"""
# Convert tuple type return value of linsolve
# to a dictionary for ease of use
sol = linsolve(eq, syms)
if not sol:
return {}
return dict(zip(syms, list(sol)[0]))
def match_riccati(eq, f, x):
"""
A function that matches and returns the coefficients
if an equation is a Riccati ODE
Parameters
==========
eq: Equation to be matched
f: Dependent variable
x: Independent variable
Returns
=======
match: True if equation is a Riccati ODE, False otherwise
funcs: [b0, b1, b2] if match is True, [] otherwise. Here,
b0, b1 and b2 are rational functions which match the equation.
"""
# Group terms based on f(x)
if isinstance(eq, Eq):
eq = eq.lhs - eq.rhs
eq = eq.expand().collect(f(x))
cf = eq.coeff(f(x).diff(x))
# There must be an f(x).diff(x) term.
# eq must be an Add object since we are using the expanded
# equation and it must have atleast 2 terms (b2 != 0)
if cf != 0 and isinstance(eq, Add):
# Divide all coefficients by the coefficient of f(x).diff(x)
# and add the terms again to get the same equation
eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x))
# Match the equation with the pattern
b1 = -eq.coeff(f(x))
b2 = -eq.coeff(f(x)**2)
b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand()
funcs = [b0, b1, b2]
# Check if coefficients are not symbols and floats
if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs):
return False, []
# If b_0(x) contains f(x), it is not a Riccati ODE
if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x),
b1.is_rational_function(x), b2.is_rational_function(x))):
return False, []
return True, funcs
return False, []
def val_at_inf(num, den, x):
# Valuation of a rational function at oo = deg(denom) - deg(numer)
return den.degree(x) - num.degree(x)
def check_necessary_conds(val_inf, muls):
"""
The necessary conditions for a rational solution
to exist are as follows -
i) Every pole of a(x) must be either a simple pole
or a multiple pole of even order.
ii) The valuation of a(x) at infinity must be even
or be greater than or equal to 2.
Here, a simple pole is a pole with multiplicity 1
and a multiple pole is a pole with multiplicity
greater than 1.
"""
return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \
all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls)
def inverse_transform_poly(num, den, x):
"""
A function to make the substitution
x -> 1/x in a rational function that
is represented using Poly objects for
numerator and denominator.
"""
# Declare for reuse
one = Poly(1, x)
xpoly = Poly(x, x)
# Check if degree of numerator is same as denominator
pwr = val_at_inf(num, den, x)
if pwr >= 0:
# Denominator has greater degree. Substituting x with
# 1/x would make the extra power go to the numerator
if num.expr != 0:
num = num.transform(one, xpoly) * x**pwr
den = den.transform(one, xpoly)
else:
# Numerator has greater degree. Substituting x with
# 1/x would make the extra power go to the denominator
num = num.transform(one, xpoly)
den = den.transform(one, xpoly) * x**(-pwr)
return num.cancel(den, include=True)
def limit_at_inf(num, den, x):
"""
Find the limit of a rational function
at oo
"""
# pwr = degree(num) - degree(den)
pwr = -val_at_inf(num, den, x)
# Numerator has a greater degree than denominator
# Limit at infinity would depend on the sign of the
# leading coefficients of numerator and denominator
if pwr > 0:
return oo*sign(num.LC()/den.LC())
# Degree of numerator is equal to that of denominator
# Limit at infinity is just the ratio of leading coeffs
elif pwr == 0:
return num.LC()/den.LC()
# Degree of numerator is less than that of denominator
# Limit at infinity is just 0
else:
return 0
def construct_c_case_1(num, den, x, pole):
# Find the coefficient of 1/(x - pole)**2 in the
# Laurent series expansion of a(x) about pole.
num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True)
r = (num1.subs(x, pole))/(den1.subs(x, pole))
# If multiplicity is 2, the coefficient to be added
# in the c-vector is c = (1 +- sqrt(1 + 4*r))/2
if r != -S(1)/4:
return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]]
return [[S.Half]]
def construct_c_case_2(num, den, x, pole, mul):
# Generate the coefficients using the recurrence
# relation mentioned in (5.14) in the thesis (Pg 80)
# r_i = mul/2
ri = mul//2
# Find the Laurent series coefficients about the pole
ser = rational_laurent_series(num, den, x, pole, mul, 6)
# Start with an empty memo to store the coefficients
# This is for the plus case
cplus = [0 for i in range(ri)]
# Base Case
cplus[ri-1] = sqrt(ser[2*ri])
# Iterate backwards to find all coefficients
s = ri - 1
sm = 0
for s in range(ri-1, 0, -1):
sm = 0
for j in range(s+1, ri):
sm += cplus[j-1]*cplus[ri+s-j-1]
if s!= 1:
cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1])
# Memo for the minus case
cminus = [-x for x in cplus]
# Find the 0th coefficient in the recurrence
cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1])
cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1])
# Add both the plus and minus cases' coefficients
if cplus != cminus:
return [cplus, cminus]
return cplus
def construct_c_case_3():
# If multiplicity is 1, the coefficient to be added
# in the c-vector is 1 (no choice)
return [[1]]
def construct_c(num, den, x, poles, muls):
"""
Helper function to calculate the coefficients
in the c-vector for each pole.
"""
c = []
for pole, mul in zip(poles, muls):
c.append([])
# Case 3
if mul == 1:
# Add the coefficients from Case 3
c[-1].extend(construct_c_case_3())
# Case 1
elif mul == 2:
# Add the coefficients from Case 1
c[-1].extend(construct_c_case_1(num, den, x, pole))
# Case 2
else:
# Add the coefficients from Case 2
c[-1].extend(construct_c_case_2(num, den, x, pole, mul))
return c
def construct_d_case_4(ser, N):
# Initialize an empty vector
dplus = [0 for i in range(N+2)]
# d_N = sqrt(a_{2*N})
dplus[N] = sqrt(ser[2*N])
# Use the recurrence relations to find
# the value of d_s
for s in range(N-1, -2, -1):
sm = 0
for j in range(s+1, N):
sm += dplus[j]*dplus[N+s-j]
if s != -1:
dplus[s] = (ser[N+s] - sm)/(2*dplus[N])
# Coefficients for the case of d_N = -sqrt(a_{2*N})
dminus = [-x for x in dplus]
# The third equation in Eq 5.15 of the thesis is WRONG!
# d_N must be replaced with N*d_N in that equation.
dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N])
dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N])
if dplus != dminus:
return [dplus, dminus]
return dplus
def construct_d_case_5(ser):
# List to store coefficients for plus case
dplus = [0, 0]
# d_0 = sqrt(a_0)
dplus[0] = sqrt(ser[0])
# d_(-1) = a_(-1)/(2*d_0)
dplus[-1] = ser[-1]/(2*dplus[0])
# Coefficients for the minus case are just the negative
# of the coefficients for the positive case.
dminus = [-x for x in dplus]
if dplus != dminus:
return [dplus, dminus]
return dplus
def construct_d_case_6(num, den, x):
# s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to
# s_oo = lim x->oo x**2 * a(x)
s_inf = limit_at_inf(Poly(x**2, x)*num, den, x)
# d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2
if s_inf != -S(1)/4:
return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]]
return [[S.Half]]
def construct_d(num, den, x, val_inf):
"""
Helper function to calculate the coefficients
in the d-vector based on the valuation of the
function at oo.
"""
N = -val_inf//2
# Multiplicity of oo as a pole
mul = -val_inf if val_inf < 0 else 0
ser = rational_laurent_series(num, den, x, oo, mul, 1)
# Case 4
if val_inf < 0:
d = construct_d_case_4(ser, N)
# Case 5
elif val_inf == 0:
d = construct_d_case_5(ser)
# Case 6
else:
d = construct_d_case_6(num, den, x)
return d
def rational_laurent_series(num, den, x, r, m, n):
r"""
The function computes the Laurent series coefficients
of a rational function.
Parameters
==========
num: A Poly object that is the numerator of `f(x)`.
den: A Poly object that is the denominator of `f(x)`.
x: The variable of expansion of the series.
r: The point of expansion of the series.
m: Multiplicity of r if r is a pole of `f(x)`. Should
be zero otherwise.
n: Order of the term upto which the series is expanded.
Returns
=======
series: A dictionary that has power of the term as key
and coefficient of that term as value.
Below is a basic outline of how the Laurent series of a
rational function `f(x)` about `x_0` is being calculated -
1. Substitute `x + x_0` in place of `x`. If `x_0`
is a pole of `f(x)`, multiply the expression by `x^m`
where `m` is the multiplicity of `x_0`. Denote the
the resulting expression as g(x). We do this substitution
so that we can now find the Laurent series of g(x) about
`x = 0`.
2. We can then assume that the Laurent series of `g(x)`
takes the following form -
.. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m
where `a_m` denotes the Laurent series coefficients.
3. Multiply the denominator to the RHS of the equation
and form a recurrence relation for the coefficients `a_m`.
"""
one = Poly(1, x, extension=True)
if r == oo:
# Series at x = oo is equal to first transforming
# the function from x -> 1/x and finding the
# series at x = 0
num, den = inverse_transform_poly(num, den, x)
r = S(0)
if r:
# For an expansion about a non-zero point, a
# transformation from x -> x + r must be made
num = num.transform(Poly(x + r, x, extension=True), one)
den = den.transform(Poly(x + r, x, extension=True), one)
# Remove the pole from the denominator if the series
# expansion is about one of the poles
num, den = (num*x**m).cancel(den, include=True)
# Equate coefficients for the first terms (base case)
maxdegree = 1 + max(num.degree(), den.degree())
syms = symbols(f'a:{maxdegree}', cls=Dummy)
diff = num - den * Poly(syms[::-1], x)
coeff_diffs = diff.all_coeffs()[::-1][:maxdegree]
(coeffs, ) = linsolve(coeff_diffs, syms)
# Use the recursion relation for the rest
recursion = den.all_coeffs()[::-1]
div, rec_rhs = recursion[0], recursion[1:]
series = list(coeffs)
while len(series) < n:
next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div
series.append(-next_coeff)
series = {m - i: val for i, val in enumerate(series)}
return series
def compute_m_ybar(x, poles, choice, N):
"""
Helper function to calculate -
1. m - The degree bound for the polynomial
solution that must be found for the auxiliary
differential equation.
2. ybar - Part of the solution which can be
computed using the poles, c and d vectors.
"""
ybar = 0
m = Poly(choice[-1][-1], x, extension=True)
# Calculate the first (nested) summation for ybar
# as given in Step 9 of the Thesis (Pg 82)
dybar = []
for i, polei in enumerate(poles):
for j, cij in enumerate(choice[i]):
dybar.append(cij/(x - polei)**(j + 1))
m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add
ybar += Add(*dybar)
# Calculate the second summation for ybar
for i in range(N+1):
ybar += choice[-1][i]*x**i
return (m.expr, ybar)
def solve_aux_eq(numa, dena, numy, deny, x, m):
"""
Helper function to find a polynomial solution
of degree m for the auxiliary differential
equation.
"""
# Assume that the solution is of the type
# p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m
psyms = symbols(f'C0:{m}', cls=Dummy)
K = ZZ[psyms]
psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K)
# Eq (5.16) in Thesis - Pg 81
auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol
if m >= 1:
px = psol.diff(x)
auxeq += px*(2*numy*deny*dena)
if m >= 2:
auxeq += px.diff(x)*(deny**2*dena)
if m != 0:
# m is a non-zero integer. Find the constant terms using undetermined coefficients
return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True
else:
# m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation
return S.One, auxeq, auxeq == 0
def remove_redundant_sols(sol1, sol2, x):
"""
Helper function to remove redundant
solutions to the differential equation.
"""
# If y1 and y2 are redundant solutions, there is
# some value of the arbitrary constant for which
# they will be equal
syms1 = sol1.atoms(Symbol, Dummy)
syms2 = sol2.atoms(Symbol, Dummy)
num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()]
num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()]
# Cross multiply
e = num1*den2 - den1*num2
# Check if there are any constants
syms = list(e.atoms(Symbol, Dummy))
if len(syms):
# Find values of constants for which solutions are equal
redn = linsolve(e.all_coeffs(), syms)
if len(redn):
# Return the general solution over a particular solution
if len(syms1) > len(syms2):
return sol2
# If both have constants, return the lesser complex solution
elif len(syms1) == len(syms2):
return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2
else:
return sol1
def get_gen_sol_from_part_sol(part_sols, a, x):
""""
Helper function which computes the general
solution for a Riccati ODE from its particular
solutions.
There are 3 cases to find the general solution
from the particular solutions for a Riccati ODE
depending on the number of particular solution(s)
we have - 1, 2 or 3.
For more information, see Section 6 of
"Methods of Solution of the Riccati Differential Equation"
by D. R. Haaheim and F. M. Stein
"""
# If no particular solutions are found, a general
# solution cannot be found
if len(part_sols) == 0:
return []
# In case of a single particular solution, the general
# solution can be found by using the substitution
# y = y1 + 1/z and solving a Bernoulli ODE to find z.
elif len(part_sols) == 1:
y1 = part_sols[0]
i = exp(Integral(2*y1, x))
z = i * Integral(a/i, x)
z = z.doit()
if a == 0 or z == 0:
return y1
return y1 + 1/z
# In case of 2 particular solutions, the general solution
# can be found by solving a separable equation. This is
# the most common case, i.e. most Riccati ODEs have 2
# rational particular solutions.
elif len(part_sols) == 2:
y1, y2 = part_sols
# One of them already has a constant
if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0:
u = exp(Integral(y2 - y1, x)).doit()
# Introduce a constant
else:
C1 = Dummy('C1')
u = C1*exp(Integral(y2 - y1, x)).doit()
if u == 1:
return y2
return (y2*u - y1)/(u - 1)
# In case of 3 particular solutions, a closed form
# of the general solution can be obtained directly
else:
y1, y2, y3 = part_sols[:3]
C1 = Dummy('C1')
return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3)
def solve_riccati(fx, x, b0, b1, b2, gensol=False):
"""
The main function that gives particular/general
solutions to Riccati ODEs that have atleast 1
rational particular solution.
"""
# Step 1 : Convert to Normal Form
a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \
b2.diff(x, 2)/(2*b2)
a_t = a.together()
num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()]
num, den = num.cancel(den, include=True)
# Step 2
presol = []
# Step 3 : a(x) is 0
if num == 0:
presol.append(1/(x + Dummy('C1')))
# Step 4 : a(x) is a non-zero constant
elif x not in num.free_symbols.union(den.free_symbols):
presol.extend([sqrt(a), -sqrt(a)])
# Step 5 : Find poles and valuation at infinity
poles = roots(den, x)
poles, muls = list(poles.keys()), list(poles.values())
val_inf = val_at_inf(num, den, x)
if len(poles):
# Check necessary conditions (outlined in the module docstring)
if not check_necessary_conds(val_inf, muls):
raise ValueError("Rational Solution doesn't exist")
# Step 6
# Construct c-vectors for each singular point
c = construct_c(num, den, x, poles, muls)
# Construct d vectors for each singular point
d = construct_d(num, den, x, val_inf)
# Step 7 : Iterate over all possible combinations and return solutions
# For each possible combination, generate an array of 0's and 1's
# where 0 means pick 1st choice and 1 means pick the second choice.
# NOTE: We could exit from the loop if we find 3 particular solutions,
# but it is not implemented here as -
# a. Finding 3 particular solutions is very rare. Most of the time,
# only 2 particular solutions are found.
# b. In case we exit after finding 3 particular solutions, it might
# happen that 1 or 2 of them are redundant solutions. So, instead of
# spending some more time in computing the particular solutions,
# we will end up computing the general solution from a single
# particular solution which is usually slower than computing the
# general solution from 2 or 3 particular solutions.
c.append(d)
choices = product(*c)
for choice in choices:
m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2)
numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()]
# Step 10 : Check if a valid solution exists. If yes, also check
# if m is a non-negative integer
if m.is_nonnegative == True and m.is_integer == True:
# Step 11 : Find polynomial solutions of degree m for the auxiliary equation
psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m)
# Step 12 : If valid polynomial solution exists, append solution.
if exists:
# m == 0 case
if psol == 1 and coeffs == 0:
# p(x) = 1, so p'(x)/p(x) term need not be added
presol.append(ybar)
# m is a positive integer and there are valid coefficients
elif len(coeffs):
# Substitute the valid coefficients to get p(x)
psol = psol.xreplace(coeffs)
# y(x) = ybar(x) + p'(x)/p(x)
presol.append(ybar + psol.diff(x)/psol)
# Remove redundant solutions from the list of existing solutions
remove = set()
for i in range(len(presol)):
for j in range(i+1, len(presol)):
rem = remove_redundant_sols(presol[i], presol[j], x)
if rem is not None:
remove.add(rem)
sols = [x for x in presol if x not in remove]
# Step 15 : Inverse transform the solutions of the equation in normal form
bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2)
# If general solution is required, compute it from the particular solutions
if gensol:
sols = [get_gen_sol_from_part_sol(sols, a, x)]
# Inverse transform the particular solutions
presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols]
return presol

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,392 @@
from sympy.core import S, Pow
from sympy.core.function import (Derivative, AppliedUndef, diff)
from sympy.core.relational import Equality, Eq
from sympy.core.symbol import Dummy
from sympy.core.sympify import sympify
from sympy.logic.boolalg import BooleanAtom
from sympy.functions import exp
from sympy.series import Order
from sympy.simplify.simplify import simplify, posify, besselsimp
from sympy.simplify.trigsimp import trigsimp
from sympy.simplify.sqrtdenest import sqrtdenest
from sympy.solvers import solve
from sympy.solvers.deutils import _preprocess, ode_order
from sympy.utilities.iterables import iterable, is_sequence
def sub_func_doit(eq, func, new):
r"""
When replacing the func with something else, we usually want the
derivative evaluated, so this function helps in making that happen.
Examples
========
>>> from sympy import Derivative, symbols, Function
>>> from sympy.solvers.ode.subscheck import sub_func_doit
>>> x, z = symbols('x, z')
>>> y = Function('y')
>>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x)
2
>>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x),
... 1/(x*(z + 1/x)))
x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x))
...- 1/(x**2*(z + 1/x)**2)
"""
reps= {func: new}
for d in eq.atoms(Derivative):
if d.expr == func:
reps[d] = new.diff(*d.variable_count)
else:
reps[d] = d.xreplace({func: new}).doit(deep=False)
return eq.xreplace(reps)
def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True):
r"""
Substitutes ``sol`` into ``ode`` and checks that the result is ``0``.
This works when ``func`` is one function, like `f(x)` or a list of
functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can
be a single solution or a list of solutions. Each solution may be an
:py:class:`~sympy.core.relational.Equality` that the solution satisfies,
e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an
:py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it
will not be necessary to explicitly identify the function, but if the
function cannot be inferred from the original equation it can be supplied
through the ``func`` argument.
If a sequence of solutions is passed, the same sort of container will be
used to return the result for each solution.
It tries the following methods, in order, until it finds zero equivalence:
1. Substitute the solution for `f` in the original equation. This only
works if ``ode`` is solved for `f`. It will attempt to solve it first
unless ``solve_for_func == False``.
2. Take `n` derivatives of the solution, where `n` is the order of
``ode``, and check to see if that is equal to the solution. This only
works on exact ODEs.
3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time
solving for the derivative of `f` of that order (this will always be
possible because `f` is a linear operator). Then back substitute each
derivative into ``ode`` in reverse order.
This function returns a tuple. The first item in the tuple is ``True`` if
the substitution results in ``0``, and ``False`` otherwise. The second
item in the tuple is what the substitution results in. It should always
be ``0`` if the first item is ``True``. Sometimes this function will
return ``False`` even when an expression is identically equal to ``0``.
This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not
reduce the expression to ``0``. If an expression returned by this
function vanishes identically, then ``sol`` really is a solution to
the ``ode``.
If this function seems to hang, it is probably because of a hard
simplification.
To use this function to test, test the first item of the tuple.
Examples
========
>>> from sympy import (Eq, Function, checkodesol, symbols,
... Derivative, exp)
>>> x, C1, C2 = symbols('x,C1,C2')
>>> f, g = symbols('f g', cls=Function)
>>> checkodesol(f(x).diff(x), Eq(f(x), C1))
(True, 0)
>>> assert checkodesol(f(x).diff(x), C1)[0]
>>> assert not checkodesol(f(x).diff(x), x)[0]
>>> checkodesol(f(x).diff(x, 2), x**2)
(False, 2)
>>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))]
>>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))]
>>> checkodesol(eqs, sol)
(True, [0, 0])
"""
if iterable(ode):
return checksysodesol(ode, sol, func=func)
if not isinstance(ode, Equality):
ode = Eq(ode, 0)
if func is None:
try:
_, func = _preprocess(ode.lhs)
except ValueError:
funcs = [s.atoms(AppliedUndef) for s in (
sol if is_sequence(sol, set) else [sol])]
funcs = set().union(*funcs)
if len(funcs) != 1:
raise ValueError(
'must pass func arg to checkodesol for this case.')
func = funcs.pop()
if not isinstance(func, AppliedUndef) or len(func.args) != 1:
raise ValueError(
"func must be a function of one variable, not %s" % func)
if is_sequence(sol, set):
return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol])
if not isinstance(sol, Equality):
sol = Eq(func, sol)
elif sol.rhs == func:
sol = sol.reversed
if order == 'auto':
order = ode_order(ode, func)
solved = sol.lhs == func and not sol.rhs.has(func)
if solve_for_func and not solved:
rhs = solve(sol, func)
if rhs:
eqs = [Eq(func, t) for t in rhs]
if len(rhs) == 1:
eqs = eqs[0]
return checkodesol(ode, eqs, order=order,
solve_for_func=False)
x = func.args[0]
# Handle series solutions here
if sol.has(Order):
assert sol.lhs == func
Oterm = sol.rhs.getO()
solrhs = sol.rhs.removeO()
Oexpr = Oterm.expr
assert isinstance(Oexpr, Pow)
sorder = Oexpr.exp
assert Oterm == Order(x**sorder)
odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand()
neworder = Order(x**(sorder - order))
odesubs = odesubs + neworder
assert odesubs.getO() == neworder
residual = odesubs.removeO()
return (residual == 0, residual)
s = True
testnum = 0
while s:
if testnum == 0:
# First pass, try substituting a solved solution directly into the
# ODE. This has the highest chance of succeeding.
ode_diff = ode.lhs - ode.rhs
if sol.lhs == func:
s = sub_func_doit(ode_diff, func, sol.rhs)
s = besselsimp(s)
else:
testnum += 1
continue
ss = simplify(s.rewrite(exp))
if ss:
# with the new numer_denom in power.py, if we do a simple
# expansion then testnum == 0 verifies all solutions.
s = ss.expand(force=True)
else:
s = 0
testnum += 1
elif testnum == 1:
# Second pass. If we cannot substitute f, try seeing if the nth
# derivative is equal, this will only work for odes that are exact,
# by definition.
s = simplify(
trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) -
trigsimp(ode.lhs) + trigsimp(ode.rhs))
# s2 = simplify(
# diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \
# ode.lhs + ode.rhs)
testnum += 1
elif testnum == 2:
# Third pass. Try solving for df/dx and substituting that into the
# ODE. Thanks to Chris Smith for suggesting this method. Many of
# the comments below are his, too.
# The method:
# - Take each of 1..n derivatives of the solution.
# - Solve each nth derivative for d^(n)f/dx^(n)
# (the differential of that order)
# - Back substitute into the ODE in decreasing order
# (i.e., n, n-1, ...)
# - Check the result for zero equivalence
if sol.lhs == func and not sol.rhs.has(func):
diffsols = {0: sol.rhs}
elif sol.rhs == func and not sol.lhs.has(func):
diffsols = {0: sol.lhs}
else:
diffsols = {}
sol = sol.lhs - sol.rhs
for i in range(1, order + 1):
# Differentiation is a linear operator, so there should always
# be 1 solution. Nonetheless, we test just to make sure.
# We only need to solve once. After that, we automatically
# have the solution to the differential in the order we want.
if i == 1:
ds = sol.diff(x)
try:
sdf = solve(ds, func.diff(x, i))
if not sdf:
raise NotImplementedError
except NotImplementedError:
testnum += 1
break
else:
diffsols[i] = sdf[0]
else:
# This is what the solution says df/dx should be.
diffsols[i] = diffsols[i - 1].diff(x)
# Make sure the above didn't fail.
if testnum > 2:
continue
else:
# Substitute it into ODE to check for self consistency.
lhs, rhs = ode.lhs, ode.rhs
for i in range(order, -1, -1):
if i == 0 and 0 not in diffsols:
# We can only substitute f(x) if the solution was
# solved for f(x).
break
lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i])
rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i])
ode_or_bool = Eq(lhs, rhs)
ode_or_bool = simplify(ode_or_bool)
if isinstance(ode_or_bool, (bool, BooleanAtom)):
if ode_or_bool:
lhs = rhs = S.Zero
else:
lhs = ode_or_bool.lhs
rhs = ode_or_bool.rhs
# No sense in overworking simplify -- just prove that the
# numerator goes to zero
num = trigsimp((lhs - rhs).as_numer_denom()[0])
# since solutions are obtained using force=True we test
# using the same level of assumptions
## replace function with dummy so assumptions will work
_func = Dummy('func')
num = num.subs(func, _func)
## posify the expression
num, reps = posify(num)
s = simplify(num).xreplace(reps).xreplace({_func: func})
testnum += 1
else:
break
if not s:
return (True, s)
elif s is True: # The code above never was able to change s
raise NotImplementedError("Unable to test if " + str(sol) +
" is a solution to " + str(ode) + ".")
else:
return (False, s)
def checksysodesol(eqs, sols, func=None):
r"""
Substitutes corresponding ``sols`` for each functions into each ``eqs`` and
checks that the result of substitutions for each equation is ``0``. The
equations and solutions passed can be any iterable.
This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`.
For each function, ``sols`` can have a single solution or a list of solutions.
In most cases it will not be necessary to explicitly identify the function,
but if the function cannot be inferred from the original equation it
can be supplied through the ``func`` argument.
When a sequence of equations is passed, the same sequence is used to return
the result for each equation with each function substituted with corresponding
solutions.
It tries the following method to find zero equivalence for each equation:
Substitute the solutions for functions, like `x(t)` and `y(t)` into the
original equations containing those functions.
This function returns a tuple. The first item in the tuple is ``True`` if
the substitution results for each equation is ``0``, and ``False`` otherwise.
The second item in the tuple is what the substitution results in. Each element
of the ``list`` should always be ``0`` corresponding to each equation if the
first item is ``True``. Note that sometimes this function may return ``False``,
but with an expression that is identically equal to ``0``, instead of returning
``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot
reduce the expression to ``0``. If an expression returned by each function
vanishes identically, then ``sols`` really is a solution to ``eqs``.
If this function seems to hang, it is probably because of a difficult simplification.
Examples
========
>>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function
>>> from sympy.solvers.ode.subscheck import checksysodesol
>>> C1, C2 = symbols('C1:3')
>>> t = symbols('t')
>>> x, y = symbols('x, y', cls=Function)
>>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12))
>>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3),
... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)]
>>> checksysodesol(eq, sol)
(True, [0, 0])
>>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3))
>>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2),
... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)]
>>> checksysodesol(eq, sol)
(True, [0, 0])
"""
def _sympify(eq):
return list(map(sympify, eq if iterable(eq) else [eq]))
eqs = _sympify(eqs)
for i in range(len(eqs)):
if isinstance(eqs[i], Equality):
eqs[i] = eqs[i].lhs - eqs[i].rhs
if func is None:
funcs = []
for eq in eqs:
derivs = eq.atoms(Derivative)
func = set().union(*[d.atoms(AppliedUndef) for d in derivs])
funcs.extend(func)
funcs = list(set(funcs))
if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\
and len({func.args for func in funcs})!=1:
raise ValueError("func must be a function of one variable, not %s" % func)
for sol in sols:
if len(sol.atoms(AppliedUndef)) != 1:
raise ValueError("solutions should have one function only")
if len(funcs) != len({sol.lhs for sol in sols}):
raise ValueError("number of solutions provided does not match the number of equations")
dictsol = {}
for sol in sols:
func = list(sol.atoms(AppliedUndef))[0]
if sol.rhs == func:
sol = sol.reversed
solved = sol.lhs == func and not sol.rhs.has(func)
if not solved:
rhs = solve(sol, func)
if not rhs:
raise NotImplementedError
else:
rhs = sol.rhs
dictsol[func] = rhs
checkeq = []
for eq in eqs:
for func in funcs:
eq = sub_func_doit(eq, func, dictsol[func])
ss = simplify(eq)
if ss != 0:
eq = ss.expand(force=True)
if eq != 0:
eq = sqrtdenest(eq).simplify()
else:
eq = 0
checkeq.append(eq)
if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0:
return (True, checkeq)
else:
return (False, checkeq)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,152 @@
from sympy.core.function import Function
from sympy.core.numbers import Rational
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (atan, sin, tan)
from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals)
from sympy.solvers.ode.subscheck import checkodesol
from sympy.testing.pytest import XFAIL
C1 = Symbol('C1')
x, y = symbols("x y")
f = Function('f')
xi = Function('xi')
eta = Function('eta')
def test_heuristic1():
a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0")
df = f(x).diff(x)
eq = Eq(df, x**2*f(x))
eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x)
eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2)
eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x))
eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2)
eq5 = x**2*df - f(x) + x**2*exp(x - (1/x))
eqlist = [eq, eq1, eq2, eq3, eq4, eq5]
i = infinitesimals(eq, hint='abaco1_simple')
assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0},
{eta(x, f(x)): f(x), xi(x, f(x)): 0},
{eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}]
i1 = infinitesimals(eq1, hint='abaco1_simple')
assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}]
i2 = infinitesimals(eq2, hint='abaco1_simple')
assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}]
i3 = infinitesimals(eq3, hint='abaco1_simple')
assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1},
{eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}]
i4 = infinitesimals(eq4, hint='abaco1_simple')
assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0},
{eta(x, f(x)): 0,
xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}]
i5 = infinitesimals(eq5, hint='abaco1_simple')
assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}]
ilist = [i, i1, i2, i3, i4, i5]
for eq, i in (zip(eqlist, ilist)):
check = checkinfsol(eq, i)
assert check[0]
# This ODE can be solved by the Lie Group method, when there are
# better assumptions
eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2)
i = infinitesimals(eq6, hint='abaco1_product')
assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}]
assert checkinfsol(eq6, i)[0]
eq7 = x*(f(x).diff(x)) + 1 - f(x)**2
i = infinitesimals(eq7, hint='chi')
assert checkinfsol(eq7, i)[0]
def test_heuristic3():
a, b = symbols("a b")
df = f(x).diff(x)
eq = x**2*df + x*f(x) + f(x)**2 + x**2
i = infinitesimals(eq, hint='bivariate')
assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}]
assert checkinfsol(eq, i)[0]
eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x
i = infinitesimals(eq, hint='bivariate')
assert checkinfsol(eq, i)[0]
def test_heuristic_function_sum():
eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x +
(1 - 3*f(x))*(x/f(x)**2))
i = infinitesimals(eq, hint='function_sum')
assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}]
assert checkinfsol(eq, i)[0]
def test_heuristic_abaco2_similar():
a, b = symbols("a b")
F = Function('F')
eq = f(x).diff(x) - F(a*x + b*f(x))
i = infinitesimals(eq, hint='abaco2_similar')
assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}]
assert checkinfsol(eq, i)[0]
eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x)))
i = infinitesimals(eq, hint='abaco2_similar')
assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}]
assert checkinfsol(eq, i)[0]
def test_heuristic_abaco2_unique_unknown():
a, b = symbols("a b")
F = Function('F')
eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b)
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}]
assert checkinfsol(eq, i)[0]
eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x)))
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}]
assert checkinfsol(eq, i)[0]
eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a
i = infinitesimals(eq, hint='abaco2_unique_unknown')
assert checkinfsol(eq, i)[0]
def test_heuristic_linear():
a, b, m, n = symbols("a b m n")
eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1))
i = infinitesimals(eq, hint='linear')
assert checkinfsol(eq, i)[0]
@XFAIL
def test_kamke():
a, b, alpha, c = symbols("a b alpha c")
eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c
i = infinitesimals(eq, hint='sum_function') # XFAIL
assert checkinfsol(eq, i)[0]
def test_user_infinitesimals():
x = Symbol("x") # assuming x is real generates an error
eq = x*(f(x).diff(x)) + 1 - f(x)**2
sol = Eq(f(x), (C1 + x**2)/(C1 - x**2))
infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0}
assert dsolve(eq, hint='lie_group', **infinitesimals) == sol
assert checkodesol(eq, sol) == (True, 0)
@XFAIL
def test_lie_group_issue15219():
eqn = exp(f(x).diff(x)-f(x))
assert 'lie_group' not in classify_ode(eqn, f(x))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,877 @@
from sympy.core.random import randint
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.numbers import (I, Rational, oo)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.hyperbolic import tanh
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import sin
from sympy.polys.polytools import Poly
from sympy.simplify.ratsimp import ratsimp
from sympy.solvers.ode.subscheck import checkodesol
from sympy.testing.pytest import slow
from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal,
riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf,
check_necessary_conds, val_at_inf, construct_c_case_1,
construct_c_case_2, construct_c_case_3, construct_d_case_4,
construct_d_case_5, construct_d_case_6, rational_laurent_series,
solve_riccati)
f = Function('f')
x = symbols('x')
# These are the functions used to generate the tests
# SHOULD NOT BE USED DIRECTLY IN TESTS
def rand_rational(maxint):
return Rational(randint(-maxint, maxint), randint(1, maxint))
def rand_poly(x, degree, maxint):
return Poly([rand_rational(maxint) for _ in range(degree+1)], x)
def rand_rational_function(x, degree, maxint):
degnum = randint(1, degree)
degden = randint(1, degree)
num = rand_poly(x, degnum, maxint)
den = rand_poly(x, degden, maxint)
while den == Poly(0, x):
den = rand_poly(x, degden, maxint)
return num / den
def find_riccati_ode(ratfunc, x, yf):
y = ratfunc
yp = y.diff(x)
q1 = rand_rational_function(x, 1, 3)
q2 = rand_rational_function(x, 1, 3)
while q2 == 0:
q2 = rand_rational_function(x, 1, 3)
q0 = ratsimp(yp - q1*y - q2*y**2)
eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2)
sol = Eq(yf, y)
assert checkodesol(eq, sol) == (True, 0)
return eq, q0, q1, q2
# Testing functions start
def test_riccati_transformation():
"""
This function tests the transformation of the
solution of a Riccati ODE to the solution of
its corresponding normal Riccati ODE.
Each test case 4 values -
1. w - The solution to be transformed
2. b1 - The coefficient of f(x) in the ODE.
3. b2 - The coefficient of f(x)**2 in the ODE.
4. y - The solution to the normal Riccati ODE.
"""
tests = [
(
x/(x - 1),
(x**2 + 7)/3*x,
x,
-x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x)
),
(
(2*x + 3)/(2*x + 2),
(3 - 3*x)/(x + 1),
5*x,
-5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x)
),
(
-1/(2*x**2 - 1),
0,
(2 - x)/(4*x - 2),
(2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \
2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False))
),
(
x,
(8*x - 12)/(12*x + 9),
x**3/(6*x - 9),
-x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \
- 9)**2 + 3*x**2/(6*x - 9))/(2*x**3)
)]
for w, b1, b2, y in tests:
assert y == riccati_normal(w, x, b1, b2)
assert w == riccati_inverse_normal(y, x, b1, b2).cancel()
# Test bp parameter in riccati_inverse_normal
tests = [
(
(-2*x - 1)/(2*x**2 + 2*x - 2),
-2/x,
(-x - 1)/(4*x),
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \
- 2)) + 1/x
),
(
3/(2*x**2),
-2/x,
(-x - 1)/(4*x),
8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
-2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3)
)]
for w, b1, b2, bp, y in tests:
assert y == riccati_normal(w, x, b1, b2)
assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel()
def test_riccati_reduced():
"""
This function tests the transformation of a
Riccati ODE to its normal Riccati ODE.
Each test case 2 values -
1. eq - A Riccati ODE.
2. normal_eq - The normal Riccati ODE of eq.
"""
tests = [
(
f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2,
f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2)
),
(
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x,
-3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \
-x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \
- (-1 + (x + 1)/x)/(x*(-x - 1))
),
(
f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2),
-(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \
f(x)**2 + f(x).diff(x) - 1/(2*x + 1)
),
(
f(x).diff(x) - f(x)**2/x,
f(x)**2 + f(x).diff(x) + 1/(4*x**2)
),
(
-3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x,
f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \
+ 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2)
),
(
6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x,
False
),
(
f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2),
False
)]
for eq, normal_eq in tests:
assert normal_eq == riccati_reduced(eq, f, x)
def test_match_riccati():
"""
This function tests if an ODE is Riccati or not.
Each test case has 5 values -
1. eq - The Riccati ODE.
2. match - Boolean indicating if eq is a Riccati ODE.
3. b0 -
4. b1 - Coefficient of f(x) in eq.
5. b2 - Coefficient of f(x)**2 in eq.
"""
tests = [
# Test Rational Riccati ODEs
(
f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \
- 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \
- (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2),
True,
45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \
(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \
315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \
+ 846*x**2 + 180*x - 72),
Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2),
1/(3*x + 1)
),
(
f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \
1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \
/(324*x**3 + 216*x**2),
True,
-4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \
+ 24*x) + 265/(324*x + 216),
Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6),
x/3 - 1
),
(
f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \
+ 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \
360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \
S(3)/2) - (x/3 - 3)*f(x)**2/(3*x),
True,
304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \
108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \
360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \
x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \
189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \
53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \
+ 53*x**3 - 63*x**2 + 40*x - 12),
Mul(-1, 3 - 2*x, evaluate=False)/(x - 3),
Mul(-1, 9 - x, evaluate=False)/(9*x)
),
# Test Non-Rational Riccati ODEs
(
f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \
x*f(x)**2/(x**(S(3)/4)),
False, 0, 0, 0
),
(
f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2,
False, 0, 0, 0
),
(
f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2,
False, 0, 0, 0
),
# Test Non-Riccati ODEs
(
(1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x),
False, 0, 0, 0
),
(
f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3,
False, 0, 0, 0
),
(
f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \
+ 3) + f(x)**2,
False, 0, 0, 0
)]
for eq, res, b0, b1, b2 in tests:
match, funcs = match_riccati(eq, f, x)
assert match == res
if res:
assert [b0, b1, b2] == funcs
def test_val_at_inf():
"""
This function tests the valuation of rational
function at oo.
Each test case has 3 values -
1. num - Numerator of rational function.
2. den - Denominator of rational function.
3. val_inf - Valuation of rational function at oo
"""
tests = [
# degree(denom) > degree(numer)
(
Poly(10*x**3 + 8*x**2 - 13*x + 6, x),
Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x),
7
),
(
Poly(1, x),
Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x),
4
),
# degree(denom) == degree(numer)
(
Poly(-6*x**3 - 8*x**2 + 8*x - 6, x),
Poly(-5*x**3 + 12*x**2 - 6*x - 9, x),
0
),
# degree(denom) < degree(numer)
(
Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x),
Poly(-14*x**2 + x, x),
-6
),
(
Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x),
Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x),
-2
)]
for num, den, val in tests:
assert val_at_inf(num, den, x) == val
def test_necessary_conds():
"""
This function tests the necessary conditions for
a Riccati ODE to have a rational particular solution.
"""
# Valuation at Infinity is an odd negative integer
assert check_necessary_conds(-3, [1, 2, 4]) == False
# Valuation at Infinity is a positive integer lesser than 2
assert check_necessary_conds(1, [1, 2, 4]) == False
# Multiplicity of a pole is an odd integer greater than 1
assert check_necessary_conds(2, [3, 1, 6]) == False
# All values are correct
assert check_necessary_conds(-10, [1, 2, 8, 12]) == True
def test_inverse_transform_poly():
"""
This function tests the substitution x -> 1/x
in rational functions represented using Poly.
"""
fns = [
(15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6),
(180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12),
(-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80),
(60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30),
(30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \
+ 15*x**2 + 45*x - 15)
]
for f in fns:
num, den = [Poly(e, x) for e in f.as_numer_denom()]
num, den = inverse_transform_poly(num, den, x)
assert f.subs(x, 1/x).cancel() == num/den
def test_limit_at_inf():
"""
This function tests the limit at oo of a
rational function.
Each test case has 3 values -
1. num - Numerator of rational function.
2. den - Denominator of rational function.
3. limit_at_inf - Limit of rational function at oo
"""
tests = [
# deg(denom) > deg(numer)
(
Poly(-12*x**2 + 20*x + 32, x),
Poly(32*x**3 + 72*x**2 + 3*x - 32, x),
0
),
# deg(denom) < deg(numer)
(
Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x),
Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x),
oo
),
# deg(denom) < deg(numer), one of the leading coefficients is negative
(
Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \
+ 630*x + 3780, x),
Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x),
-oo
),
# deg(denom) == deg(numer)
(
Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x),
Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x),
S(1)/7
),
(
Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x),
Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x),
S(288)/607
)]
for num, den, lim in tests:
assert limit_at_inf(num, den, x) == lim
def test_construct_c_case_1():
"""
This function tests the Case 1 in the step
to calculate coefficients of c-vectors.
Each test case has 4 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. pole - Pole of a(x) for which c-vector is being
calculated.
4. c - The c-vector for the pole.
"""
tests = [
(
Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True),
Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True),
S(0),
[[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]]
),
(
Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True),
Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True),
S(7)/4,
[[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]]
),
(
Poly(4*x + 2, x, extension=True),
Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \
+ 8*sqrt(3) + 16, x, domain='QQ<sqrt(3)>'),
(S(1) + sqrt(3))/2,
[[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \
[S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]]
)]
for num, den, pole, c in tests:
assert construct_c_case_1(num, den, x, pole) == c
def test_construct_c_case_2():
"""
This function tests the Case 2 in the step
to calculate coefficients of c-vectors.
Each test case has 5 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. pole - Pole of a(x) for which c-vector is being
calculated.
4. mul - The multiplicity of the pole.
5. c - The c-vector for the pole.
"""
tests = [
# Testing poles with multiplicity 2
(
Poly(1, x, extension=True),
Poly((x - 1)**2*(x - 2), x, extension=True),
1, 2,
[[-I*(-1 - I)/2], [I*(-1 + I)/2]]
),
(
Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True),
Poly((3*x - 1)**2*(x + 2)**2, x, extension=True),
S(1)/3, 2,
[[-S(89)/98], [-S(9)/98]]
),
# Testing poles with multiplicity 4
(
Poly(x**3 - x**2 + 4*x, x, extension=True),
Poly((x - 2)**4*(x + 5)**2, x, extension=True),
2, 4,
[[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \
[-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]]
),
(
Poly(3*x**5 + x**4 + 3, x, extension=True),
Poly((4*x + 1)**4*(x + 2), x, extension=True),
-S(1)/4, 4,
[[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \
[-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]]
),
# Testing poles with multiplicity 6
(
Poly(x**3 + 2, x, extension=True),
Poly((3*x - 1)**6*(x**2 + 1), x, extension=True),
S(1)/3, 6,
[[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \
[-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]]
),
(
Poly(x**2 + 12, x, extension=True),
Poly((x - sqrt(2))**6, x, extension=True),
sqrt(2), 6,
[[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \
[-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]]
)]
for num, den, pole, mul, c in tests:
assert construct_c_case_2(num, den, x, pole, mul) == c
def test_construct_c_case_3():
"""
This function tests the Case 3 in the step
to calculate coefficients of c-vectors.
"""
assert construct_c_case_3() == [[1]]
def test_construct_d_case_4():
"""
This function tests the Case 4 in the step
to calculate coefficients of the d-vector.
Each test case has 4 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. mul - Multiplicity of oo as a pole.
4. d - The d-vector.
"""
tests = [
# Tests with multiplicity at oo = 2
(
Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True),
Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True),
2,
[[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \
[-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]]
),
(
Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True),
Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True),
2,
[[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]]
),
# Tests with multiplicity at oo = 4
(
Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True),
Poly(3*x**2 + 10*x + 7, x, extension=True),
4,
[[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \
- 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \
sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]]
),
(
Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True),
Poly(12*x - 2, x, extension=True),
4,
[[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \
[-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]]
),
# Tests with multiplicity at oo = 4
(
Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True),
Poly(x + 2, x, extension=True),
6,
[[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]]
),
(
Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True),
Poly(2*x - 2, x, extension=True),
6,
[[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \
3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \
sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]]
)]
for num, den, mul, d in tests:
ser = rational_laurent_series(num, den, x, oo, mul, 1)
assert construct_d_case_4(ser, mul//2) == d
def test_construct_d_case_5():
"""
This function tests the Case 5 in the step
to calculate coefficients of the d-vector.
Each test case has 3 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. d - The d-vector.
"""
tests = [
(
Poly(2*x**3 + x**2 + x - 2, x, extension=True),
Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True),
[[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]]
),
(
Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'),
Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'),
[[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]]
),
(
Poly(x**2 - x + 1, x, domain='ZZ'),
Poly(3*x**2 + 7*x + 3, x, domain='ZZ'),
[[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]]
)]
for num, den, d in tests:
# Multiplicity of oo is 0
ser = rational_laurent_series(num, den, x, oo, 0, 1)
assert construct_d_case_5(ser) == d
def test_construct_d_case_6():
"""
This function tests the Case 6 in the step
to calculate coefficients of the d-vector.
Each test case has 3 values -
1. num - Numerator of the rational function a(x).
2. den - Denominator of the rational function a(x).
3. d - The d-vector.
"""
tests = [
(
Poly(-2*x**2 - 5, x, domain='ZZ'),
Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'),
[[S(1)/2 + I/2], [S(1)/2 - I/2]]
),
(
Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'),
Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'),
[[1], [0]]
),
(
Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'),
Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \
- 27*x - 42, x, domain='ZZ'),
[[1], [0]]
)]
for num, den, d in tests:
assert construct_d_case_6(num, den, x) == d
def test_rational_laurent_series():
"""
This function tests the computation of coefficients
of Laurent series of a rational function.
Each test case has 5 values -
1. num - Numerator of the rational function.
2. den - Denominator of the rational function.
3. x0 - Point about which Laurent series is to
be calculated.
4. mul - Multiplicity of x0 if x0 is a pole of
the rational function (0 otherwise).
5. n - Number of terms upto which the series
is to be calculated.
"""
tests = [
# Laurent series about simple pole (Multiplicity = 1)
(
Poly(x**2 - 3*x + 9, x, extension=True),
Poly(x**2 - x, x, extension=True),
S(1), 1, 6,
{1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9}
),
# Laurent series about multiple pole (Multiplicity > 1)
(
Poly(64*x**3 - 1728*x + 1216, x, extension=True),
Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True),
S(9)/8, 2, 3,
{0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \
1: S(209149)/75645}
),
(
Poly(1, x, extension=True),
Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \
+ (4 + 8*sqrt(2))*x - 4, x, extension=True),
sqrt(2), 4, 6,
{4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \
+ sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \
)/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4}
),
# Laurent series about oo
(
Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True),
Poly(x**2 - 5, x, extension=True),
oo, 3, 6,
{3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17}
),
# Laurent series at x0 where x0 is not a pole of the function
# Using multiplicity as 0 (as x0 will not be a pole)
(
Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True),
Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True),
S(2)/5, 0, 1,
{0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \
-3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016}
),
(
Poly(-7*x**2 + 2*x - 4, x, extension=True),
Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True),
oo, 0, 6,
{0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7}
)]
for num, den, x0, mul, n, ser in tests:
assert ser == rational_laurent_series(num, den, x, x0, mul, n)
def check_dummy_sol(eq, solse, dummy_sym):
"""
Helper function to check if actual solution
matches expected solution if actual solution
contains dummy symbols.
"""
if isinstance(eq, Eq):
eq = eq.lhs - eq.rhs
_, funcs = match_riccati(eq, f, x)
sols = solve_riccati(f(x), x, *funcs)
C1 = Dummy('C1')
sols = [sol.subs(C1, dummy_sym) for sol in sols]
assert all(x[0] for x in checkodesol(eq, sols))
assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse))
def test_solve_riccati():
"""
This function tests the computation of rational
particular solutions for a Riccati ODE.
Each test case has 2 values -
1. eq - Riccati ODE to be solved.
2. sol - Expected solution to the equation.
Some examples have been taken from the paper - "Statistical Investigation of
First-Order Algebraic ODEs and their Rational General Solutions" by
Georg Grasegger, N. Thieu Vo, Franz Winkler
https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf
"""
C0 = Dummy('C0')
# Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2,
# a, b, c are rational functions of x
tests = [
# a(x) is a constant
(
Eq(f(x).diff(x) + f(x)**2 - 2, 0),
[Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))]
),
# a(x) is a constant
(
f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2,
[Eq(f(x), (-2*C0 - x)/(C0*x + x**2))]
),
# a(x) is a constant
(
2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x),
[Eq(f(x), (C0 + 2*x**2)/(C0 + x))]
),
# Pole with multiplicity 1
(
Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)),
[Eq(f(x), 1/(x**2 - x))]
),
# One pole of multiplicity 2
(
x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x),
[Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)]
),
(
x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x),
[Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)]
),
# Multiple poles of multiplicity 2
(
-f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \
- 1)**2),
[Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \
- 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \
57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \
- 3*x + 1))]
),
# Regression: Poles with even multiplicity > 2 fixed
(
f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \
7*x**2 - 20*x + 4)/(4*x**4),
[Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \
- 2*x**2))]
),
# Regression: Poles with even multiplicity > 2 fixed
(
Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\
(x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \
792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)),
[Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))]
),
# More than 2 poles with multiplicity 2
# Regression: Fixed mistake in necessary conditions
(
Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \
(8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2),
[Eq(f(x), (1 - 4*x)/(2*x - 2))]
),
# Regression: Fixed mistake in necessary conditions
(
Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \
+ 3*f(x)**2/(6*x + 2)),
[Eq(f(x), (2*x + 1)/(2*x - 2))]
),
# Imaginary poles
(
f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \
- 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2),
[Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \
- x)), Eq(f(x), -1/(x - 1))],
),
# Imaginary coefficients in equation
(
f(x).diff(x) - 2*I*(f(x)**2 + 1)/x,
[Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)]
),
# Regression: linsolve returning empty solution
# Large value of m (> 10)
(
Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\
(2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)),
[Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \
18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\
x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\
((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \
415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\
*x**3 + 32744250*x**2 + 16372125*x)))]
),
# Regression: Fixed bug due to a typo in paper
(
Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6),
[Eq(f(x), 6*x)]
),
# Regression: Fixed bug due to a typo in paper
(
Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \
+ 9 + (1 - x)*f(x)/x + 3/x),
[Eq(f(x), -3*x/2 - 3)]
)]
for eq, sol in tests:
check_dummy_sol(eq, sol, C0)
@slow
def test_solve_riccati_slow():
"""
This function tests the computation of rational
particular solutions for a Riccati ODE.
Each test case has 2 values -
1. eq - Riccati ODE to be solved.
2. sol - Expected solution to the equation.
"""
C0 = Dummy('C0')
tests = [
# Very large values of m (989 and 991)
(
Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \
(54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \
2187*x + 2187) + 495),
[Eq(f(x), (18*x + 6)/(2*x - 9))]
)]
for eq, sol in tests:
check_dummy_sol(eq, sol, C0)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,203 @@
from sympy.core.function import (Derivative, Function, diff)
from sympy.core.numbers import (I, Rational, pi)
from sympy.core.relational import Eq
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.functions.special.error_functions import (Ei, erf, erfi)
from sympy.integrals.integrals import Integral
from sympy.solvers.ode.subscheck import checkodesol, checksysodesol
from sympy.functions import besselj, bessely
from sympy.testing.pytest import raises, slow
C0, C1, C2, C3, C4 = symbols('C0:5')
u, x, y, z = symbols('u,x:z', real=True)
f = Function('f')
g = Function('g')
h = Function('h')
@slow
def test_checkodesol():
# For the most part, checkodesol is well tested in the tests below.
# These tests only handle cases not checked below.
raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x)))
raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y),
x), f(x, y)))
assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \
(False, -f(x).diff(x) + f(x, y).diff(x) - 1)
assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True
assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1)
sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0)
assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0)
assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \
(False, 60*x**4*((log(x) + 1)**2 + log(x))*(
log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9)
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \
(True, 0)
assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0),
solve_for_func=False) == (True, 0)
assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x),
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \
[(True, 0), (True, 0), (False, C2)]
assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x),
Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \
{(True, 0), (True, 0), (False, C2)}
assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \
[(True, 0), (True, 0)]
assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0)
# Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that
# checkodesol tries back substituting f(x) when it can.
eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x)
sol3 = Eq(f(x), log(log(C1/x)**(-x)))
assert not checkodesol(eq3, sol3)[1].has(f(x))
# This case was failing intermittently depending on hash-seed:
eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x))
sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))
assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0]
eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x)
sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x))
assert checkodesol(eq, sol) == (True, 0)
eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))]
sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]
assert checkodesol(eqs, sol) == (True, [0, 0])
def test_checksysodesol():
x, y, z = symbols('x, y, z', cls=Function)
t = Symbol('t')
eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t)))
sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \
Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t)))
sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \
Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \
Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t)))
sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \
Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \
C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23))
sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \
Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \
sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23))
sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \
Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t)))
sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\
exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \
C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t)))
sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\
exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \
C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t)))
sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \
Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \
C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t)))
root0 = -sqrt(-sqrt(47) + 7)
root1 = sqrt(-sqrt(47) + 7)
root2 = -sqrt(sqrt(47) + 7)
root3 = sqrt(sqrt(47) + 7)
sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \
Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \
C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12))
root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2))
root1 = sqrt(-sqrt(109)/2 + Rational(15, 2))
root2 = -sqrt(sqrt(109)/2 + Rational(15, 2))
root3 = sqrt(sqrt(109)/2 + Rational(15, 2))
sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \
Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \
C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0))
sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \
C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \
+ C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t)))
I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t
I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t
sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t)))
sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \
Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t)))
sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \
Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \
Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t))))
sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \
Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \
Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t)))
sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \
Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \
Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t)))
sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \
Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))]
assert checksysodesol(eq, sol) == (True, [0, 0, 0])
eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5))
sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5))
sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \
Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))]
assert checksysodesol(eq, sol) == (True, [0, 0])
eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2))
sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)}
assert checksysodesol(eq, sol) == (True, [0, 0])

File diff suppressed because it is too large Load diff