|
|
|
from functools import wraps |
|
from inspect import unwrap |
|
from typing import Callable, List, Optional |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = [ |
|
"PassManager", |
|
"inplace_wrapper", |
|
"log_hook", |
|
"loop_pass", |
|
"this_before_that_pass_constraint", |
|
"these_before_those_pass_constraint", |
|
] |
|
|
|
|
|
|
|
def inplace_wrapper(fn: Callable) -> Callable: |
|
""" |
|
Convenience wrapper for passes which modify an object inplace. This |
|
wrapper makes them return the modified object instead. |
|
|
|
Args: |
|
fn (Callable[Object, Any]) |
|
|
|
Returns: |
|
wrapped_fn (Callable[Object, Object]) |
|
""" |
|
|
|
@wraps(fn) |
|
def wrapped_fn(gm): |
|
val = fn(gm) |
|
return gm |
|
|
|
return wrapped_fn |
|
|
|
def log_hook(fn: Callable, level=logging.INFO) -> Callable: |
|
""" |
|
Logs callable output. |
|
|
|
This is useful for logging output of passes. Note inplace_wrapper replaces |
|
the pass output with the modified object. If we want to log the original |
|
output, apply this wrapper before inplace_wrapper. |
|
|
|
|
|
``` |
|
def my_pass(d: Dict) -> bool: |
|
changed = False |
|
if 'foo' in d: |
|
d['foo'] = 'bar' |
|
changed = True |
|
return changed |
|
|
|
pm = PassManager( |
|
passes=[ |
|
inplace_wrapper(log_hook(my_pass)) |
|
] |
|
) |
|
``` |
|
|
|
Args: |
|
fn (Callable[Type1, Type2]) |
|
level: logging level (e.g. logging.INFO) |
|
|
|
Returns: |
|
wrapped_fn (Callable[Type1, Type2]) |
|
""" |
|
@wraps(fn) |
|
def wrapped_fn(gm): |
|
val = fn(gm) |
|
logger.log(level, "Ran pass %s\t Return value: %s", fn, val) |
|
return val |
|
|
|
return wrapped_fn |
|
|
|
|
|
|
|
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): |
|
""" |
|
Convenience wrapper for passes which need to be applied multiple times. |
|
|
|
Exactly one of `n_iter`or `predicate` must be specified. |
|
|
|
Args: |
|
base_pass (Callable[Object, Object]): pass to be applied in loop |
|
n_iter (int, optional): number of times to loop pass |
|
predicate (Callable[Object, bool], optional): |
|
|
|
""" |
|
assert (n_iter is not None) ^ ( |
|
predicate is not None |
|
), "Exactly one of `n_iter`or `predicate` must be specified." |
|
|
|
@wraps(base_pass) |
|
def new_pass(source): |
|
output = source |
|
if n_iter is not None and n_iter > 0: |
|
for _ in range(n_iter): |
|
output = base_pass(output) |
|
elif predicate is not None: |
|
while predicate(output): |
|
output = base_pass(output) |
|
else: |
|
raise RuntimeError( |
|
f"loop_pass must be given positive int n_iter (given " |
|
f"{n_iter}) xor predicate (given {predicate})" |
|
) |
|
return output |
|
|
|
return new_pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_pass_schedule_constraint( |
|
constraint: Callable[[Callable, Callable], bool], passes: List[Callable] |
|
): |
|
for i, a in enumerate(passes): |
|
for j, b in enumerate(passes[i + 1 :]): |
|
if constraint(a, b): |
|
continue |
|
raise RuntimeError( |
|
f"pass schedule constraint violated. Expected {a} before {b}" |
|
f" but found {a} at index {i} and {b} at index{j} in pass" |
|
f" list." |
|
) |
|
|
|
|
|
def this_before_that_pass_constraint(this: Callable, that: Callable): |
|
""" |
|
Defines a partial order ('depends on' function) where `this` must occur |
|
before `that`. |
|
""" |
|
|
|
def depends_on(a: Callable, b: Callable): |
|
if a == that and b == this: |
|
return False |
|
return True |
|
|
|
return depends_on |
|
|
|
|
|
def these_before_those_pass_constraint(these: Callable, those: Callable): |
|
""" |
|
Defines a partial order ('depends on' function) where `these` must occur |
|
before `those`. Where the inputs are 'unwrapped' before comparison. |
|
|
|
For example, the following pass list and constraint list would be invalid. |
|
``` |
|
passes = [ |
|
loop_pass(pass_b, 3), |
|
loop_pass(pass_a, 5), |
|
] |
|
|
|
constraints = [ |
|
these_before_those_pass_constraint(pass_a, pass_b) |
|
] |
|
``` |
|
|
|
Args: |
|
these (Callable): pass which should occur first |
|
those (Callable): pass which should occur later |
|
|
|
Returns: |
|
depends_on (Callable[[Object, Object], bool] |
|
""" |
|
|
|
def depends_on(a: Callable, b: Callable): |
|
if unwrap(a) == those and unwrap(b) == these: |
|
return False |
|
return True |
|
|
|
return depends_on |
|
|
|
|
|
class PassManager: |
|
""" |
|
Construct a PassManager. |
|
|
|
Collects passes and constraints. This defines the pass schedule, manages |
|
pass constraints and pass execution. |
|
|
|
Args: |
|
passes (Optional[List[Callable]]): list of passes. A pass is a |
|
callable which modifies an object and returns modified object |
|
constraint (Optional[List[Callable]]): list of constraints. A |
|
constraint is a callable which takes two passes (A, B) and returns |
|
True if A depends on B and False otherwise. See implementation of |
|
`this_before_that_pass_constraint` for example. |
|
""" |
|
|
|
passes: List[Callable] |
|
constraints: List[Callable] |
|
_validated: bool = False |
|
|
|
def __init__( |
|
self, |
|
passes=None, |
|
constraints=None, |
|
): |
|
self.passes = passes or [] |
|
self.constraints = constraints or [] |
|
|
|
@classmethod |
|
def build_from_passlist(cls, passes): |
|
pm = PassManager(passes) |
|
|
|
return pm |
|
|
|
def add_pass(self, _pass: Callable): |
|
self.passes.append(_pass) |
|
self._validated = False |
|
|
|
def add_constraint(self, constraint): |
|
self.constraints.append(constraint) |
|
self._validated = False |
|
|
|
def remove_pass(self, _passes: List[str]): |
|
if _passes is None: |
|
return |
|
passes_left = [] |
|
for ps in self.passes: |
|
if ps.__name__ not in _passes: |
|
passes_left.append(ps) |
|
self.passes = passes_left |
|
self._validated = False |
|
|
|
def replace_pass(self, _target, _replacement): |
|
passes_left = [] |
|
for ps in self.passes: |
|
if ps.__name__ == _target.__name__: |
|
passes_left.append(_replacement) |
|
else: |
|
passes_left.append(ps) |
|
self.passes = passes_left |
|
self._validated = False |
|
|
|
def validate(self): |
|
""" |
|
Validates that current pass schedule defined by `self.passes` is valid |
|
according to all constraints in `self.constraints` |
|
""" |
|
if self._validated: |
|
return |
|
for constraint in self.constraints: |
|
_validate_pass_schedule_constraint(constraint, self.passes) |
|
self._validated = True |
|
|
|
def __call__(self, source): |
|
self.validate() |
|
out = source |
|
for _pass in self.passes: |
|
out = _pass(out) |
|
return out |
|
|