File size: 864 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from rstor.properties import DEVICE, OPTIMIZER, PARAMS
from rstor.architecture.selector import load_architecture
from rstor.data.dataloader import get_data_loader
from typing import Tuple
import torch


def get_training_content(
        config: dict,
        training_mode: bool = False,
        device=DEVICE) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]:
    model = load_architecture(config)
    optimizer, dl_dict = None, None
    if training_mode:
        optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS])
        dl_dict = get_data_loader(config)
    return model, optimizer, dl_dict


if __name__ == "__main__":
    from rstor.learning.experiments_definition import default_experiment
    config = default_experiment(1)
    model, optimizer, dl_dict = get_training_content(config, training_mode=True)
    print(config)