File size: 2,152 Bytes
351e3b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import importlib.util
import os
import sys
import warnings
from pathlib import Path

def import_model_module(file_path: os.PathLike):
    module_name = str(Path(file_path).relative_to(os.getcwd())).replace(os.path.sep, ".")
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


models = {}

file_dir = Path(__file__).parents[0]

for path in file_dir.glob("*"):
    if path.is_dir():
        model_file_path = path / "model.py"
        if not model_file_path.exists():
            continue
        module = import_model_module(model_file_path)
        name_key = "name"
        get_model_key = "get_model"
        name = getattr(module, name_key, None)
        get_model = getattr(module, get_model_key, None)

        def check_attr_exists(attr_name, attr):
            if attr is None:
                warnings.warn(
                    f"Module {model_file_path} should define attribute '{attr_name}'"
                )
                return False
            return True

        def check_attr_type(attr_name, attr, type):
            if isinstance(attr, type):
                return True
            warnings.warn(
                f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}"
            )
            return False

        def check_attr_callable(attr_name, attr):
            if callable(attr):
                return True
            warnings.warn(f"'{attr_name}' should be callable")
            return False

        if not check_attr_exists(name_key, name):
            continue
        if not check_attr_exists(get_model_key, get_model):
            continue
        if not check_attr_type(name_key, name, str):
            continue
        if not check_attr_callable(get_model_key, get_model):
            continue

        models[name] = get_model


def get_model(name: str):
    if name not in models:
        raise KeyError(f"No model with name {name}")
    return models[name]()


def get_all_model_names():
    return list(models.keys())