291 lines
10 KiB
Python
291 lines
10 KiB
Python
"""Module for differentiation using CSE."""
|
|
|
|
from sympy import cse, Matrix, Derivative, MatrixBase
|
|
from sympy.utilities.iterables import iterable
|
|
|
|
|
|
def _remove_cse_from_derivative(replacements, reduced_expressions):
|
|
"""
|
|
This function is designed to postprocess the output of a common subexpression
|
|
elimination (CSE) operation. Specifically, it removes any CSE replacement
|
|
symbols from the arguments of ``Derivative`` terms in the expression. This
|
|
is necessary to ensure that the forward Jacobian function correctly handles
|
|
derivative terms.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
replacements : list of (Symbol, expression) pairs
|
|
Replacement symbols and relative common subexpressions that have been
|
|
replaced during a CSE operation.
|
|
|
|
reduced_expressions : list of SymPy expressions
|
|
The reduced expressions with all the replacements from the
|
|
replacements list above.
|
|
|
|
Returns
|
|
=======
|
|
|
|
processed_replacements : list of (Symbol, expression) pairs
|
|
Processed replacement list, in the same format of the
|
|
``replacements`` input list.
|
|
|
|
processed_reduced : list of SymPy expressions
|
|
Processed reduced list, in the same format of the
|
|
``reduced_expressions`` input list.
|
|
"""
|
|
|
|
def traverse(node, repl_dict):
|
|
if isinstance(node, Derivative):
|
|
return replace_all(node, repl_dict)
|
|
if not node.args:
|
|
return node
|
|
new_args = [traverse(arg, repl_dict) for arg in node.args]
|
|
return node.func(*new_args)
|
|
|
|
def replace_all(node, repl_dict):
|
|
result = node
|
|
while True:
|
|
free_symbols = result.free_symbols
|
|
symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict}
|
|
if not symbols_dict:
|
|
break
|
|
result = result.xreplace(symbols_dict)
|
|
return result
|
|
|
|
repl_dict = dict(replacements)
|
|
processed_replacements = [
|
|
(rep_sym, traverse(sub_exp, repl_dict))
|
|
for rep_sym, sub_exp in replacements
|
|
]
|
|
processed_reduced = [
|
|
red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp])
|
|
for red_exp in reduced_expressions
|
|
]
|
|
|
|
return processed_replacements, processed_reduced
|
|
|
|
|
|
def _forward_jacobian_cse(replacements, reduced_expr, wrt):
|
|
"""
|
|
Core function to compute the Jacobian of an input Matrix of expressions
|
|
through forward accumulation. Takes directly the output of a CSE operation
|
|
(replacements and reduced_expr), and an iterable of variables (wrt) with
|
|
respect to which to differentiate the reduced expression and returns the
|
|
reduced Jacobian matrix and the ``replacements`` list.
|
|
|
|
The function also returns a list of precomputed free symbols for each
|
|
subexpression, which are useful in the substitution process.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
replacements : list of (Symbol, expression) pairs
|
|
Replacement symbols and relative common subexpressions that have been
|
|
replaced during a CSE operation.
|
|
|
|
reduced_expr : list of SymPy expressions
|
|
The reduced expressions with all the replacements from the
|
|
replacements list above.
|
|
|
|
wrt : iterable
|
|
Iterable of expressions with respect to which to compute the
|
|
Jacobian matrix.
|
|
|
|
Returns
|
|
=======
|
|
|
|
replacements : list of (Symbol, expression) pairs
|
|
Replacement symbols and relative common subexpressions that have been
|
|
replaced during a CSE operation. Compared to the input replacement list,
|
|
the output one doesn't contain replacement symbols inside
|
|
``Derivative``'s arguments.
|
|
|
|
jacobian : list of SymPy expressions
|
|
The list only contains one element, which is the Jacobian matrix with
|
|
elements in reduced form (replacement symbols are present).
|
|
|
|
precomputed_fs: list
|
|
List of sets, which store the free symbols present in each sub-expression.
|
|
Useful in the substitution process.
|
|
"""
|
|
|
|
if not isinstance(reduced_expr[0], MatrixBase):
|
|
raise TypeError("``expr`` must be of matrix type")
|
|
|
|
if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1):
|
|
raise TypeError("``expr`` must be a row or a column matrix")
|
|
|
|
if not iterable(wrt):
|
|
raise TypeError("``wrt`` must be an iterable of variables")
|
|
|
|
elif not isinstance(wrt, MatrixBase):
|
|
wrt = Matrix(wrt)
|
|
|
|
if not (wrt.shape[0] == 1 or wrt.shape[1] == 1):
|
|
raise TypeError("``wrt`` must be a row or a column matrix")
|
|
|
|
replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr)
|
|
|
|
if replacements:
|
|
rep_sym, sub_expr = map(Matrix, zip(*replacements))
|
|
else:
|
|
rep_sym, sub_expr = Matrix([]), Matrix([])
|
|
|
|
l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0])
|
|
|
|
f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt,
|
|
{
|
|
(i, j): diff_value
|
|
for i, r in enumerate(reduced_expr[0])
|
|
for j, w in enumerate(wrt)
|
|
if (diff_value := r.diff(w)) != 0
|
|
},
|
|
)
|
|
|
|
if not replacements:
|
|
return [], [f1], []
|
|
|
|
f2 = Matrix.from_dok(l_red, l_sub,
|
|
{
|
|
(i, j): diff_value
|
|
for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]])
|
|
for j, s in enumerate(rep_sym)
|
|
if s in fs and (diff_value := r.diff(s)) != 0
|
|
},
|
|
)
|
|
|
|
rep_sym_set = set(rep_sym)
|
|
precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ]
|
|
|
|
c_matrix = Matrix.from_dok(1, l_wrt,
|
|
{(0, j): diff_value for j, w in enumerate(wrt)
|
|
if (diff_value := sub_expr[0].diff(w)) != 0})
|
|
|
|
for i in range(1, l_sub):
|
|
|
|
bi_matrix = Matrix.from_dok(1, i,
|
|
{(0, j): diff_value for j in range(i + 1)
|
|
if rep_sym[j] in precomputed_fs[i]
|
|
and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0})
|
|
|
|
ai_matrix = Matrix.from_dok(1, l_wrt,
|
|
{(0, j): diff_value for j, w in enumerate(wrt)
|
|
if (diff_value := sub_expr[i].diff(w)) != 0})
|
|
|
|
if bi_matrix._rep.nnz():
|
|
ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix)
|
|
c_matrix = Matrix.vstack(c_matrix, ci_matrix)
|
|
else:
|
|
c_matrix = Matrix.vstack(c_matrix, ai_matrix)
|
|
|
|
jacobian = f2.multiply(c_matrix).add(f1)
|
|
jacobian = [reduced_expr[0].__class__(jacobian)]
|
|
|
|
return replacements, jacobian, precomputed_fs
|
|
|
|
|
|
def _forward_jacobian_norm_in_cse_out(expr, wrt):
|
|
"""
|
|
Function to compute the Jacobian of an input Matrix of expressions through
|
|
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
|
|
and an iterable of variables (wrt) with respect to which to compute the
|
|
Jacobian matrix. The matrix is returned in reduced form (containing
|
|
replacement symbols) along with the ``replacements`` list.
|
|
|
|
The function also returns a list of precomputed free symbols for each
|
|
subexpression, which are useful in the substitution process.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
expr : Matrix
|
|
The vector to be differentiated.
|
|
|
|
wrt : iterable
|
|
The vector with respect to which to perform the differentiation.
|
|
Can be a matrix or an iterable of variables.
|
|
|
|
Returns
|
|
=======
|
|
|
|
replacements : list of (Symbol, expression) pairs
|
|
Replacement symbols and relative common subexpressions that have been
|
|
replaced during a CSE operation. The output replacement list doesn't
|
|
contain replacement symbols inside ``Derivative``'s arguments.
|
|
|
|
jacobian : list of SymPy expressions
|
|
The list only contains one element, which is the Jacobian matrix with
|
|
elements in reduced form (replacement symbols are present).
|
|
|
|
precomputed_fs: list
|
|
List of sets, which store the free symbols present in each
|
|
sub-expression. Useful in the substitution process.
|
|
"""
|
|
|
|
replacements, reduced_expr = cse(expr)
|
|
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
|
|
|
return replacements, jacobian, precomputed_fs
|
|
|
|
|
|
def _forward_jacobian(expr, wrt):
|
|
"""
|
|
Function to compute the Jacobian of an input Matrix of expressions through
|
|
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
|
|
and an iterable of variables (wrt) with respect to which to compute the
|
|
Jacobian matrix.
|
|
|
|
Explanation
|
|
===========
|
|
|
|
Expressions often contain repeated subexpressions. Using a tree structure,
|
|
these subexpressions are duplicated and differentiated multiple times,
|
|
leading to inefficiency.
|
|
|
|
Instead, if a data structure called a directed acyclic graph (DAG) is used
|
|
then each of these repeated subexpressions will only exist a single time.
|
|
This function uses a combination of representing the expression as a DAG and
|
|
a forward accumulation algorithm (repeated application of the chain rule
|
|
symbolically) to more efficiently calculate the Jacobian matrix of a target
|
|
expression ``expr`` with respect to an expression or set of expressions
|
|
``wrt``.
|
|
|
|
Note that this function is intended to improve performance when
|
|
differentiating large expressions that contain many common subexpressions.
|
|
For small and simple expressions it is likely less performant than using
|
|
SymPy's standard differentiation functions and methods.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
expr : Matrix
|
|
The vector to be differentiated.
|
|
|
|
wrt : iterable
|
|
The vector with respect to which to do the differentiation.
|
|
Can be a matrix or an iterable of variables.
|
|
|
|
See Also
|
|
========
|
|
|
|
Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph
|
|
"""
|
|
|
|
replacements, reduced_expr = cse(expr)
|
|
|
|
if replacements:
|
|
rep_sym, _ = map(Matrix, zip(*replacements))
|
|
else:
|
|
rep_sym = Matrix([])
|
|
|
|
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
|
|
|
|
if not replacements: return jacobian[0]
|
|
|
|
sub_rep = dict(replacements)
|
|
for i, ik in enumerate(precomputed_fs):
|
|
sub_dict = {j: sub_rep[j] for j in ik}
|
|
sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict)
|
|
|
|
return jacobian[0].xreplace(sub_rep)
|