import os import sys import inspect sys.path.append(os.getcwd()) from main.library.speaker_diarization.speechbrain import fetch, run_on_main from main.library.speaker_diarization.features import DEFAULT_TRANSFER_HOOKS, DEFAULT_LOAD_HOOKS def get_default_hook(obj, default_hooks): for cls in inspect.getmro(type(obj)): if cls in default_hooks: return default_hooks[cls] return None class Pretrainer: def __init__(self, loadables=None, paths=None, custom_hooks=None, conditions=None): self.loadables = {} if loadables is not None: self.add_loadables(loadables) self.paths = {} if paths is not None: self.add_paths(paths) self.custom_hooks = {} if custom_hooks is not None: self.add_custom_hooks(custom_hooks) self.conditions = {} if conditions is not None: self.add_conditions(conditions) self.is_local = [] def add_loadables(self, loadables): self.loadables.update(loadables) def add_paths(self, paths): self.paths.update(paths) def add_custom_hooks(self, custom_hooks): self.custom_hooks.update(custom_hooks) def add_conditions(self, conditions): self.conditions.update(conditions) @staticmethod def split_path(path): def split(src): if "/" in src: return src.rsplit("/", maxsplit=1) else: return "./", src return split(path) def collect_files(self, default_source=None): loadable_paths = {} for name in self.loadables: if not self.is_loadable(name): continue save_filename = name + ".ckpt" if name in self.paths: source, filename = self.split_path(self.paths[name]) elif default_source is not None: filename = save_filename source = default_source else: raise ValueError fetch_kwargs = {"filename": filename, "source": source} path = None def run_fetch(**kwargs): nonlocal path path = fetch(**kwargs) run_on_main(run_fetch, kwargs=fetch_kwargs, post_func=run_fetch, post_kwargs=fetch_kwargs) loadable_paths[name] = path self.paths[name] = str(path) self.is_local.append(name) return loadable_paths def is_loadable(self, name): if name not in self.conditions: return True condition = self.conditions[name] if callable(condition): return condition() else: return bool(condition) def load_collected(self): paramfiles = {} for name in self.loadables: if not self.is_loadable(name): continue if name in self.is_local: paramfiles[name] = self.paths[name] else: raise ValueError self._call_load_hooks(paramfiles) def _call_load_hooks(self, paramfiles): for name, obj in self.loadables.items(): if not self.is_loadable(name): continue loadpath = paramfiles[name] if name in self.custom_hooks: self.custom_hooks[name](obj, loadpath) continue default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS) if default_hook is not None: default_hook(obj, loadpath) continue default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS) if default_hook is not None: end_of_epoch = False default_hook(obj, loadpath, end_of_epoch) continue raise RuntimeError