from rstor.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE from rstor.architecture.stacked_convolutions import StackedConvolutions from rstor.architecture.nafnet import NAFNet, UNet import torch def load_architecture(config: dict) -> torch.nn.Module: conf_model = config[MODEL][ARCHITECTURE] if config[MODEL][NAME] == StackedConvolutions.__name__: model = StackedConvolutions(**conf_model) elif config[MODEL][NAME] == NAFNet.__name__: model = NAFNet(**conf_model) elif config[MODEL][NAME] == UNet.__name__: model = UNet(**conf_model) else: raise ValueError(f"Unknown model {config[MODEL][NAME]}") config[MODEL][N_PARAMS] = model.count_parameters() config[MODEL]["receptive_field"] = model.receptive_field() return model