File size: 3,935 Bytes
c64fb9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Dict, Text, Callable, List
from collections import defaultdict


class HookManager(object):
    def __init__(self, hook_dict: Dict[Text, List[Callable]] = None):
        self.hook_dict = hook_dict or defaultdict(list)
        self.called = defaultdict(int)
        self.forks = dict()

    def register(self, name: Text, func: Callable):
        assert name
        found_successor = False
        for header, d in self.forks.items():
            if name.startswith(header.split('.')[0]+'.'):
                next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
                prev_ = header.split('.')[0]
                if next_.isnumeric():
                    if prev_ + '.' + next_ == header:
                        d.register(name[len(header)+1:], func)
                        found_successor = True
                else:
                    if next_ == '*':
                        d.register(name[len(prev_ + '.*')+1:], func)
                        found_successor = True
                    else:
                        d.register(name[len(header)+1:], func) 
                        found_successor = True
        if not found_successor:
            self.hook_dict[name].append(func)
    
    def unregister(self, name: Text, func: Callable):
        assert name
        found_successor = False
        for header, d in self.forks.items():
            if name.startswith(header.split('.')[0]+'.'):
                next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
                prev_ = header.split('.')[0]
                if next_.isnumeric() and  prev_ + '.' + next_ == header:
                    d.register(name[len(header)+1:], func)
                elif next_ == '*':
                    d.register(name[len(prev_ + '.*')+1:], func)
                else:
                    d.register(name[len(header)+1:], func) 
                found_successor = True
        if not found_successor and func in self.hook_dict[name]:
            self.hook_dict[name].remove(func)
    
    def __call__(self, name: Text, **kwargs):
        if name in self.hook_dict:
            self.called[name] += 1
            for function in self.hook_dict[name]:
                ret = function(**kwargs)
            if len(self.hook_dict[name]) > 1: 
                last = self.hook_dict[name][-1]
                print(f'The last returned value comes from func {last}')
            return ret
        else:
           return kwargs['ret']

    def fork(self, name):
        if name in self.forks:
            raise ValueError(f'Forking with the same name is not allowed. Already forked with {name}.')
        filtered_hooks = [(k[len(name)+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.')]
        filtered_hooks_d = defaultdict(list)
        for i, j in filtered_hooks:
            if isinstance(j, list):
                filtered_hooks_d[i].extend(j)
            else:
                filtered_hooks_d[i].append(j)
        new_hook = HookManager(filtered_hooks_d)
        self.forks[name] = new_hook
        return new_hook

    def fork_iterative(self, name, iteration):
        filtered_hooks = [(k[len(name+'.'+str(iteration))+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.'+str(iteration)+'.')]
        filtered_hooks += [(k[len(name+'.*')+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.*.')]
        filtered_hooks_d = defaultdict(list)
        for i, j in filtered_hooks:
            if isinstance(j, list):
                filtered_hooks_d[i].extend(j)
            else:
                filtered_hooks_d[i].append(j)
        new_hook = HookManager(filtered_hooks_d)
        self.forks[name+'.'+str(iteration)] = new_hook
        return new_hook
    
    def finalize(self):
        for name in self.hook_dict.keys():
            if self.called[name] == 0:
                raise ValueError(f'Hook {name} was registered but never used!')