Spaces:
Runtime error
Runtime error
import pathlib | |
import shutil | |
from typing import Optional, List | |
from substra import Client, BackendType | |
from substra.sdk.schemas import ( | |
DatasetSpec, | |
Permissions, | |
DataSampleSpec | |
) | |
from substrafl.strategies import Strategy | |
from substrafl.dependency import Dependency | |
from substrafl.remote.register import add_metric | |
from substrafl.index_generator import NpIndexGenerator | |
from substrafl.algorithms.pytorch import TorchFedAvgAlgo | |
from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode | |
from substrafl.evaluation_strategy import EvaluationStrategy | |
from substrafl.experiment import execute_experiment | |
from substra.sdk.models import ComputePlan | |
from datasets import load_dataset, Dataset | |
from sklearn.metrics import accuracy_score | |
import numpy as np | |
import torch | |
class SubstraRunner: | |
def __init__(self): | |
self.num_clients = 3 | |
self.clients = {} | |
self.algo_provider: Optional[Client] = None | |
self.datasets: List[Dataset] = [] | |
self.test_dataset: Optional[Dataset] = None | |
self.path = pathlib.Path(__file__).parent.resolve() | |
self.dataset_keys = {} | |
self.train_data_sample_keys = {} | |
self.test_data_sample_keys = {} | |
self.metric_key: Optional[str] = None | |
NUM_UPDATES = 100 | |
BATCH_SIZE = 32 | |
self.index_generator = NpIndexGenerator( | |
batch_size=BATCH_SIZE, | |
num_updates=NUM_UPDATES, | |
) | |
self.algorithm: Optional[TorchFedAvgAlgo] = None | |
self.strategy: Optional[Strategy] = None | |
self.aggregation_node: Optional[AggregationNode] = None | |
self.train_data_nodes = list() | |
self.test_data_nodes = list() | |
self.eval_strategy: Optional[EvaluationStrategy] = None | |
self.NUM_ROUNDS = 3 | |
self.compute_plan: Optional[ComputePlan] = None | |
self.experiment_folder = self.path / "experiment_summaries" | |
def set_up_clients(self): | |
self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS) | |
self.clients = { | |
c.organization_info().organization_id: c | |
for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)] | |
} | |
def prepare_data(self): | |
dataset = load_dataset("mnist", split="train").shuffle() | |
self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)] | |
self.test_dataset = load_dataset("mnist", split="test") | |
data_path = self.path / "data" | |
if data_path.exists() and data_path.is_dir(): | |
shutil.rmtree(data_path) | |
for i, client_id in enumerate(self.clients): | |
ds = self.datasets[i] | |
ds.save_to_disk(data_path / client_id / "train") | |
self.test_dataset.save_to_disk(data_path / client_id / "test") | |
def register_data(self): | |
for client_id, client in self.clients.items(): | |
permissions_dataset = Permissions(public=False, authorized_ids=[ | |
self.algo_provider.organization_info().organization_id | |
]) | |
dataset = DatasetSpec( | |
name="MNIST", | |
type="npy", | |
data_opener=self.path / pathlib.Path("dataset_assets/opener.py"), | |
description=self.path / pathlib.Path("dataset_assets/description.md"), | |
permissions=permissions_dataset, | |
logs_permission=permissions_dataset, | |
) | |
self.dataset_keys[client_id] = client.add_dataset(dataset) | |
assert self.dataset_keys[client_id], "Missing dataset key" | |
self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec( | |
data_manager_keys=[self.dataset_keys[client_id]], | |
path=self.path / "data" / client_id / "train", | |
)) | |
data_sample = DataSampleSpec( | |
data_manager_keys=[self.dataset_keys[client_id]], | |
path=self.path / "data" / client_id / "test", | |
) | |
self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample) | |
def register_metric(self): | |
permissions_metric = Permissions( | |
public=False, | |
authorized_ids=[ | |
self.algo_provider.organization_info().organization_id | |
] + list(self.clients.keys()) | |
) | |
metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"]) | |
def accuracy(datasamples, predictions_path): | |
y_true = datasamples["label"] | |
y_pred = np.load(predictions_path) | |
return accuracy_score(y_true, np.argmax(y_pred, axis=1)) | |
self.metric_key = add_metric( | |
client=self.algo_provider, | |
metric_function=accuracy, | |
permissions=permissions_metric, | |
dependencies=metric_deps, | |
) | |
def set_aggregation(self): | |
self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id) | |
for org_id in self.clients: | |
train_data_node = TrainDataNode( | |
organization_id=org_id, | |
data_manager_key=self.dataset_keys[org_id], | |
data_sample_keys=[self.train_data_sample_keys[org_id]], | |
) | |
self.train_data_nodes.append(train_data_node) | |
def set_testing(self): | |
for org_id in self.clients: | |
test_data_node = TestDataNode( | |
organization_id=org_id, | |
data_manager_key=self.dataset_keys[org_id], | |
test_data_sample_keys=[self.test_data_sample_keys[org_id]], | |
metric_keys=[self.metric_key], | |
) | |
self.test_data_nodes.append(test_data_node) | |
self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1) | |
def run_compute_plan(self): | |
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"]) | |
self.compute_plan = execute_experiment( | |
client=self.algo_provider, | |
algo=self.algorithm, | |
strategy=self.strategy, | |
train_data_nodes=self.train_data_nodes, | |
evaluation_strategy=self.eval_strategy, | |
aggregation_node=self.aggregation_node, | |
num_rounds=self.NUM_ROUNDS, | |
experiment_folder=self.experiment_folder, | |
dependencies=algo_deps, | |
) | |
def algo_generator(model, criterion, optimizer, index_generator, dataset, seed): | |
class MyAlgo(TorchFedAvgAlgo): | |
def __init__(self): | |
super().__init__( | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
index_generator=index_generator, | |
dataset=dataset, | |
seed=seed, | |
) | |
return MyAlgo | |