File size: 3,637 Bytes
96134ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys
import inspect

sys.path.append(os.getcwd())

from main.library.speaker_diarization.speechbrain import fetch, run_on_main
from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS


def get_default_hook(obj, default_hooks):
    for cls in inspect.getmro(type(obj)):
        if cls in default_hooks: return default_hooks[cls]
        
    return None

class Pretrainer:
    def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None):
        self.loadables = {}

        if loadables is not None: self.add_loadables(loadables)
        self.paths = {}

        if paths is not None: self.add_paths(paths)
        self.custom_hooks = {}

        if custom_hooks is not None: self.add_custom_hooks(custom_hooks)
        self.conditions = {}

        if conditions is not None: self.add_conditions(conditions)
        self.is_local = []

    def add_loadables(self, loadables):
        self.loadables.update(loadables)

    def add_paths(self, paths):
        self.paths.update(paths)

    def add_custom_hooks(self, custom_hooks):
        self.custom_hooks.update(custom_hooks)

    def add_conditions(self, conditions):
        self.conditions.update(conditions)

    @staticmethod
    def split_path(path):
        def split(src):
            if "/" in src: return src.rsplit("/", maxsplit=1)
            else: return "./", src

        return split(path)

    def collect_files(self, default_source=None):
        loadable_paths = {}
        for name in self.loadables:
            if not self.is_loadable(name): continue
            save_filename = name + ".ckpt"

            if name in self.paths: source, filename = self.split_path(self.paths[name])
            elif default_source is not None:
                filename = save_filename
                source = default_source
            else: raise ValueError

            fetch_kwargs = {"filename": filename, "source": source}
            path = None

            def run_fetch(**kwargs):
                nonlocal path

                path = fetch(**kwargs)

            run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs)

            loadable_paths[name] = path
            self.paths[name] = str(path)
            self.is_local.append(name)

        return loadable_paths

    def is_loadable(self, name):
        if name not in self.conditions: return True
        condition = self.conditions[name]

        if callable(condition): return condition()
        else: return bool(condition)

    def load_collected(self):
        paramfiles = {}
        for name in self.loadables:
            if not self.is_loadable(name): continue

            if name in self.is_local: paramfiles[name] = self.paths[name]
            else: raise ValueError

        self._call_load_hooks(paramfiles)

    def _call_load_hooks(self, paramfiles):
        for name, obj in self.loadables.items():
            if not self.is_loadable(name): continue
            loadpath = paramfiles[name]

            if name in self.custom_hooks:
                self.custom_hooks[name](obj, loadpath)
                continue

            default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
            
            if default_hook is not None:
                default_hook(obj, loadpath)
                continue

            default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)

            if default_hook is not None:
                end_of_epoch = False
                default_hook(obj, loadpath, end_of_epoch)
                continue

            raise RuntimeError