Spaces:
Runtime error
Runtime error
import inspect | |
from copy import deepcopy | |
from functools import wraps | |
import torch.nn as nn | |
def serialize(init): | |
parameters = list(inspect.signature(init).parameters) | |
def new_init(self, *args, **kwargs): | |
params = deepcopy(kwargs) | |
for pname, value in zip(parameters[1:], args): | |
params[pname] = value | |
config = {"class": get_classname(self.__class__), "params": dict()} | |
specified_params = set(params.keys()) | |
for pname, param in get_default_params(self.__class__).items(): | |
if pname not in params: | |
params[pname] = param.default | |
for name, value in list(params.items()): | |
param_type = "builtin" | |
if inspect.isclass(value): | |
param_type = "class" | |
value = get_classname(value) | |
config["params"][name] = { | |
"type": param_type, | |
"value": value, | |
"specified": name in specified_params, | |
} | |
setattr(self, "_config", config) | |
init(self, *args, **kwargs) | |
return new_init | |
def load_model(config, **kwargs): | |
model_class = get_class_from_str(config["class"]) | |
model_default_params = get_default_params(model_class) | |
model_args = dict() | |
for pname, param in config["params"].items(): | |
value = param["value"] | |
if param["type"] == "class": | |
value = get_class_from_str(value) | |
if pname not in model_default_params and not param["specified"]: | |
continue | |
assert pname in model_default_params | |
if not param["specified"] and model_default_params[pname].default == value: | |
continue | |
model_args[pname] = value | |
model_args.update(kwargs) | |
return model_class(**model_args) | |
def get_config_repr(config): | |
config_str = f'Model: {config["class"]}\n' | |
for pname, param in config["params"].items(): | |
value = param["value"] | |
if param["type"] == "class": | |
value = value.split(".")[-1] | |
param_str = f"{pname:<22} = {str(value):<12}" | |
if not param["specified"]: | |
param_str += " (default)" | |
config_str += param_str + "\n" | |
return config_str | |
def get_default_params(some_class): | |
params = dict() | |
for mclass in some_class.mro(): | |
if mclass is nn.Module or mclass is object: | |
continue | |
mclass_params = inspect.signature(mclass.__init__).parameters | |
for pname, param in mclass_params.items(): | |
if param.default != param.empty and pname not in params: | |
params[pname] = param | |
return params | |
def get_classname(cls): | |
module = cls.__module__ | |
name = cls.__qualname__ | |
if module is not None and module != "__builtin__": | |
name = module + "." + name | |
return name | |
def get_class_from_str(class_str): | |
components = class_str.split(".") | |
mod = __import__(".".join(components[:-1])) | |
for comp in components[1:]: | |
mod = getattr(mod, comp) | |
return mod | |