Spaces:
Running
Running
File size: 7,474 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
import logging
logger = logging.getLogger(__name__)
__all__ = [
"PassManager",
"inplace_wrapper",
"log_hook",
"loop_pass",
"this_before_that_pass_constraint",
"these_before_those_pass_constraint",
]
# for callables which modify object inplace and return something other than
# the object on which they act
def inplace_wrapper(fn: Callable) -> Callable:
"""
Convenience wrapper for passes which modify an object inplace. This
wrapper makes them return the modified object instead.
Args:
fn (Callable[Object, Any])
Returns:
wrapped_fn (Callable[Object, Object])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
return gm
return wrapped_fn
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
"""
Logs callable output.
This is useful for logging output of passes. Note inplace_wrapper replaces
the pass output with the modified object. If we want to log the original
output, apply this wrapper before inplace_wrapper.
```
def my_pass(d: Dict) -> bool:
changed = False
if 'foo' in d:
d['foo'] = 'bar'
changed = True
return changed
pm = PassManager(
passes=[
inplace_wrapper(log_hook(my_pass))
]
)
```
Args:
fn (Callable[Type1, Type2])
level: logging level (e.g. logging.INFO)
Returns:
wrapped_fn (Callable[Type1, Type2])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
return val
return wrapped_fn
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
"""
Convenience wrapper for passes which need to be applied multiple times.
Exactly one of `n_iter`or `predicate` must be specified.
Args:
base_pass (Callable[Object, Object]): pass to be applied in loop
n_iter (int, optional): number of times to loop pass
predicate (Callable[Object, bool], optional):
"""
assert (n_iter is not None) ^ (
predicate is not None
), "Exactly one of `n_iter`or `predicate` must be specified."
@wraps(base_pass)
def new_pass(source):
output = source
if n_iter is not None and n_iter > 0:
for _ in range(n_iter):
output = base_pass(output)
elif predicate is not None:
while predicate(output):
output = base_pass(output)
else:
raise RuntimeError(
f"loop_pass must be given positive int n_iter (given "
f"{n_iter}) xor predicate (given {predicate})"
)
return output
return new_pass
# Pass Schedule Constraints:
#
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
):
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
if constraint(a, b):
continue
raise RuntimeError(
f"pass schedule constraint violated. Expected {a} before {b}"
f" but found {a} at index {i} and {b} at index{j} in pass"
f" list."
)
def this_before_that_pass_constraint(this: Callable, that: Callable):
"""
Defines a partial order ('depends on' function) where `this` must occur
before `that`.
"""
def depends_on(a: Callable, b: Callable):
if a == that and b == this:
return False
return True
return depends_on
def these_before_those_pass_constraint(these: Callable, those: Callable):
"""
Defines a partial order ('depends on' function) where `these` must occur
before `those`. Where the inputs are 'unwrapped' before comparison.
For example, the following pass list and constraint list would be invalid.
```
passes = [
loop_pass(pass_b, 3),
loop_pass(pass_a, 5),
]
constraints = [
these_before_those_pass_constraint(pass_a, pass_b)
]
```
Args:
these (Callable): pass which should occur first
those (Callable): pass which should occur later
Returns:
depends_on (Callable[[Object, Object], bool]
"""
def depends_on(a: Callable, b: Callable):
if unwrap(a) == those and unwrap(b) == these:
return False
return True
return depends_on
class PassManager:
"""
Construct a PassManager.
Collects passes and constraints. This defines the pass schedule, manages
pass constraints and pass execution.
Args:
passes (Optional[List[Callable]]): list of passes. A pass is a
callable which modifies an object and returns modified object
constraint (Optional[List[Callable]]): list of constraints. A
constraint is a callable which takes two passes (A, B) and returns
True if A depends on B and False otherwise. See implementation of
`this_before_that_pass_constraint` for example.
"""
passes: List[Callable]
constraints: List[Callable]
_validated: bool = False
def __init__(
self,
passes=None,
constraints=None,
):
self.passes = passes or []
self.constraints = constraints or []
@classmethod
def build_from_passlist(cls, passes):
pm = PassManager(passes)
# TODO(alexbeloi): add constraint management/validation
return pm
def add_pass(self, _pass: Callable):
self.passes.append(_pass)
self._validated = False
def add_constraint(self, constraint):
self.constraints.append(constraint)
self._validated = False
def remove_pass(self, _passes: List[str]):
if _passes is None:
return
passes_left = []
for ps in self.passes:
if ps.__name__ not in _passes:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def replace_pass(self, _target, _replacement):
passes_left = []
for ps in self.passes:
if ps.__name__ == _target.__name__:
passes_left.append(_replacement)
else:
passes_left.append(ps)
self.passes = passes_left
self._validated = False
def validate(self):
"""
Validates that current pass schedule defined by `self.passes` is valid
according to all constraints in `self.constraints`
"""
if self._validated:
return
for constraint in self.constraints:
_validate_pass_schedule_constraint(constraint, self.passes)
self._validated = True
def __call__(self, source):
self.validate()
out = source
for _pass in self.passes:
out = _pass(out)
return out
|