Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from .matexpr import MatrixExpr
from .special import Identity
from sympy.core import S
from sympy.core.expr import ExprBuilder
from sympy.core.cache import cacheit
from sympy.core.power import Pow
from sympy.core.sympify import _sympify
from sympy.matrices import MatrixBase
from sympy.matrices.exceptions import NonSquareMatrixError
class MatPow(MatrixExpr):
def __new__(cls, base, exp, evaluate=False, **options):
base = _sympify(base)
if not base.is_Matrix:
raise TypeError("MatPow base should be a matrix")
if base.is_square is False:
raise NonSquareMatrixError("Power of non-square matrix %s" % base)
exp = _sympify(exp)
obj = super().__new__(cls, base, exp)
if evaluate:
obj = obj.doit(deep=False)
return obj
@property
def base(self):
return self.args[0]
@property
def exp(self):
return self.args[1]
@property
def shape(self):
return self.base.shape
@cacheit
def _get_explicit_matrix(self):
return self.base.as_explicit()**self.exp
def _entry(self, i, j, **kwargs):
from sympy.matrices.expressions import MatMul
A = self.doit()
if isinstance(A, MatPow):
# We still have a MatPow, make an explicit MatMul out of it.
if A.exp.is_Integer and A.exp.is_positive:
A = MatMul(*[A.base for k in range(A.exp)])
elif not self._is_shape_symbolic():
return A._get_explicit_matrix()[i, j]
else:
# Leave the expression unevaluated:
from sympy.matrices.expressions.matexpr import MatrixElement
return MatrixElement(self, i, j)
return A[i, j]
def doit(self, **hints):
if hints.get('deep', True):
base, exp = (arg.doit(**hints) for arg in self.args)
else:
base, exp = self.args
# combine all powers, e.g. (A ** 2) ** 3 -> A ** 6
while isinstance(base, MatPow):
exp *= base.args[1]
base = base.args[0]
if isinstance(base, MatrixBase):
# Delegate
return base ** exp
# Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them
if exp == S.One:
return base
if exp == S.Zero:
return Identity(base.rows)
if exp == S.NegativeOne:
from sympy.matrices.expressions import Inverse
return Inverse(base).doit(**hints)
eval_power = getattr(base, '_eval_power', None)
if eval_power is not None:
return eval_power(exp)
return MatPow(base, exp)
def _eval_transpose(self):
base, exp = self.args
return MatPow(base.transpose(), exp)
def _eval_adjoint(self):
base, exp = self.args
return MatPow(base.adjoint(), exp)
def _eval_conjugate(self):
base, exp = self.args
return MatPow(base.conjugate(), exp)
def _eval_derivative(self, x):
return Pow._eval_derivative(self, x)
def _eval_derivative_matrix_lines(self, x):
from sympy.tensor.array.expressions.array_expressions import ArrayContraction
from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
from .matmul import MatMul
from .inverse import Inverse
exp = self.exp
if self.base.shape == (1, 1) and not exp.has(x):
lr = self.base._eval_derivative_matrix_lines(x)
for i in lr:
subexpr = ExprBuilder(
ArrayContraction,
[
ExprBuilder(
ArrayTensorProduct,
[
Identity(1),
i._lines[0],
exp*self.base**(exp-1),
i._lines[1],
Identity(1),
]
),
(0, 3, 4), (5, 7, 8)
],
validator=ArrayContraction._validate
)
i._first_pointer_parent = subexpr.args[0].args
i._first_pointer_index = 0
i._second_pointer_parent = subexpr.args[0].args
i._second_pointer_index = 4
i._lines = [subexpr]
return lr
if (exp > 0) == True:
newexpr = MatMul.fromiter([self.base for i in range(exp)])
elif (exp == -1) == True:
return Inverse(self.base)._eval_derivative_matrix_lines(x)
elif (exp < 0) == True:
newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)])
elif (exp == 0) == True:
return self.doit()._eval_derivative_matrix_lines(x)
else:
raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x))
return newexpr._eval_derivative_matrix_lines(x)
def _eval_inverse(self):
return MatPow(self.base, -self.exp)
|