|
""" Generic SymPy-Independent Strategies """ |
|
from __future__ import annotations |
|
from collections.abc import Callable, Mapping |
|
from typing import TypeVar |
|
from sys import stdout |
|
|
|
|
|
_S = TypeVar('_S') |
|
_T = TypeVar('_T') |
|
|
|
|
|
def identity(x: _T) -> _T: |
|
return x |
|
|
|
|
|
def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]: |
|
""" Apply a rule repeatedly until it has no effect """ |
|
def exhaustive_rl(expr: _T) -> _T: |
|
new, old = rule(expr), expr |
|
while new != old: |
|
new, old = rule(new), new |
|
return new |
|
return exhaustive_rl |
|
|
|
|
|
def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]: |
|
"""Memoized version of a rule |
|
|
|
Notes |
|
===== |
|
|
|
This cache can grow infinitely, so it is not recommended to use this |
|
than ``functools.lru_cache`` unless you need very heavy computation. |
|
""" |
|
cache: dict[_S, _T] = {} |
|
|
|
def memoized_rl(expr: _S) -> _T: |
|
if expr in cache: |
|
return cache[expr] |
|
else: |
|
result = rule(expr) |
|
cache[expr] = result |
|
return result |
|
return memoized_rl |
|
|
|
|
|
def condition( |
|
cond: Callable[[_T], bool], rule: Callable[[_T], _T] |
|
) -> Callable[[_T], _T]: |
|
""" Only apply rule if condition is true """ |
|
def conditioned_rl(expr: _T) -> _T: |
|
if cond(expr): |
|
return rule(expr) |
|
return expr |
|
return conditioned_rl |
|
|
|
|
|
def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: |
|
""" |
|
Compose a sequence of rules so that they apply to the expr sequentially |
|
""" |
|
def chain_rl(expr: _T) -> _T: |
|
for rule in rules: |
|
expr = rule(expr) |
|
return expr |
|
return chain_rl |
|
|
|
|
|
def debug(rule, file=None): |
|
""" Print out before and after expressions each time rule is used """ |
|
if file is None: |
|
file = stdout |
|
|
|
def debug_rl(*args, **kwargs): |
|
expr = args[0] |
|
result = rule(*args, **kwargs) |
|
if result != expr: |
|
file.write("Rule: %s\n" % rule.__name__) |
|
file.write("In: %s\nOut: %s\n\n" % (expr, result)) |
|
return result |
|
return debug_rl |
|
|
|
|
|
def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]: |
|
""" Return original expr if rule returns None """ |
|
def null_safe_rl(expr: _T) -> _T: |
|
result = rule(expr) |
|
if result is None: |
|
return expr |
|
return result |
|
return null_safe_rl |
|
|
|
|
|
def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]: |
|
""" Return original expr if rule raises exception """ |
|
def try_rl(expr: _T) -> _T: |
|
try: |
|
return rule(expr) |
|
except exception: |
|
return expr |
|
return try_rl |
|
|
|
|
|
def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: |
|
""" Try each of the rules until one works. Then stop. """ |
|
def do_one_rl(expr: _T) -> _T: |
|
for rl in rules: |
|
result = rl(expr) |
|
if result != expr: |
|
return result |
|
return expr |
|
return do_one_rl |
|
|
|
|
|
def switch( |
|
key: Callable[[_S], _T], |
|
ruledict: Mapping[_T, Callable[[_S], _S]] |
|
) -> Callable[[_S], _S]: |
|
""" Select a rule based on the result of key called on the function """ |
|
def switch_rl(expr: _S) -> _S: |
|
rl = ruledict.get(key(expr), identity) |
|
return rl(expr) |
|
return switch_rl |
|
|
|
|
|
|
|
|
|
def _identity(x): |
|
return x |
|
|
|
|
|
def minimize( |
|
*rules: Callable[[_S], _T], |
|
objective=_identity |
|
) -> Callable[[_S], _T]: |
|
""" Select result of rules that minimizes objective |
|
|
|
>>> from sympy.strategies import minimize |
|
>>> inc = lambda x: x + 1 |
|
>>> dec = lambda x: x - 1 |
|
>>> rl = minimize(inc, dec) |
|
>>> rl(4) |
|
3 |
|
|
|
>>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize |
|
>>> rl(4) |
|
5 |
|
""" |
|
def minrule(expr: _S) -> _T: |
|
return min([rule(expr) for rule in rules], key=objective) |
|
return minrule |
|
|