107 lines
3.3 KiB
Python
107 lines
3.3 KiB
Python
"""
|
|
Matrix square root for general matrices and for upper triangular matrices.
|
|
|
|
This module exists to avoid cyclic imports.
|
|
|
|
"""
|
|
__all__ = []
|
|
|
|
import numpy as np
|
|
|
|
# Local imports
|
|
from .lapack import ztrsyl, dtrsyl
|
|
|
|
class SqrtmError(np.linalg.LinAlgError):
|
|
pass
|
|
|
|
from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402
|
|
|
|
|
|
def _sqrtm_triu(T, blocksize=64):
|
|
"""
|
|
Matrix square root of an upper triangular matrix.
|
|
|
|
This is a helper function for `sqrtm` and `logm`.
|
|
|
|
Parameters
|
|
----------
|
|
T : (N, N) array_like upper triangular
|
|
Matrix whose square root to evaluate
|
|
blocksize : int, optional
|
|
If the blocksize is not degenerate with respect to the
|
|
size of the input array, then use a blocked algorithm. (Default: 64)
|
|
|
|
Returns
|
|
-------
|
|
sqrtm : (N, N) ndarray
|
|
Value of the sqrt function at `T`
|
|
|
|
References
|
|
----------
|
|
.. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
|
|
"Blocked Schur Algorithms for Computing the Matrix Square Root,
|
|
Lecture Notes in Computer Science, 7782. pp. 171-182.
|
|
|
|
"""
|
|
T_diag = np.diag(T)
|
|
keep_it_real = np.isrealobj(T) and np.min(T_diag, initial=0.) >= 0
|
|
|
|
# Cast to complex as necessary + ensure double precision
|
|
if not keep_it_real:
|
|
T = np.asarray(T, dtype=np.complex128, order="C")
|
|
T_diag = np.asarray(T_diag, dtype=np.complex128)
|
|
else:
|
|
T = np.asarray(T, dtype=np.float64, order="C")
|
|
T_diag = np.asarray(T_diag, dtype=np.float64)
|
|
|
|
R = np.diag(np.sqrt(T_diag))
|
|
|
|
# Compute the number of blocks to use; use at least one block.
|
|
n, n = T.shape
|
|
nblocks = max(n // blocksize, 1)
|
|
|
|
# Compute the smaller of the two sizes of blocks that
|
|
# we will actually use, and compute the number of large blocks.
|
|
bsmall, nlarge = divmod(n, nblocks)
|
|
blarge = bsmall + 1
|
|
nsmall = nblocks - nlarge
|
|
if nsmall * bsmall + nlarge * blarge != n:
|
|
raise Exception('internal inconsistency')
|
|
|
|
# Define the index range covered by each block.
|
|
start_stop_pairs = []
|
|
start = 0
|
|
for count, size in ((nsmall, bsmall), (nlarge, blarge)):
|
|
for i in range(count):
|
|
start_stop_pairs.append((start, start + size))
|
|
start += size
|
|
|
|
# Within-block interactions (Cythonized)
|
|
try:
|
|
within_block_loop(R, T, start_stop_pairs, nblocks)
|
|
except RuntimeError as e:
|
|
raise SqrtmError(*e.args) from e
|
|
|
|
# Between-block interactions (Cython would give no significant speedup)
|
|
for j in range(nblocks):
|
|
jstart, jstop = start_stop_pairs[j]
|
|
for i in range(j-1, -1, -1):
|
|
istart, istop = start_stop_pairs[i]
|
|
S = T[istart:istop, jstart:jstop]
|
|
if j - i > 1:
|
|
S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
|
|
jstart:jstop])
|
|
|
|
# Invoke LAPACK.
|
|
# For more details, see the solve_sylvester implementation
|
|
# and the fortran dtrsyl and ztrsyl docs.
|
|
Rii = R[istart:istop, istart:istop]
|
|
Rjj = R[jstart:jstop, jstart:jstop]
|
|
if keep_it_real:
|
|
x, scale, info = dtrsyl(Rii, Rjj, S)
|
|
else:
|
|
x, scale, info = ztrsyl(Rii, Rjj, S)
|
|
R[istart:istop, jstart:jstop] = x * scale
|
|
|
|
# Return the matrix square root.
|
|
return R
|