File size: 6,426 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import logging
from typing import Dict, Optional, Tuple, Type
import sympy
from torch.utils._sympy.functions import FloorDiv
log = logging.getLogger(__name__)
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
sympy.Eq: sympy.Eq,
sympy.Ne: sympy.Ne,
sympy.Ge: sympy.Le,
sympy.Gt: sympy.Lt,
sympy.Le: sympy.Ge,
sympy.Lt: sympy.Gt,
}
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type, None)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
#
# Returns a tuple of:
# 1. The simplified expression
# 2. The expression on the right-hand side
#
# Returns 'None' if it can't reach a state where the only thing in the left
# hand side is 'thing'.
#
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
# left-hand side.
#
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
# inequalities.
def try_solve(
expr: sympy.Basic,
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions:
# - Those that are not relational operations
# - Those that don't have a mirror (just avoiding unexpected classes)
if not isinstance(expr, sympy.Rel) or mirror is None:
log.debug("expression with unsupported type: %s", type(expr))
return None
lhs_has_thing = expr.lhs.has(thing)
rhs_has_thing = expr.rhs.has(thing)
# Give up when 'thing' appears on both sides of the relational expression.
# That is because, as is, we assume the thing we are trying to isolate is
# only on the right-hand side.
if lhs_has_thing and rhs_has_thing:
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
return None
# Try considering both LHS and RHS by mirroring the original expression:
# a < b ==> b > a
expressions = []
# Add each version of 'expr' if 'thing' is in its left-hand side.
if lhs_has_thing:
expressions.append(expr)
if rhs_has_thing:
expressions.append(mirror(expr.rhs, expr.lhs))
for e in expressions:
if e is None:
continue
assert isinstance(e, sympy.Rel)
for _ in range(trials):
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
# Stop if there was no change in this trial.
if trial == e:
break
e = trial # type: ignore[assignment]
# Return if we were able to isolate 'thing' on the left-hand side.
if isinstance(e, sympy.Rel) and e.lhs == thing:
log.debug("solved: %s ---> %s", expr, e)
return e, e.rhs
return None
def _try_isolate_lhs(
expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
) -> sympy.Basic:
e = expr
op = type(expr)
if isinstance(e, sympy.Rel):
# Move any constants in the left-hand side to the right-hand side.
lhs_not_thing = (
sum(a for a in e.lhs.args if not a.has(thing))
if isinstance(e.lhs, sympy.Add)
else 0
)
e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing) # type: ignore[attr-defined]
# Divide both sides by the factors that don't contain thing.
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
lhs, rhs = e.args
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
# If we can't tell whether 'other' is negative or positive, we do nothing.
# That is because we don't know whether we have mirror the operation or not.
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
# Divide both sides by 'other'.
lhs = lhs / other
rhs = rhs / other
# If 'e' is an inequality and 'other' is negative, we have to
# mirror the expression.
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
op = mirror_rel_op(op) # type: ignore[assignment]
assert op is not None
e = op(lhs, rhs)
################################################################################
# left-hand side is FloorDiv
################################################################################
#
# Given the expression: a // b op c
# where 'op' is a relational operation, these rules only work if:
# - b > 0
# - c is an integer
if (
floordiv_inequality
and isinstance(e, sympy.Rel)
and isinstance(e.lhs, FloorDiv)
and e.lhs.divisor.is_positive
and e.rhs.is_integer
):
# a // b == expr
# => a >= (b * expr) and a < (b * (expr + 1))
if isinstance(expr, sympy.Eq):
numerator, denominator = e.lhs.args
return sympy.And(
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# a // b != expr
# => a < (b * expr) or a >= (b * (expr + 1))
if isinstance(expr, sympy.Ne):
numerator, denominator = e.lhs.args
return sympy.Or(
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# The transformations below only work if b is positive.
# Note: we only have this information for constants.
# a // b > expr => a >= b * (expr + 1)
# a // b >= expr => a >= b * expr
if isinstance(expr, (sympy.Gt, sympy.Ge)):
quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
# a // b < expr => a < b * expr
# a // b <= expr => a < b * (expr + 1)
if isinstance(expr, (sympy.Lt, sympy.Le)):
quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
return e
|