File size: 6,426 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import logging

from typing import Dict, Optional, Tuple, Type

import sympy

from torch.utils._sympy.functions import FloorDiv

log = logging.getLogger(__name__)

_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
    sympy.Eq: sympy.Eq,
    sympy.Ne: sympy.Ne,
    sympy.Ge: sympy.Le,
    sympy.Gt: sympy.Lt,
    sympy.Le: sympy.Ge,
    sympy.Lt: sympy.Gt,
}

INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)


def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
    return _MIRROR_REL_OP.get(type, None)


# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
#
# Returns a tuple of:
#   1. The simplified expression
#   2. The expression on the right-hand side
#
# Returns 'None' if it can't reach a state where the only thing in the left
# hand side is 'thing'.
#
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
# left-hand side.
#
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
# inequalities.
def try_solve(
    expr: sympy.Basic,
    thing: sympy.Basic,
    trials: int = 5,
    floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
    mirror = mirror_rel_op(type(expr))

    # Ignore unsupported expressions:
    #   - Those that are not relational operations
    #   - Those that don't have a mirror (just avoiding unexpected classes)
    if not isinstance(expr, sympy.Rel) or mirror is None:
        log.debug("expression with unsupported type: %s", type(expr))
        return None

    lhs_has_thing = expr.lhs.has(thing)
    rhs_has_thing = expr.rhs.has(thing)

    # Give up when 'thing' appears on both sides of the relational expression.
    # That is because, as is, we assume the thing we are trying to isolate is
    # only on the right-hand side.
    if lhs_has_thing and rhs_has_thing:
        log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
        return None

    # Try considering both LHS and RHS by mirroring the original expression:
    # a < b ==> b > a
    expressions = []

    # Add each version of 'expr' if 'thing' is in its left-hand side.
    if lhs_has_thing:
        expressions.append(expr)
    if rhs_has_thing:
        expressions.append(mirror(expr.rhs, expr.lhs))

    for e in expressions:
        if e is None:
            continue

        assert isinstance(e, sympy.Rel)

        for _ in range(trials):
            trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
            # Stop if there was no change in this trial.
            if trial == e:
                break
            e = trial  # type: ignore[assignment]

        # Return if we were able to isolate 'thing' on the left-hand side.
        if isinstance(e, sympy.Rel) and e.lhs == thing:
            log.debug("solved: %s ---> %s", expr, e)
            return e, e.rhs

    return None


def _try_isolate_lhs(
    expr: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
) -> sympy.Basic:
    e = expr
    op = type(expr)

    if isinstance(e, sympy.Rel):
        # Move any constants in the left-hand side to the right-hand side.
        lhs_not_thing = (
            sum(a for a in e.lhs.args if not a.has(thing))
            if isinstance(e.lhs, sympy.Add)
            else 0
        )
        e = op(expr.lhs - lhs_not_thing, expr.rhs - lhs_not_thing)  # type: ignore[attr-defined]

    # Divide both sides by the factors that don't contain thing.
    if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
        lhs, rhs = e.args
        other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])

        # If we can't tell whether 'other' is negative or positive, we do nothing.
        # That is because we don't know whether we have mirror the operation or not.
        if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
            # Divide both sides by 'other'.
            lhs = lhs / other
            rhs = rhs / other

            # If 'e' is an inequality and 'other' is negative, we have to
            # mirror the expression.
            if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
                op = mirror_rel_op(op)  # type: ignore[assignment]

            assert op is not None
            e = op(lhs, rhs)

    ################################################################################
    # left-hand side is FloorDiv
    ################################################################################
    #
    # Given the expression: a // b op c
    # where 'op' is a relational operation, these rules only work if:
    #   - b > 0
    #   - c is an integer
    if (
        floordiv_inequality
        and isinstance(e, sympy.Rel)
        and isinstance(e.lhs, FloorDiv)
        and e.lhs.divisor.is_positive
        and e.rhs.is_integer
    ):
        # a // b == expr
        # => a >= (b * expr) and a < (b * (expr + 1))
        if isinstance(expr, sympy.Eq):
            numerator, denominator = e.lhs.args
            return sympy.And(
                sympy.Ge(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
                sympy.Lt(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
            )
        # a // b != expr
        # => a < (b * expr) or a >= (b * (expr + 1))
        if isinstance(expr, sympy.Ne):
            numerator, denominator = e.lhs.args
            return sympy.Or(
                sympy.Lt(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
                sympy.Ge(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
            )
        # The transformations below only work if b is positive.
        # Note: we only have this information for constants.
        # a // b > expr  => a >= b * (expr + 1)
        # a // b >= expr => a >= b * expr
        if isinstance(expr, (sympy.Gt, sympy.Ge)):
            quotient = e.rhs if isinstance(expr, sympy.Ge) else (e.rhs + 1)  # type: ignore[arg-type]
            return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]
        # a // b < expr  => a < b * expr
        # a // b <= expr => a < b * (expr + 1)
        if isinstance(expr, (sympy.Lt, sympy.Le)):
            quotient = e.rhs if isinstance(expr, sympy.Lt) else (e.rhs + 1)  # type: ignore[arg-type]
            return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]

    return e