Spaces:
Sleeping
Sleeping
from functools import reduce | |
import operator | |
from sympy.core import Basic, sympify | |
from sympy.core.add import add, Add, _could_extract_minus_sign | |
from sympy.core.sorting import default_sort_key | |
from sympy.functions import adjoint | |
from sympy.matrices.matrixbase import MatrixBase | |
from sympy.matrices.expressions.transpose import transpose | |
from sympy.strategies import (rm_id, unpack, flatten, sort, condition, | |
exhaust, do_one, glom) | |
from sympy.matrices.expressions.matexpr import MatrixExpr | |
from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix | |
from sympy.matrices.expressions._shape import validate_matadd_integer as validate | |
from sympy.utilities.iterables import sift | |
from sympy.utilities.exceptions import sympy_deprecation_warning | |
# XXX: MatAdd should perhaps not subclass directly from Add | |
class MatAdd(MatrixExpr, Add): | |
"""A Sum of Matrix Expressions | |
MatAdd inherits from and operates like SymPy Add | |
Examples | |
======== | |
>>> from sympy import MatAdd, MatrixSymbol | |
>>> A = MatrixSymbol('A', 5, 5) | |
>>> B = MatrixSymbol('B', 5, 5) | |
>>> C = MatrixSymbol('C', 5, 5) | |
>>> MatAdd(A, B, C) | |
A + B + C | |
""" | |
is_MatAdd = True | |
identity = GenericZeroMatrix() | |
def __new__(cls, *args, evaluate=False, check=None, _sympify=True): | |
if not args: | |
return cls.identity | |
# This must be removed aggressively in the constructor to avoid | |
# TypeErrors from GenericZeroMatrix().shape | |
args = list(filter(lambda i: cls.identity != i, args)) | |
if _sympify: | |
args = list(map(sympify, args)) | |
if not all(isinstance(arg, MatrixExpr) for arg in args): | |
raise TypeError("Mix of Matrix and Scalar symbols") | |
obj = Basic.__new__(cls, *args) | |
if check is not None: | |
sympy_deprecation_warning( | |
"Passing check to MatAdd is deprecated and the check argument will be removed in a future version.", | |
deprecated_since_version="1.11", | |
active_deprecations_target='remove-check-argument-from-matrix-operations') | |
if check is not False: | |
validate(*args) | |
if evaluate: | |
obj = cls._evaluate(obj) | |
return obj | |
def _evaluate(cls, expr): | |
return canonicalize(expr) | |
def shape(self): | |
return self.args[0].shape | |
def could_extract_minus_sign(self): | |
return _could_extract_minus_sign(self) | |
def expand(self, **kwargs): | |
expanded = super(MatAdd, self).expand(**kwargs) | |
return self._evaluate(expanded) | |
def _entry(self, i, j, **kwargs): | |
return Add(*[arg._entry(i, j, **kwargs) for arg in self.args]) | |
def _eval_transpose(self): | |
return MatAdd(*[transpose(arg) for arg in self.args]).doit() | |
def _eval_adjoint(self): | |
return MatAdd(*[adjoint(arg) for arg in self.args]).doit() | |
def _eval_trace(self): | |
from .trace import trace | |
return Add(*[trace(arg) for arg in self.args]).doit() | |
def doit(self, **hints): | |
deep = hints.get('deep', True) | |
if deep: | |
args = [arg.doit(**hints) for arg in self.args] | |
else: | |
args = self.args | |
return canonicalize(MatAdd(*args)) | |
def _eval_derivative_matrix_lines(self, x): | |
add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args] | |
return [j for i in add_lines for j in i] | |
add.register_handlerclass((Add, MatAdd), MatAdd) | |
factor_of = lambda arg: arg.as_coeff_mmul()[0] | |
matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) | |
def combine(cnt, mat): | |
if cnt == 1: | |
return mat | |
else: | |
return cnt * mat | |
def merge_explicit(matadd): | |
""" Merge explicit MatrixBase arguments | |
Examples | |
======== | |
>>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint | |
>>> from sympy.matrices.expressions.matadd import merge_explicit | |
>>> A = MatrixSymbol('A', 2, 2) | |
>>> B = eye(2) | |
>>> C = Matrix([[1, 2], [3, 4]]) | |
>>> X = MatAdd(A, B, C) | |
>>> pprint(X) | |
[1 0] [1 2] | |
A + [ ] + [ ] | |
[0 1] [3 4] | |
>>> pprint(merge_explicit(X)) | |
[2 2] | |
A + [ ] | |
[3 5] | |
""" | |
groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) | |
if len(groups[True]) > 1: | |
return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])])) | |
else: | |
return matadd | |
rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), | |
unpack, | |
flatten, | |
glom(matrix_of, factor_of, combine), | |
merge_explicit, | |
sort(default_sort_key)) | |
canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), | |
do_one(*rules))) | |