Spaces:
Running
Running
File size: 3,935 Bytes
c64fb9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from typing import Dict, Text, Callable, List
from collections import defaultdict
class HookManager(object):
def __init__(self, hook_dict: Dict[Text, List[Callable]] = None):
self.hook_dict = hook_dict or defaultdict(list)
self.called = defaultdict(int)
self.forks = dict()
def register(self, name: Text, func: Callable):
assert name
found_successor = False
for header, d in self.forks.items():
if name.startswith(header.split('.')[0]+'.'):
next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
prev_ = header.split('.')[0]
if next_.isnumeric():
if prev_ + '.' + next_ == header:
d.register(name[len(header)+1:], func)
found_successor = True
else:
if next_ == '*':
d.register(name[len(prev_ + '.*')+1:], func)
found_successor = True
else:
d.register(name[len(header)+1:], func)
found_successor = True
if not found_successor:
self.hook_dict[name].append(func)
def unregister(self, name: Text, func: Callable):
assert name
found_successor = False
for header, d in self.forks.items():
if name.startswith(header.split('.')[0]+'.'):
next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
prev_ = header.split('.')[0]
if next_.isnumeric() and prev_ + '.' + next_ == header:
d.register(name[len(header)+1:], func)
elif next_ == '*':
d.register(name[len(prev_ + '.*')+1:], func)
else:
d.register(name[len(header)+1:], func)
found_successor = True
if not found_successor and func in self.hook_dict[name]:
self.hook_dict[name].remove(func)
def __call__(self, name: Text, **kwargs):
if name in self.hook_dict:
self.called[name] += 1
for function in self.hook_dict[name]:
ret = function(**kwargs)
if len(self.hook_dict[name]) > 1:
last = self.hook_dict[name][-1]
print(f'The last returned value comes from func {last}')
return ret
else:
return kwargs['ret']
def fork(self, name):
if name in self.forks:
raise ValueError(f'Forking with the same name is not allowed. Already forked with {name}.')
filtered_hooks = [(k[len(name)+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.')]
filtered_hooks_d = defaultdict(list)
for i, j in filtered_hooks:
if isinstance(j, list):
filtered_hooks_d[i].extend(j)
else:
filtered_hooks_d[i].append(j)
new_hook = HookManager(filtered_hooks_d)
self.forks[name] = new_hook
return new_hook
def fork_iterative(self, name, iteration):
filtered_hooks = [(k[len(name+'.'+str(iteration))+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.'+str(iteration)+'.')]
filtered_hooks += [(k[len(name+'.*')+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.*.')]
filtered_hooks_d = defaultdict(list)
for i, j in filtered_hooks:
if isinstance(j, list):
filtered_hooks_d[i].extend(j)
else:
filtered_hooks_d[i].append(j)
new_hook = HookManager(filtered_hooks_d)
self.forks[name+'.'+str(iteration)] = new_hook
return new_hook
def finalize(self):
for name in self.hook_dict.keys():
if self.called[name] == 0:
raise ValueError(f'Hook {name} was registered but never used!') |