File size: 792 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from rstor.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE
from rstor.architecture.stacked_convolutions import StackedConvolutions
from rstor.architecture.nafnet import NAFNet, UNet
import torch


def load_architecture(config: dict) -> torch.nn.Module:
    conf_model = config[MODEL][ARCHITECTURE]
    if config[MODEL][NAME] == StackedConvolutions.__name__:
        model = StackedConvolutions(**conf_model)
    elif config[MODEL][NAME] == NAFNet.__name__:
        model = NAFNet(**conf_model)
    elif config[MODEL][NAME] == UNet.__name__:
        model = UNet(**conf_model)
    else:
        raise ValueError(f"Unknown model {config[MODEL][NAME]}")
    config[MODEL][N_PARAMS] = model.count_parameters()
    config[MODEL]["receptive_field"] = model.receptive_field()
    return model