Spaces:
Sleeping
Sleeping
File size: 4,773 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from functools import reduce
import operator
from sympy.core import Basic, sympify
from sympy.core.add import add, Add, _could_extract_minus_sign
from sympy.core.sorting import default_sort_key
from sympy.functions import adjoint
from sympy.matrices.matrixbase import MatrixBase
from sympy.matrices.expressions.transpose import transpose
from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
exhaust, do_one, glom)
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix
from sympy.matrices.expressions._shape import validate_matadd_integer as validate
from sympy.utilities.iterables import sift
from sympy.utilities.exceptions import sympy_deprecation_warning
# XXX: MatAdd should perhaps not subclass directly from Add
class MatAdd(MatrixExpr, Add):
"""A Sum of Matrix Expressions
MatAdd inherits from and operates like SymPy Add
Examples
========
>>> from sympy import MatAdd, MatrixSymbol
>>> A = MatrixSymbol('A', 5, 5)
>>> B = MatrixSymbol('B', 5, 5)
>>> C = MatrixSymbol('C', 5, 5)
>>> MatAdd(A, B, C)
A + B + C
"""
is_MatAdd = True
identity = GenericZeroMatrix()
def __new__(cls, *args, evaluate=False, check=None, _sympify=True):
if not args:
return cls.identity
# This must be removed aggressively in the constructor to avoid
# TypeErrors from GenericZeroMatrix().shape
args = list(filter(lambda i: cls.identity != i, args))
if _sympify:
args = list(map(sympify, args))
if not all(isinstance(arg, MatrixExpr) for arg in args):
raise TypeError("Mix of Matrix and Scalar symbols")
obj = Basic.__new__(cls, *args)
if check is not None:
sympy_deprecation_warning(
"Passing check to MatAdd is deprecated and the check argument will be removed in a future version.",
deprecated_since_version="1.11",
active_deprecations_target='remove-check-argument-from-matrix-operations')
if check is not False:
validate(*args)
if evaluate:
obj = cls._evaluate(obj)
return obj
@classmethod
def _evaluate(cls, expr):
return canonicalize(expr)
@property
def shape(self):
return self.args[0].shape
def could_extract_minus_sign(self):
return _could_extract_minus_sign(self)
def expand(self, **kwargs):
expanded = super(MatAdd, self).expand(**kwargs)
return self._evaluate(expanded)
def _entry(self, i, j, **kwargs):
return Add(*[arg._entry(i, j, **kwargs) for arg in self.args])
def _eval_transpose(self):
return MatAdd(*[transpose(arg) for arg in self.args]).doit()
def _eval_adjoint(self):
return MatAdd(*[adjoint(arg) for arg in self.args]).doit()
def _eval_trace(self):
from .trace import trace
return Add(*[trace(arg) for arg in self.args]).doit()
def doit(self, **hints):
deep = hints.get('deep', True)
if deep:
args = [arg.doit(**hints) for arg in self.args]
else:
args = self.args
return canonicalize(MatAdd(*args))
def _eval_derivative_matrix_lines(self, x):
add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args]
return [j for i in add_lines for j in i]
add.register_handlerclass((Add, MatAdd), MatAdd)
factor_of = lambda arg: arg.as_coeff_mmul()[0]
matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1])
def combine(cnt, mat):
if cnt == 1:
return mat
else:
return cnt * mat
def merge_explicit(matadd):
""" Merge explicit MatrixBase arguments
Examples
========
>>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
>>> from sympy.matrices.expressions.matadd import merge_explicit
>>> A = MatrixSymbol('A', 2, 2)
>>> B = eye(2)
>>> C = Matrix([[1, 2], [3, 4]])
>>> X = MatAdd(A, B, C)
>>> pprint(X)
[1 0] [1 2]
A + [ ] + [ ]
[0 1] [3 4]
>>> pprint(merge_explicit(X))
[2 2]
A + [ ]
[3 5]
"""
groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
if len(groups[True]) > 1:
return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])]))
else:
return matadd
rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)),
unpack,
flatten,
glom(matrix_of, factor_of, combine),
merge_explicit,
sort(default_sort_key))
canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
do_one(*rules)))
|