File size: 3,033 Bytes
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
1615d09
 
 
 
2cdd41c
 
1615d09
2cdd41c
 
 
 
 
 
1615d09
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
1615d09
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
1615d09
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
2cdd41c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import inspect
from copy import deepcopy
from functools import wraps

import torch.nn as nn


def serialize(init):
    parameters = list(inspect.signature(init).parameters)

    @wraps(init)
    def new_init(self, *args, **kwargs):
        params = deepcopy(kwargs)
        for pname, value in zip(parameters[1:], args):
            params[pname] = value

        config = {"class": get_classname(self.__class__), "params": dict()}
        specified_params = set(params.keys())

        for pname, param in get_default_params(self.__class__).items():
            if pname not in params:
                params[pname] = param.default

        for name, value in list(params.items()):
            param_type = "builtin"
            if inspect.isclass(value):
                param_type = "class"
                value = get_classname(value)

            config["params"][name] = {
                "type": param_type,
                "value": value,
                "specified": name in specified_params,
            }

        setattr(self, "_config", config)
        init(self, *args, **kwargs)

    return new_init


def load_model(config, **kwargs):
    model_class = get_class_from_str(config["class"])
    model_default_params = get_default_params(model_class)

    model_args = dict()
    for pname, param in config["params"].items():
        value = param["value"]
        if param["type"] == "class":
            value = get_class_from_str(value)

        if pname not in model_default_params and not param["specified"]:
            continue

        assert pname in model_default_params
        if not param["specified"] and model_default_params[pname].default == value:
            continue
        model_args[pname] = value

    model_args.update(kwargs)

    return model_class(**model_args)


def get_config_repr(config):
    config_str = f'Model: {config["class"]}\n'
    for pname, param in config["params"].items():
        value = param["value"]
        if param["type"] == "class":
            value = value.split(".")[-1]
        param_str = f"{pname:<22} = {str(value):<12}"
        if not param["specified"]:
            param_str += " (default)"
        config_str += param_str + "\n"
    return config_str


def get_default_params(some_class):
    params = dict()
    for mclass in some_class.mro():
        if mclass is nn.Module or mclass is object:
            continue

        mclass_params = inspect.signature(mclass.__init__).parameters
        for pname, param in mclass_params.items():
            if param.default != param.empty and pname not in params:
                params[pname] = param

    return params


def get_classname(cls):
    module = cls.__module__
    name = cls.__qualname__
    if module is not None and module != "__builtin__":
        name = module + "." + name
    return name


def get_class_from_str(class_str):
    components = class_str.split(".")
    mod = __import__(".".join(components[:-1]))
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod