from substra_helpers.substra_runner import SubstraRunner, algo_generator from substra_helpers.model import CNN from substra_helpers.dataset import TorchDataset from substrafl.strategies import FedAvg import torch seed = 42 torch.manual_seed(seed) model = CNN() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = torch.nn.CrossEntropyLoss() runner = SubstraRunner() runner.set_up_clients() runner.prepare_data() runner.register_data() runner.register_metric() runner.algorithm = algo_generator( model=model, criterion=criterion, optimizer=optimizer, index_generator=runner.index_generator, dataset=TorchDataset, seed=seed )() runner.strategy = FedAvg() runner.set_aggregation() runner.set_testing() runner.run_compute_plan()