File size: 12,941 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
"""LU decomposition functions."""
from warnings import warn
from numpy import asarray, asarray_chkfinite
import numpy as np
from itertools import product
# Local imports
from ._misc import _datacopied, LinAlgWarning
from .lapack import get_lapack_funcs
from ._decomp_lu_cython import lu_dispatcher
lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)])
for x in np.typecodes['All']}
__all__ = ['lu', 'lu_solve', 'lu_factor']
def lu_factor(a, overwrite_a=False, check_finite=True):
"""
Compute pivoted LU decomposition of a matrix.
The decomposition is::
A = P L U
where P is a permutation matrix, L lower triangular with unit
diagonal elements, and U upper triangular.
Parameters
----------
a : (M, N) array_like
Matrix to decompose
overwrite_a : bool, optional
Whether to overwrite data in A (may increase performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
lu : (M, N) ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : (K,) ndarray
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
Of shape ``(K,)``, with ``K = min(M, N)``.
See Also
--------
lu : gives lu factorization in more user-friendly format
lu_solve : solve an equation system using the LU factorization of a matrix
Notes
-----
This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
:func:`lu`, it outputs the L and U factors into a single array
and returns pivot indices instead of a permutation matrix.
While the underlying ``*GETRF`` routines return 1-based pivot indices, the
``piv`` array returned by ``lu_factor`` contains 0-based indices.
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu_factor
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> lu, piv = lu_factor(A)
>>> piv
array([2, 2, 3, 3], dtype=int32)
Convert LAPACK's ``piv`` array to NumPy index and test the permutation
>>> def pivot_to_permutation(piv):
... perm = np.arange(len(piv))
... for i in range(len(piv)):
... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
... return perm
...
>>> p_inv = pivot_to_permutation(piv)
>>> p_inv
array([2, 0, 3, 1])
>>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
>>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
True
The P matrix in P L U is defined by the inverse permutation and
can be recovered using argsort:
>>> p = np.argsort(p_inv)
>>> p
array([1, 3, 0, 2])
>>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
True
or alternatively:
>>> P = np.eye(4)[p]
>>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
True
"""
if check_finite:
a1 = asarray_chkfinite(a)
else:
a1 = asarray(a)
# accommodate empty arrays
if a1.size == 0:
lu = np.empty_like(a1)
piv = np.arange(0, dtype=np.int32)
return lu, piv
overwrite_a = overwrite_a or (_datacopied(a1, a))
getrf, = get_lapack_funcs(('getrf',), (a1,))
lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
if info < 0:
raise ValueError('illegal value in %dth argument of '
'internal getrf (lu_factor)' % -info)
if info > 0:
warn("Diagonal number %d is exactly zero. Singular matrix." % info,
LinAlgWarning, stacklevel=2)
return lu, piv
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""Solve an equation system, a x = b, given the LU factorization of a
Parameters
----------
(lu, piv)
Factorization of the coefficient matrix a, as given by lu_factor.
In particular piv are 0-indexed pivot indices.
b : array
Right-hand side
trans : {0, 1, 2}, optional
Type of system to solve:
===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b : bool, optional
Whether to overwrite data in b (may increase performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns
-------
x : array
Solution to the system
See Also
--------
lu_factor : LU factorize a matrix
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu_factor, lu_solve
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = lu_factor(A)
>>> x = lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
True
"""
(lu, piv) = lu_and_piv
if check_finite:
b1 = asarray_chkfinite(b)
else:
b1 = asarray(b)
overwrite_b = overwrite_b or _datacopied(b1, b)
if lu.shape[0] != b1.shape[0]:
raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")
# accommodate empty arrays
if b1.size == 0:
m = lu_solve((np.eye(2, dtype=lu.dtype), [0, 1]), np.ones(2, dtype=b.dtype))
return np.empty_like(b1, dtype=m.dtype)
getrs, = get_lapack_funcs(('getrs',), (lu, b1))
x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
if info == 0:
return x
raise ValueError('illegal value in %dth argument of internal gesv|posv'
% -info)
def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
p_indices=False):
"""
Compute LU decomposition of a matrix with partial pivoting.
The decomposition satisfies::
A = P @ L @ U
where ``P`` is a permutation matrix, ``L`` lower triangular with unit
diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
``True`` then ``L`` is returned already permuted and hence satisfying
``A = L @ U``.
Parameters
----------
a : (M, N) array_like
Array to decompose
permute_l : bool, optional
Perform the multiplication P*L (Default: do not permute)
overwrite_a : bool, optional
Whether to overwrite data in a (may improve performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
p_indices : bool, optional
If ``True`` the permutation information is returned as row indices.
The default is ``False`` for backwards-compatibility reasons.
Returns
-------
**(If `permute_l` is ``False``)**
p : (..., M, M) ndarray
Permutation arrays or vectors depending on `p_indices`
l : (..., M, K) ndarray
Lower triangular or trapezoidal array with unit diagonal.
``K = min(M, N)``
u : (..., K, N) ndarray
Upper triangular or trapezoidal array
**(If `permute_l` is ``True``)**
pl : (..., M, K) ndarray
Permuted L matrix.
``K = min(M, N)``
u : (..., K, N) ndarray
Upper triangular or trapezoidal array
Notes
-----
Permutation matrices are costly since they are nothing but row reorder of
``L`` and hence indices are strongly recommended to be used instead if the
permutation is required. The relation in the 2D case then becomes simply
``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
to avoid complicated indexing tricks.
In 2D case, if one has the indices however, for some reason, the
permutation matrix is still needed then it can be constructed by
``np.eye(M)[P, :]``.
Examples
--------
>>> import numpy as np
>>> from scipy.linalg import lu
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = lu(A)
>>> np.allclose(A, p @ l @ u)
True
>>> p # Permutation matrix
array([[0., 1., 0., 0.], # Row index 1
[0., 0., 0., 1.], # Row index 3
[1., 0., 0., 0.], # Row index 0
[0., 0., 1., 0.]]) # Row index 2
>>> p, _, _ = lu(A, p_indices=True)
>>> p
array([1, 3, 0, 2], dtype=int32) # as given by row indices above
>>> np.allclose(A, l[p, :] @ u)
True
We can also use nd-arrays, for example, a demonstration with 4D array:
>>> rng = np.random.default_rng()
>>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
>>> p, l, u = lu(A)
>>> p.shape, l.shape, u.shape
((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
>>> np.allclose(A, p @ l @ u)
True
>>> PL, U = lu(A, permute_l=True)
>>> np.allclose(A, PL @ U)
True
"""
a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
if a1.ndim < 2:
raise ValueError('The input array must be at least two-dimensional.')
# Also check if dtype is LAPACK compatible
if a1.dtype.char not in 'fdFD':
dtype_char = lapack_cast_dict[a1.dtype.char]
if not dtype_char: # No casting possible
raise TypeError(f'The dtype {a1.dtype} cannot be cast '
'to float(32, 64) or complex(64, 128).')
a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch
overwrite_a = True
*nd, m, n = a1.shape
k = min(m, n)
real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
# Empty input
if min(*a1.shape) == 0:
if permute_l:
PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
return PL, U
else:
P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
np.empty([*nd, 0, 0], dtype=real_dchar))
L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
return P, L, U
# Scalar case
if a1.shape[-2:] == (1, 1):
if permute_l:
return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
else:
P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
else np.ones_like(a1))
return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
# Then check overwrite permission
if not _datacopied(a1, a): # "a" still alive through "a1"
if not overwrite_a:
# Data belongs to "a" so make a copy
a1 = a1.copy(order='C')
# else: Do nothing we'll use "a" if possible
# else: a1 has its own data thus free to scratch
# Then layout checks, might happen that overwrite is allowed but original
# array was read-only or non-contiguous.
if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
a1 = a1.copy(order='C')
if not nd: # 2D array
p = np.empty(m, dtype=np.int32)
u = np.zeros([k, k], dtype=a1.dtype)
lu_dispatcher(a1, u, p, permute_l)
P, L, U = (p, a1, u) if m > n else (p, u, a1)
else: # Stacked array
# Prepare the contiguous data holders
P = np.empty([*nd, m], dtype=np.int32) # perm vecs
if m > n: # Tall arrays, U will be created
U = np.zeros([*nd, k, k], dtype=a1.dtype)
for ind in product(*[range(x) for x in a1.shape[:-2]]):
lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
L = a1
else: # Fat arrays, L will be created
L = np.zeros([*nd, k, k], dtype=a1.dtype)
for ind in product(*[range(x) for x in a1.shape[:-2]]):
lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
U = a1
# Convert permutation vecs to permutation arrays
# permute_l=False needed to enter here to avoid wasted efforts
if (not p_indices) and (not permute_l):
if nd:
Pa = np.zeros([*nd, m, m], dtype=real_dchar)
# An unreadable index hack - One-hot encoding for perm matrices
nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
Pa[(*nd_ix, P)] = 1
P = Pa
else: # 2D case
Pa = np.zeros([m, m], dtype=real_dchar)
Pa[np.arange(m), P] = 1
P = Pa
return (L, U) if permute_l else (P, L, U)
|