Spaces:
Running
Running
from functools import wraps | |
from inspect import unwrap | |
from typing import Callable, List, Optional | |
import logging | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"PassManager", | |
"inplace_wrapper", | |
"log_hook", | |
"loop_pass", | |
"this_before_that_pass_constraint", | |
"these_before_those_pass_constraint", | |
] | |
# for callables which modify object inplace and return something other than | |
# the object on which they act | |
def inplace_wrapper(fn: Callable) -> Callable: | |
""" | |
Convenience wrapper for passes which modify an object inplace. This | |
wrapper makes them return the modified object instead. | |
Args: | |
fn (Callable[Object, Any]) | |
Returns: | |
wrapped_fn (Callable[Object, Object]) | |
""" | |
def wrapped_fn(gm): | |
val = fn(gm) | |
return gm | |
return wrapped_fn | |
def log_hook(fn: Callable, level=logging.INFO) -> Callable: | |
""" | |
Logs callable output. | |
This is useful for logging output of passes. Note inplace_wrapper replaces | |
the pass output with the modified object. If we want to log the original | |
output, apply this wrapper before inplace_wrapper. | |
``` | |
def my_pass(d: Dict) -> bool: | |
changed = False | |
if 'foo' in d: | |
d['foo'] = 'bar' | |
changed = True | |
return changed | |
pm = PassManager( | |
passes=[ | |
inplace_wrapper(log_hook(my_pass)) | |
] | |
) | |
``` | |
Args: | |
fn (Callable[Type1, Type2]) | |
level: logging level (e.g. logging.INFO) | |
Returns: | |
wrapped_fn (Callable[Type1, Type2]) | |
""" | |
def wrapped_fn(gm): | |
val = fn(gm) | |
logger.log(level, "Ran pass %s\t Return value: %s", fn, val) | |
return val | |
return wrapped_fn | |
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): | |
""" | |
Convenience wrapper for passes which need to be applied multiple times. | |
Exactly one of `n_iter`or `predicate` must be specified. | |
Args: | |
base_pass (Callable[Object, Object]): pass to be applied in loop | |
n_iter (int, optional): number of times to loop pass | |
predicate (Callable[Object, bool], optional): | |
""" | |
assert (n_iter is not None) ^ ( | |
predicate is not None | |
), "Exactly one of `n_iter`or `predicate` must be specified." | |
def new_pass(source): | |
output = source | |
if n_iter is not None and n_iter > 0: | |
for _ in range(n_iter): | |
output = base_pass(output) | |
elif predicate is not None: | |
while predicate(output): | |
output = base_pass(output) | |
else: | |
raise RuntimeError( | |
f"loop_pass must be given positive int n_iter (given " | |
f"{n_iter}) xor predicate (given {predicate})" | |
) | |
return output | |
return new_pass | |
# Pass Schedule Constraints: | |
# | |
# Implemented as 'depends on' operators. A constraint is satisfied iff a list | |
# has a valid partial ordering according to this comparison operator. | |
def _validate_pass_schedule_constraint( | |
constraint: Callable[[Callable, Callable], bool], passes: List[Callable] | |
): | |
for i, a in enumerate(passes): | |
for j, b in enumerate(passes[i + 1 :]): | |
if constraint(a, b): | |
continue | |
raise RuntimeError( | |
f"pass schedule constraint violated. Expected {a} before {b}" | |
f" but found {a} at index {i} and {b} at index{j} in pass" | |
f" list." | |
) | |
def this_before_that_pass_constraint(this: Callable, that: Callable): | |
""" | |
Defines a partial order ('depends on' function) where `this` must occur | |
before `that`. | |
""" | |
def depends_on(a: Callable, b: Callable): | |
if a == that and b == this: | |
return False | |
return True | |
return depends_on | |
def these_before_those_pass_constraint(these: Callable, those: Callable): | |
""" | |
Defines a partial order ('depends on' function) where `these` must occur | |
before `those`. Where the inputs are 'unwrapped' before comparison. | |
For example, the following pass list and constraint list would be invalid. | |
``` | |
passes = [ | |
loop_pass(pass_b, 3), | |
loop_pass(pass_a, 5), | |
] | |
constraints = [ | |
these_before_those_pass_constraint(pass_a, pass_b) | |
] | |
``` | |
Args: | |
these (Callable): pass which should occur first | |
those (Callable): pass which should occur later | |
Returns: | |
depends_on (Callable[[Object, Object], bool] | |
""" | |
def depends_on(a: Callable, b: Callable): | |
if unwrap(a) == those and unwrap(b) == these: | |
return False | |
return True | |
return depends_on | |
class PassManager: | |
""" | |
Construct a PassManager. | |
Collects passes and constraints. This defines the pass schedule, manages | |
pass constraints and pass execution. | |
Args: | |
passes (Optional[List[Callable]]): list of passes. A pass is a | |
callable which modifies an object and returns modified object | |
constraint (Optional[List[Callable]]): list of constraints. A | |
constraint is a callable which takes two passes (A, B) and returns | |
True if A depends on B and False otherwise. See implementation of | |
`this_before_that_pass_constraint` for example. | |
""" | |
passes: List[Callable] | |
constraints: List[Callable] | |
_validated: bool = False | |
def __init__( | |
self, | |
passes=None, | |
constraints=None, | |
): | |
self.passes = passes or [] | |
self.constraints = constraints or [] | |
def build_from_passlist(cls, passes): | |
pm = PassManager(passes) | |
# TODO(alexbeloi): add constraint management/validation | |
return pm | |
def add_pass(self, _pass: Callable): | |
self.passes.append(_pass) | |
self._validated = False | |
def add_constraint(self, constraint): | |
self.constraints.append(constraint) | |
self._validated = False | |
def remove_pass(self, _passes: List[str]): | |
if _passes is None: | |
return | |
passes_left = [] | |
for ps in self.passes: | |
if ps.__name__ not in _passes: | |
passes_left.append(ps) | |
self.passes = passes_left | |
self._validated = False | |
def replace_pass(self, _target, _replacement): | |
passes_left = [] | |
for ps in self.passes: | |
if ps.__name__ == _target.__name__: | |
passes_left.append(_replacement) | |
else: | |
passes_left.append(ps) | |
self.passes = passes_left | |
self._validated = False | |
def validate(self): | |
""" | |
Validates that current pass schedule defined by `self.passes` is valid | |
according to all constraints in `self.constraints` | |
""" | |
if self._validated: | |
return | |
for constraint in self.constraints: | |
_validate_pass_schedule_constraint(constraint, self.passes) | |
self._validated = True | |
def __call__(self, source): | |
self.validate() | |
out = source | |
for _pass in self.passes: | |
out = _pass(out) | |
return out | |