|
import logging |
|
|
|
from typing import Dict, Optional, Tuple, Type |
|
|
|
import sympy |
|
|
|
from torch.utils._sympy.functions import FloorDiv |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = { |
|
sympy.Eq: sympy.Eq, |
|
sympy.Ne: sympy.Ne, |
|
sympy.Ge: sympy.Le, |
|
sympy.Gt: sympy.Lt, |
|
sympy.Le: sympy.Ge, |
|
sympy.Lt: sympy.Gt, |
|
} |
|
|
|
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le) |
|
|
|
|
|
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]: |
|
return _MIRROR_REL_OP.get(type, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def try_solve( |
|
expr: sympy.Basic, |
|
thing: sympy.Basic, |
|
trials: int = 5, |
|
floordiv_inequality: bool = True, |
|
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]: |
|
mirror = mirror_rel_op(type(expr)) |
|
|
|
|
|
|
|
|
|
if not isinstance(expr, sympy.Rel) or mirror is None: |
|
log.debug("expression with unsupported type: %s", type(expr)) |
|
return None |
|
|
|
lhs_has_thing = expr.lhs.has(thing) |
|
rhs_has_thing = expr.rhs.has(thing) |
|
|
|
|
|
|
|
|
|
if lhs_has_thing and rhs_has_thing: |
|
log.debug("thing (%s) found in both sides of expression: %s", thing, expr) |
|
return None |
|
|
|
|
|
|
|
expressions = [] |
|
|
|
|
|
if lhs_has_thing: |
|
expressions.append(expr) |
|
if rhs_has_thing: |
|
expressions.append(mirror(expr.rhs, expr.lhs)) |
|
|
|
for e in expressions: |
|
if e is None: |
|
continue |
|
|
|
assert isinstance(e, sympy.Rel) |
|
|
|
for _ in range(trials): |
|
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality) |
|
|
|
if trial == e: |
|
break |
|
e = trial |
|
|
|
|
|
if isinstance(e, sympy.Rel) and e.lhs == thing: |
|
log.debug("solved: %s ---> %s", expr, e) |
|
return e, e.rhs |
|
|
|
return None |
|
|
|
|
|
def _try_isolate_lhs( |
|
expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool |
|
) -> sympy.Basic: |
|
e = expr |
|
op = type(expr) |
|
|
|
if isinstance(e, sympy.Rel): |
|
|
|
lhs_not_thing = ( |
|
sum(a for a in e.lhs.args if not a.has(thing)) |
|
if isinstance(e.lhs, sympy.Add) |
|
else 0 |
|
) |
|
e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) |
|
|
|
|
|
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul): |
|
lhs, rhs = e.args |
|
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)]) |
|
|
|
|
|
|
|
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None): |
|
|
|
lhs = lhs / other |
|
rhs = rhs / other |
|
|
|
|
|
|
|
if isinstance(e, INEQUALITY_TYPES) and other.is_negative: |
|
op = mirror_rel_op(op) |
|
|
|
assert op is not None |
|
e = op(lhs, rhs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
floordiv_inequality |
|
and isinstance(e, sympy.Rel) |
|
and isinstance(e.lhs, FloorDiv) |
|
and e.lhs.divisor.is_positive |
|
and e.rhs.is_integer |
|
): |
|
|
|
|
|
if isinstance(expr, sympy.Eq): |
|
numerator, denominator = e.lhs.args |
|
return sympy.And( |
|
sympy.Ge(numerator, (e.rhs * denominator)), |
|
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), |
|
) |
|
|
|
|
|
if isinstance(expr, sympy.Ne): |
|
numerator, denominator = e.lhs.args |
|
return sympy.Or( |
|
sympy.Lt(numerator, (e.rhs * denominator)), |
|
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), |
|
) |
|
|
|
|
|
|
|
|
|
if isinstance(expr, (sympy.Gt, sympy.Ge)): |
|
quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) |
|
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) |
|
|
|
|
|
if isinstance(expr, (sympy.Lt, sympy.Le)): |
|
quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) |
|
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) |
|
|
|
return e |
|
|