import pathlib import shutil from typing import Optional, List from substra import Client, BackendType from substra.sdk.schemas import ( DatasetSpec, Permissions, DataSampleSpec ) from substrafl.strategies import Strategy from substrafl.dependency import Dependency from substrafl.remote.register import add_metric from substrafl.index_generator import NpIndexGenerator from substrafl.algorithms.pytorch import TorchFedAvgAlgo from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode from substrafl.evaluation_strategy import EvaluationStrategy from substrafl.experiment import execute_experiment from substra.sdk.models import ComputePlan from datasets import load_dataset, Dataset from sklearn.metrics import accuracy_score import numpy as np import torch class SubstraRunner: def __init__(self): self.num_clients = 3 self.clients = {} self.algo_provider: Optional[Client] = None self.datasets: List[Dataset] = [] self.test_dataset: Optional[Dataset] = None self.path = pathlib.Path(__file__).parent.resolve() self.dataset_keys = {} self.train_data_sample_keys = {} self.test_data_sample_keys = {} self.metric_key: Optional[str] = None NUM_UPDATES = 100 BATCH_SIZE = 32 self.index_generator = NpIndexGenerator( batch_size=BATCH_SIZE, num_updates=NUM_UPDATES, ) self.algorithm: Optional[TorchFedAvgAlgo] = None self.strategy: Optional[Strategy] = None self.aggregation_node: Optional[AggregationNode] = None self.train_data_nodes = list() self.test_data_nodes = list() self.eval_strategy: Optional[EvaluationStrategy] = None self.NUM_ROUNDS = 3 self.compute_plan: Optional[ComputePlan] = None self.experiment_folder = self.path / "experiment_summaries" def set_up_clients(self): self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS) self.clients = { c.organization_info().organization_id: c for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)] } def prepare_data(self): dataset = load_dataset("mnist", split="train").shuffle() self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)] self.test_dataset = load_dataset("mnist", split="test") data_path = self.path / "data" if data_path.exists() and data_path.is_dir(): shutil.rmtree(data_path) for i, client_id in enumerate(self.clients): ds = self.datasets[i] ds.save_to_disk(data_path / client_id / "train") self.test_dataset.save_to_disk(data_path / client_id / "test") def register_data(self): for client_id, client in self.clients.items(): permissions_dataset = Permissions(public=False, authorized_ids=[ self.algo_provider.organization_info().organization_id ]) dataset = DatasetSpec( name="MNIST", type="npy", data_opener=self.path / pathlib.Path("dataset_assets/opener.py"), description=self.path / pathlib.Path("dataset_assets/description.md"), permissions=permissions_dataset, logs_permission=permissions_dataset, ) self.dataset_keys[client_id] = client.add_dataset(dataset) assert self.dataset_keys[client_id], "Missing dataset key" self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec( data_manager_keys=[self.dataset_keys[client_id]], path=self.path / "data" / client_id / "train", )) data_sample = DataSampleSpec( data_manager_keys=[self.dataset_keys[client_id]], path=self.path / "data" / client_id / "test", ) self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample) def register_metric(self): permissions_metric = Permissions( public=False, authorized_ids=[ self.algo_provider.organization_info().organization_id ] + list(self.clients.keys()) ) metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"]) def accuracy(datasamples, predictions_path): y_true = datasamples["label"] y_pred = np.load(predictions_path) return accuracy_score(y_true, np.argmax(y_pred, axis=1)) self.metric_key = add_metric( client=self.algo_provider, metric_function=accuracy, permissions=permissions_metric, dependencies=metric_deps, ) def set_aggregation(self): self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id) for org_id in self.clients: train_data_node = TrainDataNode( organization_id=org_id, data_manager_key=self.dataset_keys[org_id], data_sample_keys=[self.train_data_sample_keys[org_id]], ) self.train_data_nodes.append(train_data_node) def set_testing(self): for org_id in self.clients: test_data_node = TestDataNode( organization_id=org_id, data_manager_key=self.dataset_keys[org_id], test_data_sample_keys=[self.test_data_sample_keys[org_id]], metric_keys=[self.metric_key], ) self.test_data_nodes.append(test_data_node) self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1) def run_compute_plan(self): algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"]) self.compute_plan = execute_experiment( client=self.algo_provider, algo=self.algorithm, strategy=self.strategy, train_data_nodes=self.train_data_nodes, evaluation_strategy=self.eval_strategy, aggregation_node=self.aggregation_node, num_rounds=self.NUM_ROUNDS, experiment_folder=self.experiment_folder, dependencies=algo_deps, ) def algo_generator(model, criterion, optimizer, index_generator, dataset, seed): class MyAlgo(TorchFedAvgAlgo): def __init__(self): super().__init__( model=model, criterion=criterion, optimizer=optimizer, index_generator=index_generator, dataset=dataset, seed=seed, ) return MyAlgo