""" Matrix square root for general matrices and for upper triangular matrices. This module exists to avoid cyclic imports. """ __all__ = [] import numpy as np # Local imports from .lapack import ztrsyl, dtrsyl class SqrtmError(np.linalg.LinAlgError): pass from ._matfuncs_sqrtm_triu import within_block_loop # noqa: E402 def _sqrtm_triu(T, blocksize=64): """ Matrix square root of an upper triangular matrix. This is a helper function for `sqrtm` and `logm`. Parameters ---------- T : (N, N) array_like upper triangular Matrix whose square root to evaluate blocksize : int, optional If the blocksize is not degenerate with respect to the size of the input array, then use a blocked algorithm. (Default: 64) Returns ------- sqrtm : (N, N) ndarray Value of the sqrt function at `T` References ---------- .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013) "Blocked Schur Algorithms for Computing the Matrix Square Root, Lecture Notes in Computer Science, 7782. pp. 171-182. """ T_diag = np.diag(T) keep_it_real = np.isrealobj(T) and np.min(T_diag, initial=0.) >= 0 # Cast to complex as necessary + ensure double precision if not keep_it_real: T = np.asarray(T, dtype=np.complex128, order="C") T_diag = np.asarray(T_diag, dtype=np.complex128) else: T = np.asarray(T, dtype=np.float64, order="C") T_diag = np.asarray(T_diag, dtype=np.float64) R = np.diag(np.sqrt(T_diag)) # Compute the number of blocks to use; use at least one block. n, n = T.shape nblocks = max(n // blocksize, 1) # Compute the smaller of the two sizes of blocks that # we will actually use, and compute the number of large blocks. bsmall, nlarge = divmod(n, nblocks) blarge = bsmall + 1 nsmall = nblocks - nlarge if nsmall * bsmall + nlarge * blarge != n: raise Exception('internal inconsistency') # Define the index range covered by each block. start_stop_pairs = [] start = 0 for count, size in ((nsmall, bsmall), (nlarge, blarge)): for i in range(count): start_stop_pairs.append((start, start + size)) start += size # Within-block interactions (Cythonized) try: within_block_loop(R, T, start_stop_pairs, nblocks) except RuntimeError as e: raise SqrtmError(*e.args) from e # Between-block interactions (Cython would give no significant speedup) for j in range(nblocks): jstart, jstop = start_stop_pairs[j] for i in range(j-1, -1, -1): istart, istop = start_stop_pairs[i] S = T[istart:istop, jstart:jstop] if j - i > 1: S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart, jstart:jstop]) # Invoke LAPACK. # For more details, see the solve_sylvester implementation # and the fortran dtrsyl and ztrsyl docs. Rii = R[istart:istop, istart:istop] Rjj = R[jstart:jstop, jstart:jstop] if keep_it_real: x, scale, info = dtrsyl(Rii, Rjj, S) else: x, scale, info = ztrsyl(Rii, Rjj, S) R[istart:istop, jstart:jstop] = x * scale # Return the matrix square root. return R