from rstor.properties import DEVICE, OPTIMIZER, PARAMS from rstor.architecture.selector import load_architecture from rstor.data.dataloader import get_data_loader from typing import Tuple import torch def get_training_content( config: dict, training_mode: bool = False, device=DEVICE) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]: model = load_architecture(config) optimizer, dl_dict = None, None if training_mode: optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS]) dl_dict = get_data_loader(config) return model, optimizer, dl_dict if __name__ == "__main__": from rstor.learning.experiments_definition import default_experiment config = default_experiment(1) model, optimizer, dl_dict = get_training_content(config, training_mode=True) print(config)