File size: 7,474 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
import logging

logger = logging.getLogger(__name__)

__all__ = [
    "PassManager",
    "inplace_wrapper",
    "log_hook",
    "loop_pass",
    "this_before_that_pass_constraint",
    "these_before_those_pass_constraint",
]

# for callables which modify object inplace and return something other than
# the object on which they act
def inplace_wrapper(fn: Callable) -> Callable:
    """

    Convenience wrapper for passes which modify an object inplace. This

    wrapper makes them return the modified object instead.



    Args:

        fn (Callable[Object, Any])



    Returns:

        wrapped_fn (Callable[Object, Object])

    """

    @wraps(fn)
    def wrapped_fn(gm):
        val = fn(gm)
        return gm

    return wrapped_fn

def log_hook(fn: Callable, level=logging.INFO) -> Callable:
    """

    Logs callable output.



    This is useful for logging output of passes. Note inplace_wrapper replaces

    the pass output with the modified object. If we want to log the original

    output, apply this wrapper before inplace_wrapper.





    ```

    def my_pass(d: Dict) -> bool:

        changed = False

        if 'foo' in d:

            d['foo'] = 'bar'

            changed = True

        return changed



    pm = PassManager(

        passes=[

            inplace_wrapper(log_hook(my_pass))

        ]

    )

    ```



    Args:

        fn (Callable[Type1, Type2])

        level: logging level (e.g. logging.INFO)



    Returns:

        wrapped_fn (Callable[Type1, Type2])

    """
    @wraps(fn)
    def wrapped_fn(gm):
        val = fn(gm)
        logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
        return val

    return wrapped_fn



def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
    """

    Convenience wrapper for passes which need to be applied multiple times.



    Exactly one of `n_iter`or `predicate` must be specified.



    Args:

        base_pass (Callable[Object, Object]): pass to be applied in loop

        n_iter (int, optional): number of times to loop pass

        predicate (Callable[Object, bool], optional):



    """
    assert (n_iter is not None) ^ (
        predicate is not None
    ), "Exactly one of `n_iter`or `predicate` must be specified."

    @wraps(base_pass)
    def new_pass(source):
        output = source
        if n_iter is not None and n_iter > 0:
            for _ in range(n_iter):
                output = base_pass(output)
        elif predicate is not None:
            while predicate(output):
                output = base_pass(output)
        else:
            raise RuntimeError(
                f"loop_pass must be given positive int n_iter (given "
                f"{n_iter}) xor predicate (given {predicate})"
            )
        return output

    return new_pass


# Pass Schedule Constraints:
#
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint(

    constraint: Callable[[Callable, Callable], bool], passes: List[Callable]

):
    for i, a in enumerate(passes):
        for j, b in enumerate(passes[i + 1 :]):
            if constraint(a, b):
                continue
            raise RuntimeError(
                f"pass schedule constraint violated. Expected {a} before {b}"
                f" but found {a} at index {i} and {b} at index{j} in pass"
                f" list."
            )


def this_before_that_pass_constraint(this: Callable, that: Callable):
    """

    Defines a partial order ('depends on' function) where `this` must occur

    before `that`.

    """

    def depends_on(a: Callable, b: Callable):
        if a == that and b == this:
            return False
        return True

    return depends_on


def these_before_those_pass_constraint(these: Callable, those: Callable):
    """

    Defines a partial order ('depends on' function) where `these` must occur

    before `those`. Where the inputs are 'unwrapped' before comparison.



    For example, the following pass list and constraint list would be invalid.

    ```

    passes = [

        loop_pass(pass_b, 3),

        loop_pass(pass_a, 5),

    ]



    constraints = [

        these_before_those_pass_constraint(pass_a, pass_b)

    ]

    ```



    Args:

        these (Callable): pass which should occur first

        those (Callable): pass which should occur later



    Returns:

        depends_on (Callable[[Object, Object], bool]

    """

    def depends_on(a: Callable, b: Callable):
        if unwrap(a) == those and unwrap(b) == these:
            return False
        return True

    return depends_on


class PassManager:
    """

    Construct a PassManager.



    Collects passes and constraints. This defines the pass schedule, manages

    pass constraints and pass execution.



    Args:

        passes (Optional[List[Callable]]): list of passes. A pass is a

            callable which modifies an object and returns modified object

        constraint (Optional[List[Callable]]): list of constraints. A

            constraint is a callable which takes two passes (A, B) and returns

            True if A depends on B and False otherwise. See implementation of

            `this_before_that_pass_constraint` for example.

    """

    passes: List[Callable]
    constraints: List[Callable]
    _validated: bool = False

    def __init__(

        self,

        passes=None,

        constraints=None,

    ):
        self.passes = passes or []
        self.constraints = constraints or []

    @classmethod
    def build_from_passlist(cls, passes):
        pm = PassManager(passes)
        # TODO(alexbeloi): add constraint management/validation
        return pm

    def add_pass(self, _pass: Callable):
        self.passes.append(_pass)
        self._validated = False

    def add_constraint(self, constraint):
        self.constraints.append(constraint)
        self._validated = False

    def remove_pass(self, _passes: List[str]):
        if _passes is None:
            return
        passes_left = []
        for ps in self.passes:
            if ps.__name__ not in _passes:
                passes_left.append(ps)
        self.passes = passes_left
        self._validated = False

    def replace_pass(self, _target, _replacement):
        passes_left = []
        for ps in self.passes:
            if ps.__name__ == _target.__name__:
                passes_left.append(_replacement)
            else:
                passes_left.append(ps)
        self.passes = passes_left
        self._validated = False

    def validate(self):
        """

        Validates that current pass schedule defined by `self.passes` is valid

        according to all constraints in `self.constraints`

        """
        if self._validated:
            return
        for constraint in self.constraints:
            _validate_pass_schedule_constraint(constraint, self.passes)
        self._validated = True

    def __call__(self, source):
        self.validate()
        out = source
        for _pass in self.passes:
            out = _pass(out)
        return out