Spaces:
Sleeping
Sleeping
from sympy.core.relational import Eq | |
from sympy.core.expr import Expr | |
from sympy.core.numbers import Integer | |
from sympy.logic.boolalg import Boolean, And | |
from sympy.matrices.expressions.matexpr import MatrixExpr | |
from sympy.matrices.exceptions import ShapeError | |
from typing import Union | |
def is_matadd_valid(*args: MatrixExpr) -> Boolean: | |
"""Return the symbolic condition how ``MatAdd``, ``HadamardProduct`` | |
makes sense. | |
Parameters | |
========== | |
args | |
The list of arguments of matrices to be tested for. | |
Examples | |
======== | |
>>> from sympy import MatrixSymbol, symbols | |
>>> from sympy.matrices.expressions._shape import is_matadd_valid | |
>>> m, n, p, q = symbols('m n p q') | |
>>> A = MatrixSymbol('A', m, n) | |
>>> B = MatrixSymbol('B', p, q) | |
>>> is_matadd_valid(A, B) | |
Eq(m, p) & Eq(n, q) | |
""" | |
rows, cols = zip(*(arg.shape for arg in args)) | |
return And( | |
*(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])), | |
*(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])), | |
) | |
def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean: | |
"""Return the symbolic condition how ``MatMul`` makes sense | |
Parameters | |
========== | |
args | |
The list of arguments of matrices and scalar expressions to be tested | |
for. | |
Examples | |
======== | |
>>> from sympy import MatrixSymbol, symbols | |
>>> from sympy.matrices.expressions._shape import is_matmul_valid | |
>>> m, n, p, q = symbols('m n p q') | |
>>> A = MatrixSymbol('A', m, n) | |
>>> B = MatrixSymbol('B', p, q) | |
>>> is_matmul_valid(A, B) | |
Eq(n, p) | |
""" | |
rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr))) | |
return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:]))) | |
def is_square(arg: MatrixExpr, /) -> Boolean: | |
"""Return the symbolic condition how the matrix is assumed to be square | |
Parameters | |
========== | |
arg | |
The matrix to be tested for. | |
Examples | |
======== | |
>>> from sympy import MatrixSymbol, symbols | |
>>> from sympy.matrices.expressions._shape import is_square | |
>>> m, n = symbols('m n') | |
>>> A = MatrixSymbol('A', m, n) | |
>>> is_square(A) | |
Eq(m, n) | |
""" | |
return Eq(arg.rows, arg.cols) | |
def validate_matadd_integer(*args: MatrixExpr) -> None: | |
"""Validate matrix shape for addition only for integer values""" | |
rows, cols = zip(*(x.shape for x in args)) | |
if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1: | |
raise ShapeError(f"Matrices have mismatching shape: {rows}") | |
if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1: | |
raise ShapeError(f"Matrices have mismatching shape: {cols}") | |
def validate_matmul_integer(*args: MatrixExpr) -> None: | |
"""Validate matrix shape for multiplication only for integer values""" | |
for A, B in zip(args[:-1], args[1:]): | |
i, j = A.cols, B.rows | |
if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j: | |
raise ShapeError("Matrices are not aligned", i, j) | |