Spaces:
Sleeping
Sleeping
from sympy.core import S | |
from sympy.core.function import Lambda | |
from sympy.core.power import Pow | |
from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter | |
from .codeprinter import CodePrinter | |
_not_in_numpy = 'erf erfc factorial gamma loggamma'.split() | |
_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] | |
_known_functions_numpy = dict(_in_numpy, **{ | |
'acos': 'arccos', | |
'acosh': 'arccosh', | |
'asin': 'arcsin', | |
'asinh': 'arcsinh', | |
'atan': 'arctan', | |
'atan2': 'arctan2', | |
'atanh': 'arctanh', | |
'exp2': 'exp2', | |
'sign': 'sign', | |
'logaddexp': 'logaddexp', | |
'logaddexp2': 'logaddexp2', | |
'isnan': 'isnan' | |
}) | |
_known_constants_numpy = { | |
'Exp1': 'e', | |
'Pi': 'pi', | |
'EulerGamma': 'euler_gamma', | |
'NaN': 'nan', | |
'Infinity': 'inf', | |
} | |
_numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()} | |
_numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()} | |
class NumPyPrinter(ArrayPrinter, PythonCodePrinter): | |
""" | |
Numpy printer which handles vectorized piecewise functions, | |
logical operators, etc. | |
""" | |
_module = 'numpy' | |
_kf = _numpy_known_functions | |
_kc = _numpy_known_constants | |
def __init__(self, settings=None): | |
""" | |
`settings` is passed to CodePrinter.__init__() | |
`module` specifies the array module to use, currently 'NumPy', 'CuPy' | |
or 'JAX'. | |
""" | |
self.language = "Python with {}".format(self._module) | |
self.printmethod = "_{}code".format(self._module) | |
self._kf = {**PythonCodePrinter._kf, **self._kf} | |
super().__init__(settings=settings) | |
def _print_seq(self, seq): | |
"General sequence printer: converts to tuple" | |
# Print tuples here instead of lists because numba supports | |
# tuples in nopython mode. | |
delimiter=', ' | |
return '({},)'.format(delimiter.join(self._print(item) for item in seq)) | |
def _print_NegativeInfinity(self, expr): | |
return '-' + self._print(S.Infinity) | |
def _print_MatMul(self, expr): | |
"Matrix multiplication printer" | |
if expr.as_coeff_matrices()[0] is not S.One: | |
expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])] | |
return '({})'.format(').dot('.join(self._print(i) for i in expr_list)) | |
return '({})'.format(').dot('.join(self._print(i) for i in expr.args)) | |
def _print_MatPow(self, expr): | |
"Matrix power printer" | |
return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'), | |
self._print(expr.args[0]), self._print(expr.args[1])) | |
def _print_Inverse(self, expr): | |
"Matrix inverse printer" | |
return '{}({})'.format(self._module_format(self._module + '.linalg.inv'), | |
self._print(expr.args[0])) | |
def _print_DotProduct(self, expr): | |
# DotProduct allows any shape order, but numpy.dot does matrix | |
# multiplication, so we have to make sure it gets 1 x n by n x 1. | |
arg1, arg2 = expr.args | |
if arg1.shape[0] != 1: | |
arg1 = arg1.T | |
if arg2.shape[1] != 1: | |
arg2 = arg2.T | |
return "%s(%s, %s)" % (self._module_format(self._module + '.dot'), | |
self._print(arg1), | |
self._print(arg2)) | |
def _print_MatrixSolve(self, expr): | |
return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'), | |
self._print(expr.matrix), | |
self._print(expr.vector)) | |
def _print_ZeroMatrix(self, expr): | |
return '{}({})'.format(self._module_format(self._module + '.zeros'), | |
self._print(expr.shape)) | |
def _print_OneMatrix(self, expr): | |
return '{}({})'.format(self._module_format(self._module + '.ones'), | |
self._print(expr.shape)) | |
def _print_FunctionMatrix(self, expr): | |
from sympy.abc import i, j | |
lamda = expr.lamda | |
if not isinstance(lamda, Lambda): | |
lamda = Lambda((i, j), lamda(i, j)) | |
return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'), | |
', '.join(self._print(arg) for arg in lamda.args[0]), | |
self._print(lamda.args[1]), self._print(expr.shape)) | |
def _print_HadamardProduct(self, expr): | |
func = self._module_format(self._module + '.multiply') | |
return ''.join('{}({}, '.format(func, self._print(arg)) \ | |
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), | |
')' * (len(expr.args) - 1)) | |
def _print_KroneckerProduct(self, expr): | |
func = self._module_format(self._module + '.kron') | |
return ''.join('{}({}, '.format(func, self._print(arg)) \ | |
for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), | |
')' * (len(expr.args) - 1)) | |
def _print_Adjoint(self, expr): | |
return '{}({}({}))'.format( | |
self._module_format(self._module + '.conjugate'), | |
self._module_format(self._module + '.transpose'), | |
self._print(expr.args[0])) | |
def _print_DiagonalOf(self, expr): | |
vect = '{}({})'.format( | |
self._module_format(self._module + '.diag'), | |
self._print(expr.arg)) | |
return '{}({}, (-1, 1))'.format( | |
self._module_format(self._module + '.reshape'), vect) | |
def _print_DiagMatrix(self, expr): | |
return '{}({})'.format(self._module_format(self._module + '.diagflat'), | |
self._print(expr.args[0])) | |
def _print_DiagonalMatrix(self, expr): | |
return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'), | |
self._print(expr.arg), self._module_format(self._module + '.eye'), | |
self._print(expr.shape[0]), self._print(expr.shape[1])) | |
def _print_Piecewise(self, expr): | |
"Piecewise function printer" | |
from sympy.logic.boolalg import ITE, simplify_logic | |
def print_cond(cond): | |
""" Problem having an ITE in the cond. """ | |
if cond.has(ITE): | |
return self._print(simplify_logic(cond)) | |
else: | |
return self._print(cond) | |
exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) | |
conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args)) | |
# If [default_value, True] is a (expr, cond) sequence in a Piecewise object | |
# it will behave the same as passing the 'default' kwarg to select() | |
# *as long as* it is the last element in expr.args. | |
# If this is not the case, it may be triggered prematurely. | |
return '{}({}, {}, default={})'.format( | |
self._module_format(self._module + '.select'), conds, exprs, | |
self._print(S.NaN)) | |
def _print_Relational(self, expr): | |
"Relational printer for Equality and Unequality" | |
op = { | |
'==' :'equal', | |
'!=' :'not_equal', | |
'<' :'less', | |
'<=' :'less_equal', | |
'>' :'greater', | |
'>=' :'greater_equal', | |
} | |
if expr.rel_op in op: | |
lhs = self._print(expr.lhs) | |
rhs = self._print(expr.rhs) | |
return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]), | |
lhs=lhs, rhs=rhs) | |
return super()._print_Relational(expr) | |
def _print_And(self, expr): | |
"Logical And printer" | |
# We have to override LambdaPrinter because it uses Python 'and' keyword. | |
# If LambdaPrinter didn't define it, we could use StrPrinter's | |
# version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. | |
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args)) | |
def _print_Or(self, expr): | |
"Logical Or printer" | |
# We have to override LambdaPrinter because it uses Python 'or' keyword. | |
# If LambdaPrinter didn't define it, we could use StrPrinter's | |
# version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. | |
return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args)) | |
def _print_Not(self, expr): | |
"Logical Not printer" | |
# We have to override LambdaPrinter because it uses Python 'not' keyword. | |
# If LambdaPrinter didn't define it, we would still have to define our | |
# own because StrPrinter doesn't define it. | |
return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args)) | |
def _print_Pow(self, expr, rational=False): | |
# XXX Workaround for negative integer power error | |
if expr.exp.is_integer and expr.exp.is_negative: | |
expr = Pow(expr.base, expr.exp.evalf(), evaluate=False) | |
return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt') | |
def _print_Min(self, expr): | |
return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amin'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args)) | |
def _print_Max(self, expr): | |
return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amax'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args)) | |
def _print_arg(self, expr): | |
return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0])) | |
def _print_im(self, expr): | |
return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0])) | |
def _print_Mod(self, expr): | |
return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join( | |
(self._print(arg) for arg in expr.args))) | |
def _print_re(self, expr): | |
return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0])) | |
def _print_sinc(self, expr): | |
return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi)) | |
def _print_MatrixBase(self, expr): | |
func = self.known_functions.get(expr.__class__.__name__, None) | |
if func is None: | |
func = self._module_format(self._module + '.array') | |
return "%s(%s)" % (func, self._print(expr.tolist())) | |
def _print_Identity(self, expr): | |
shape = expr.shape | |
if all(dim.is_Integer for dim in shape): | |
return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0])) | |
else: | |
raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices") | |
def _print_BlockMatrix(self, expr): | |
return '{}({})'.format(self._module_format(self._module + '.block'), | |
self._print(expr.args[0].tolist())) | |
def _print_NDimArray(self, expr): | |
if len(expr.shape) == 1: | |
return self._module + '.array(' + self._print(expr.args[0]) + ')' | |
if len(expr.shape) == 2: | |
return self._print(expr.tomatrix()) | |
# Should be possible to extend to more dimensions | |
return super()._print_not_supported(self, expr) | |
_add = "add" | |
_einsum = "einsum" | |
_transpose = "transpose" | |
_ones = "ones" | |
_zeros = "zeros" | |
_print_lowergamma = CodePrinter._print_not_supported | |
_print_uppergamma = CodePrinter._print_not_supported | |
_print_fresnelc = CodePrinter._print_not_supported | |
_print_fresnels = CodePrinter._print_not_supported | |
for func in _numpy_known_functions: | |
setattr(NumPyPrinter, f'_print_{func}', _print_known_func) | |
for const in _numpy_known_constants: | |
setattr(NumPyPrinter, f'_print_{const}', _print_known_const) | |
_known_functions_scipy_special = { | |
'Ei': 'expi', | |
'erf': 'erf', | |
'erfc': 'erfc', | |
'besselj': 'jv', | |
'bessely': 'yv', | |
'besseli': 'iv', | |
'besselk': 'kv', | |
'cosm1': 'cosm1', | |
'powm1': 'powm1', | |
'factorial': 'factorial', | |
'gamma': 'gamma', | |
'loggamma': 'gammaln', | |
'digamma': 'psi', | |
'polygamma': 'polygamma', | |
'RisingFactorial': 'poch', | |
'jacobi': 'eval_jacobi', | |
'gegenbauer': 'eval_gegenbauer', | |
'chebyshevt': 'eval_chebyt', | |
'chebyshevu': 'eval_chebyu', | |
'legendre': 'eval_legendre', | |
'hermite': 'eval_hermite', | |
'laguerre': 'eval_laguerre', | |
'assoc_laguerre': 'eval_genlaguerre', | |
'beta': 'beta', | |
'LambertW' : 'lambertw', | |
} | |
_known_constants_scipy_constants = { | |
'GoldenRatio': 'golden_ratio', | |
'Pi': 'pi', | |
} | |
_scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()} | |
_scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()} | |
class SciPyPrinter(NumPyPrinter): | |
_kf = {**NumPyPrinter._kf, **_scipy_known_functions} | |
_kc = {**NumPyPrinter._kc, **_scipy_known_constants} | |
def __init__(self, settings=None): | |
super().__init__(settings=settings) | |
self.language = "Python with SciPy and NumPy" | |
def _print_SparseRepMatrix(self, expr): | |
i, j, data = [], [], [] | |
for (r, c), v in expr.todok().items(): | |
i.append(r) | |
j.append(c) | |
data.append(v) | |
return "{name}(({data}, ({i}, {j})), shape={shape})".format( | |
name=self._module_format('scipy.sparse.coo_matrix'), | |
data=data, i=i, j=j, shape=expr.shape | |
) | |
_print_ImmutableSparseMatrix = _print_SparseRepMatrix | |
# SciPy's lpmv has a different order of arguments from assoc_legendre | |
def _print_assoc_legendre(self, expr): | |
return "{0}({2}, {1}, {3})".format( | |
self._module_format('scipy.special.lpmv'), | |
self._print(expr.args[0]), | |
self._print(expr.args[1]), | |
self._print(expr.args[2])) | |
def _print_lowergamma(self, expr): | |
return "{0}({2})*{1}({2}, {3})".format( | |
self._module_format('scipy.special.gamma'), | |
self._module_format('scipy.special.gammainc'), | |
self._print(expr.args[0]), | |
self._print(expr.args[1])) | |
def _print_uppergamma(self, expr): | |
return "{0}({2})*{1}({2}, {3})".format( | |
self._module_format('scipy.special.gamma'), | |
self._module_format('scipy.special.gammaincc'), | |
self._print(expr.args[0]), | |
self._print(expr.args[1])) | |
def _print_betainc(self, expr): | |
betainc = self._module_format('scipy.special.betainc') | |
beta = self._module_format('scipy.special.beta') | |
args = [self._print(arg) for arg in expr.args] | |
return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \ | |
* {beta}({args[0]}, {args[1]})" | |
def _print_betainc_regularized(self, expr): | |
return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format( | |
self._module_format('scipy.special.betainc'), | |
self._print(expr.args[0]), | |
self._print(expr.args[1]), | |
self._print(expr.args[2]), | |
self._print(expr.args[3])) | |
def _print_fresnels(self, expr): | |
return "{}({})[0]".format( | |
self._module_format("scipy.special.fresnel"), | |
self._print(expr.args[0])) | |
def _print_fresnelc(self, expr): | |
return "{}({})[1]".format( | |
self._module_format("scipy.special.fresnel"), | |
self._print(expr.args[0])) | |
def _print_airyai(self, expr): | |
return "{}({})[0]".format( | |
self._module_format("scipy.special.airy"), | |
self._print(expr.args[0])) | |
def _print_airyaiprime(self, expr): | |
return "{}({})[1]".format( | |
self._module_format("scipy.special.airy"), | |
self._print(expr.args[0])) | |
def _print_airybi(self, expr): | |
return "{}({})[2]".format( | |
self._module_format("scipy.special.airy"), | |
self._print(expr.args[0])) | |
def _print_airybiprime(self, expr): | |
return "{}({})[3]".format( | |
self._module_format("scipy.special.airy"), | |
self._print(expr.args[0])) | |
def _print_bernoulli(self, expr): | |
# scipy's bernoulli is inconsistent with SymPy's so rewrite | |
return self._print(expr._eval_rewrite_as_zeta(*expr.args)) | |
def _print_harmonic(self, expr): | |
return self._print(expr._eval_rewrite_as_zeta(*expr.args)) | |
def _print_Integral(self, e): | |
integration_vars, limits = _unpack_integral_limits(e) | |
if len(limits) == 1: | |
# nicer (but not necessary) to prefer quad over nquad for 1D case | |
module_str = self._module_format("scipy.integrate.quad") | |
limit_str = "%s, %s" % tuple(map(self._print, limits[0])) | |
else: | |
module_str = self._module_format("scipy.integrate.nquad") | |
limit_str = "({})".format(", ".join( | |
"(%s, %s)" % tuple(map(self._print, l)) for l in limits)) | |
return "{}(lambda {}: {}, {})[0]".format( | |
module_str, | |
", ".join(map(self._print, integration_vars)), | |
self._print(e.args[0]), | |
limit_str) | |
def _print_Si(self, expr): | |
return "{}({})[0]".format( | |
self._module_format("scipy.special.sici"), | |
self._print(expr.args[0])) | |
def _print_Ci(self, expr): | |
return "{}({})[1]".format( | |
self._module_format("scipy.special.sici"), | |
self._print(expr.args[0])) | |
for func in _scipy_known_functions: | |
setattr(SciPyPrinter, f'_print_{func}', _print_known_func) | |
for const in _scipy_known_constants: | |
setattr(SciPyPrinter, f'_print_{const}', _print_known_const) | |
_cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()} | |
_cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()} | |
class CuPyPrinter(NumPyPrinter): | |
""" | |
CuPy printer which handles vectorized piecewise functions, | |
logical operators, etc. | |
""" | |
_module = 'cupy' | |
_kf = _cupy_known_functions | |
_kc = _cupy_known_constants | |
def __init__(self, settings=None): | |
super().__init__(settings=settings) | |
for func in _cupy_known_functions: | |
setattr(CuPyPrinter, f'_print_{func}', _print_known_func) | |
for const in _cupy_known_constants: | |
setattr(CuPyPrinter, f'_print_{const}', _print_known_const) | |
_jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()} | |
_jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()} | |
class JaxPrinter(NumPyPrinter): | |
""" | |
JAX printer which handles vectorized piecewise functions, | |
logical operators, etc. | |
""" | |
_module = "jax.numpy" | |
_kf = _jax_known_functions | |
_kc = _jax_known_constants | |
def __init__(self, settings=None): | |
super().__init__(settings=settings) | |
self.printmethod = '_jaxcode' | |
# These need specific override to allow for the lack of "jax.numpy.reduce" | |
def _print_And(self, expr): | |
"Logical And printer" | |
return "{}({}.asarray([{}]), axis=0)".format( | |
self._module_format(self._module + ".all"), | |
self._module_format(self._module), | |
",".join(self._print(i) for i in expr.args), | |
) | |
def _print_Or(self, expr): | |
"Logical Or printer" | |
return "{}({}.asarray([{}]), axis=0)".format( | |
self._module_format(self._module + ".any"), | |
self._module_format(self._module), | |
",".join(self._print(i) for i in expr.args), | |
) | |
for func in _jax_known_functions: | |
setattr(JaxPrinter, f'_print_{func}', _print_known_func) | |
for const in _jax_known_constants: | |
setattr(JaxPrinter, f'_print_{const}', _print_known_const) | |