Spaces:
Sleeping
Sleeping
from sympy.assumptions.ask import ask, Q | |
from sympy.assumptions.refine import handlers_dict | |
from sympy.core import Basic, sympify, S | |
from sympy.core.mul import mul, Mul | |
from sympy.core.numbers import Number, Integer | |
from sympy.core.symbol import Dummy | |
from sympy.functions import adjoint | |
from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust, | |
do_one, new) | |
from sympy.matrices.exceptions import NonInvertibleMatrixError | |
from sympy.matrices.matrixbase import MatrixBase | |
from sympy.utilities.exceptions import sympy_deprecation_warning | |
from sympy.matrices.expressions._shape import validate_matmul_integer as validate | |
from .inverse import Inverse | |
from .matexpr import MatrixExpr | |
from .matpow import MatPow | |
from .transpose import transpose | |
from .permutation import PermutationMatrix | |
from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix | |
# XXX: MatMul should perhaps not subclass directly from Mul | |
class MatMul(MatrixExpr, Mul): | |
""" | |
A product of matrix expressions | |
Examples | |
======== | |
>>> from sympy import MatMul, MatrixSymbol | |
>>> A = MatrixSymbol('A', 5, 4) | |
>>> B = MatrixSymbol('B', 4, 3) | |
>>> C = MatrixSymbol('C', 3, 6) | |
>>> MatMul(A, B, C) | |
A*B*C | |
""" | |
is_MatMul = True | |
identity = GenericIdentity() | |
def __new__(cls, *args, evaluate=False, check=None, _sympify=True): | |
if not args: | |
return cls.identity | |
# This must be removed aggressively in the constructor to avoid | |
# TypeErrors from GenericIdentity().shape | |
args = list(filter(lambda i: cls.identity != i, args)) | |
if _sympify: | |
args = list(map(sympify, args)) | |
obj = Basic.__new__(cls, *args) | |
factor, matrices = obj.as_coeff_matrices() | |
if check is not None: | |
sympy_deprecation_warning( | |
"Passing check to MatMul is deprecated and the check argument will be removed in a future version.", | |
deprecated_since_version="1.11", | |
active_deprecations_target='remove-check-argument-from-matrix-operations') | |
if check is not False: | |
validate(*matrices) | |
if not matrices: | |
# Should it be | |
# | |
# return Basic.__neq__(cls, factor, GenericIdentity()) ? | |
return factor | |
if evaluate: | |
return cls._evaluate(obj) | |
return obj | |
def _evaluate(cls, expr): | |
return canonicalize(expr) | |
def shape(self): | |
matrices = [arg for arg in self.args if arg.is_Matrix] | |
return (matrices[0].rows, matrices[-1].cols) | |
def _entry(self, i, j, expand=True, **kwargs): | |
# Avoid cyclic imports | |
from sympy.concrete.summations import Sum | |
from sympy.matrices.immutable import ImmutableMatrix | |
coeff, matrices = self.as_coeff_matrices() | |
if len(matrices) == 1: # situation like 2*X, matmul is just X | |
return coeff * matrices[0][i, j] | |
indices = [None]*(len(matrices) + 1) | |
ind_ranges = [None]*(len(matrices) - 1) | |
indices[0] = i | |
indices[-1] = j | |
def f(): | |
counter = 1 | |
while True: | |
yield Dummy("i_%i" % counter) | |
counter += 1 | |
dummy_generator = kwargs.get("dummy_generator", f()) | |
for i in range(1, len(matrices)): | |
indices[i] = next(dummy_generator) | |
for i, arg in enumerate(matrices[:-1]): | |
ind_ranges[i] = arg.shape[1] - 1 | |
matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)] | |
expr_in_sum = Mul.fromiter(matrices) | |
if any(v.has(ImmutableMatrix) for v in matrices): | |
expand = True | |
result = coeff*Sum( | |
expr_in_sum, | |
*zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges) | |
) | |
# Don't waste time in result.doit() if the sum bounds are symbolic | |
if not any(isinstance(v, (Integer, int)) for v in ind_ranges): | |
expand = False | |
return result.doit() if expand else result | |
def as_coeff_matrices(self): | |
scalars = [x for x in self.args if not x.is_Matrix] | |
matrices = [x for x in self.args if x.is_Matrix] | |
coeff = Mul(*scalars) | |
if coeff.is_commutative is False: | |
raise NotImplementedError("noncommutative scalars in MatMul are not supported.") | |
return coeff, matrices | |
def as_coeff_mmul(self): | |
coeff, matrices = self.as_coeff_matrices() | |
return coeff, MatMul(*matrices) | |
def expand(self, **kwargs): | |
expanded = super(MatMul, self).expand(**kwargs) | |
return self._evaluate(expanded) | |
def _eval_transpose(self): | |
"""Transposition of matrix multiplication. | |
Notes | |
===== | |
The following rules are applied. | |
Transposition for matrix multiplied with another matrix: | |
`\\left(A B\\right)^{T} = B^{T} A^{T}` | |
Transposition for matrix multiplied with scalar: | |
`\\left(c A\\right)^{T} = c A^{T}` | |
References | |
========== | |
.. [1] https://en.wikipedia.org/wiki/Transpose | |
""" | |
coeff, matrices = self.as_coeff_matrices() | |
return MatMul( | |
coeff, *[transpose(arg) for arg in matrices[::-1]]).doit() | |
def _eval_adjoint(self): | |
return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() | |
def _eval_trace(self): | |
factor, mmul = self.as_coeff_mmul() | |
if factor != 1: | |
from .trace import trace | |
return factor * trace(mmul.doit()) | |
def _eval_determinant(self): | |
from sympy.matrices.expressions.determinant import Determinant | |
factor, matrices = self.as_coeff_matrices() | |
square_matrices = only_squares(*matrices) | |
return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) | |
def _eval_inverse(self): | |
if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)): | |
return MatMul(*( | |
arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 | |
for arg in self.args[::-1] | |
) | |
).doit() | |
return Inverse(self) | |
def doit(self, **hints): | |
deep = hints.get('deep', True) | |
if deep: | |
args = tuple(arg.doit(**hints) for arg in self.args) | |
else: | |
args = self.args | |
# treat scalar*MatrixSymbol or scalar*MatPow separately | |
expr = canonicalize(MatMul(*args)) | |
return expr | |
# Needed for partial compatibility with Mul | |
def args_cnc(self, cset=False, warn=True, **kwargs): | |
coeff_c = [x for x in self.args if x.is_commutative] | |
coeff_nc = [x for x in self.args if not x.is_commutative] | |
if cset: | |
clen = len(coeff_c) | |
coeff_c = set(coeff_c) | |
if clen and warn and len(coeff_c) != clen: | |
raise ValueError('repeated commutative arguments: %s' % | |
[ci for ci in coeff_c if list(self.args).count(ci) > 1]) | |
return [coeff_c, coeff_nc] | |
def _eval_derivative_matrix_lines(self, x): | |
from .transpose import Transpose | |
with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] | |
lines = [] | |
for ind in with_x_ind: | |
left_args = self.args[:ind] | |
right_args = self.args[ind+1:] | |
if right_args: | |
right_mat = MatMul.fromiter(right_args) | |
else: | |
right_mat = Identity(self.shape[1]) | |
if left_args: | |
left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)]) | |
else: | |
left_rev = Identity(self.shape[0]) | |
d = self.args[ind]._eval_derivative_matrix_lines(x) | |
for i in d: | |
i.append_first(left_rev) | |
i.append_second(right_mat) | |
lines.append(i) | |
return lines | |
mul.register_handlerclass((Mul, MatMul), MatMul) | |
# Rules | |
def newmul(*args): | |
if args[0] == 1: | |
args = args[1:] | |
return new(MatMul, *args) | |
def any_zeros(mul): | |
if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) | |
for arg in mul.args): | |
matrices = [arg for arg in mul.args if arg.is_Matrix] | |
return ZeroMatrix(matrices[0].rows, matrices[-1].cols) | |
return mul | |
def merge_explicit(matmul): | |
""" Merge explicit MatrixBase arguments | |
>>> from sympy import MatrixSymbol, Matrix, MatMul, pprint | |
>>> from sympy.matrices.expressions.matmul import merge_explicit | |
>>> A = MatrixSymbol('A', 2, 2) | |
>>> B = Matrix([[1, 1], [1, 1]]) | |
>>> C = Matrix([[1, 2], [3, 4]]) | |
>>> X = MatMul(A, B, C) | |
>>> pprint(X) | |
[1 1] [1 2] | |
A*[ ]*[ ] | |
[1 1] [3 4] | |
>>> pprint(merge_explicit(X)) | |
[4 6] | |
A*[ ] | |
[4 6] | |
>>> X = MatMul(B, A, C) | |
>>> pprint(X) | |
[1 1] [1 2] | |
[ ]*A*[ ] | |
[1 1] [3 4] | |
>>> pprint(merge_explicit(X)) | |
[1 1] [1 2] | |
[ ]*A*[ ] | |
[1 1] [3 4] | |
""" | |
if not any(isinstance(arg, MatrixBase) for arg in matmul.args): | |
return matmul | |
newargs = [] | |
last = matmul.args[0] | |
for arg in matmul.args[1:]: | |
if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): | |
last = last * arg | |
else: | |
newargs.append(last) | |
last = arg | |
newargs.append(last) | |
return MatMul(*newargs) | |
def remove_ids(mul): | |
""" Remove Identities from a MatMul | |
This is a modified version of sympy.strategies.rm_id. | |
This is necesssary because MatMul may contain both MatrixExprs and Exprs | |
as args. | |
See Also | |
======== | |
sympy.strategies.rm_id | |
""" | |
# Separate Exprs from MatrixExprs in args | |
factor, mmul = mul.as_coeff_mmul() | |
# Apply standard rm_id for MatMuls | |
result = rm_id(lambda x: x.is_Identity is True)(mmul) | |
if result != mmul: | |
return newmul(factor, *result.args) # Recombine and return | |
else: | |
return mul | |
def factor_in_front(mul): | |
factor, matrices = mul.as_coeff_matrices() | |
if factor != 1: | |
return newmul(factor, *matrices) | |
return mul | |
def combine_powers(mul): | |
r"""Combine consecutive powers with the same base into one, e.g. | |
$$A \times A^2 \Rightarrow A^3$$ | |
This also cancels out the possible matrix inverses using the | |
knowledgebase of :class:`~.Inverse`, e.g., | |
$$ Y \times X \times X^{-1} \Rightarrow Y $$ | |
""" | |
factor, args = mul.as_coeff_matrices() | |
new_args = [args[0]] | |
for i in range(1, len(args)): | |
A = new_args[-1] | |
B = args[i] | |
if isinstance(B, Inverse) and isinstance(B.arg, MatMul): | |
Bargs = B.arg.args | |
l = len(Bargs) | |
if list(Bargs) == new_args[-l:]: | |
new_args = new_args[:-l] + [Identity(B.shape[0])] | |
continue | |
if isinstance(A, Inverse) and isinstance(A.arg, MatMul): | |
Aargs = A.arg.args | |
l = len(Aargs) | |
if list(Aargs) == args[i:i+l]: | |
identity = Identity(A.shape[0]) | |
new_args[-1] = identity | |
for j in range(i, i+l): | |
args[j] = identity | |
continue | |
if A.is_square == False or B.is_square == False: | |
new_args.append(B) | |
continue | |
if isinstance(A, MatPow): | |
A_base, A_exp = A.args | |
else: | |
A_base, A_exp = A, S.One | |
if isinstance(B, MatPow): | |
B_base, B_exp = B.args | |
else: | |
B_base, B_exp = B, S.One | |
if A_base == B_base: | |
new_exp = A_exp + B_exp | |
new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) | |
continue | |
elif not isinstance(B_base, MatrixBase): | |
try: | |
B_base_inv = B_base.inverse() | |
except NonInvertibleMatrixError: | |
B_base_inv = None | |
if B_base_inv is not None and A_base == B_base_inv: | |
new_exp = A_exp - B_exp | |
new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) | |
continue | |
new_args.append(B) | |
return newmul(factor, *new_args) | |
def combine_permutations(mul): | |
"""Refine products of permutation matrices as the products of cycles. | |
""" | |
args = mul.args | |
l = len(args) | |
if l < 2: | |
return mul | |
result = [args[0]] | |
for i in range(1, l): | |
A = result[-1] | |
B = args[i] | |
if isinstance(A, PermutationMatrix) and \ | |
isinstance(B, PermutationMatrix): | |
cycle_1 = A.args[0] | |
cycle_2 = B.args[0] | |
result[-1] = PermutationMatrix(cycle_1 * cycle_2) | |
else: | |
result.append(B) | |
return MatMul(*result) | |
def combine_one_matrices(mul): | |
""" | |
Combine products of OneMatrix | |
e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4) | |
""" | |
factor, args = mul.as_coeff_matrices() | |
new_args = [args[0]] | |
for B in args[1:]: | |
A = new_args[-1] | |
if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix): | |
new_args.append(B) | |
continue | |
new_args.pop() | |
new_args.append(OneMatrix(A.shape[0], B.shape[1])) | |
factor *= A.shape[1] | |
return newmul(factor, *new_args) | |
def distribute_monom(mul): | |
""" | |
Simplify MatMul expressions but distributing | |
rational term to MatMul. | |
e.g. 2*(A+B) -> 2*A + 2*B | |
""" | |
args = mul.args | |
if len(args) == 2: | |
from .matadd import MatAdd | |
if args[0].is_MatAdd and args[1].is_Rational: | |
return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args]) | |
if args[1].is_MatAdd and args[0].is_Rational: | |
return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args]) | |
return mul | |
rules = ( | |
distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1), | |
merge_explicit, factor_in_front, flatten, combine_permutations) | |
canonicalize = exhaust(typed({MatMul: do_one(*rules)})) | |
def only_squares(*matrices): | |
"""factor matrices only if they are square""" | |
if matrices[0].rows != matrices[-1].cols: | |
raise RuntimeError("Invalid matrices being multiplied") | |
out = [] | |
start = 0 | |
for i, M in enumerate(matrices): | |
if M.cols == matrices[start].rows: | |
out.append(MatMul(*matrices[start:i+1]).doit()) | |
start = i+1 | |
return out | |
def refine_MatMul(expr, assumptions): | |
""" | |
>>> from sympy import MatrixSymbol, Q, assuming, refine | |
>>> X = MatrixSymbol('X', 2, 2) | |
>>> expr = X * X.T | |
>>> print(expr) | |
X*X.T | |
>>> with assuming(Q.orthogonal(X)): | |
... print(refine(expr)) | |
I | |
""" | |
newargs = [] | |
exprargs = [] | |
for args in expr.args: | |
if args.is_Matrix: | |
exprargs.append(args) | |
else: | |
newargs.append(args) | |
last = exprargs[0] | |
for arg in exprargs[1:]: | |
if arg == last.T and ask(Q.orthogonal(arg), assumptions): | |
last = Identity(arg.shape[0]) | |
elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): | |
last = Identity(arg.shape[0]) | |
else: | |
newargs.append(last) | |
last = arg | |
newargs.append(last) | |
return MatMul(*newargs) | |
handlers_dict['MatMul'] = refine_MatMul | |