NimaBoscarino's picture
WIP: Substra orchestrator
04a30fc
raw
history blame
6.95 kB
import pathlib
import shutil
from typing import Optional, List
from substra import Client, BackendType
from substra.sdk.schemas import (
DatasetSpec,
Permissions,
DataSampleSpec
)
from substrafl.strategies import Strategy
from substrafl.dependency import Dependency
from substrafl.remote.register import add_metric
from substrafl.index_generator import NpIndexGenerator
from substrafl.algorithms.pytorch import TorchFedAvgAlgo
from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode
from substrafl.evaluation_strategy import EvaluationStrategy
from substrafl.experiment import execute_experiment
from substra.sdk.models import ComputePlan
from datasets import load_dataset, Dataset
from sklearn.metrics import accuracy_score
import numpy as np
import torch
class SubstraRunner:
def __init__(self):
self.num_clients = 3
self.clients = {}
self.algo_provider: Optional[Client] = None
self.datasets: List[Dataset] = []
self.test_dataset: Optional[Dataset] = None
self.path = pathlib.Path(__file__).parent.resolve()
self.dataset_keys = {}
self.train_data_sample_keys = {}
self.test_data_sample_keys = {}
self.metric_key: Optional[str] = None
NUM_UPDATES = 100
BATCH_SIZE = 32
self.index_generator = NpIndexGenerator(
batch_size=BATCH_SIZE,
num_updates=NUM_UPDATES,
)
self.algorithm: Optional[TorchFedAvgAlgo] = None
self.strategy: Optional[Strategy] = None
self.aggregation_node: Optional[AggregationNode] = None
self.train_data_nodes = list()
self.test_data_nodes = list()
self.eval_strategy: Optional[EvaluationStrategy] = None
self.NUM_ROUNDS = 3
self.compute_plan: Optional[ComputePlan] = None
self.experiment_folder = self.path / "experiment_summaries"
def set_up_clients(self):
self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS)
self.clients = {
c.organization_info().organization_id: c
for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)]
}
def prepare_data(self):
dataset = load_dataset("mnist", split="train").shuffle()
self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)]
self.test_dataset = load_dataset("mnist", split="test")
data_path = self.path / "data"
if data_path.exists() and data_path.is_dir():
shutil.rmtree(data_path)
for i, client_id in enumerate(self.clients):
ds = self.datasets[i]
ds.save_to_disk(data_path / client_id / "train")
self.test_dataset.save_to_disk(data_path / client_id / "test")
def register_data(self):
for client_id, client in self.clients.items():
permissions_dataset = Permissions(public=False, authorized_ids=[
self.algo_provider.organization_info().organization_id
])
dataset = DatasetSpec(
name="MNIST",
type="npy",
data_opener=self.path / pathlib.Path("dataset_assets/opener.py"),
description=self.path / pathlib.Path("dataset_assets/description.md"),
permissions=permissions_dataset,
logs_permission=permissions_dataset,
)
self.dataset_keys[client_id] = client.add_dataset(dataset)
assert self.dataset_keys[client_id], "Missing dataset key"
self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec(
data_manager_keys=[self.dataset_keys[client_id]],
path=self.path / "data" / client_id / "train",
))
data_sample = DataSampleSpec(
data_manager_keys=[self.dataset_keys[client_id]],
path=self.path / "data" / client_id / "test",
)
self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample)
def register_metric(self):
permissions_metric = Permissions(
public=False,
authorized_ids=[
self.algo_provider.organization_info().organization_id
] + list(self.clients.keys())
)
metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])
def accuracy(datasamples, predictions_path):
y_true = datasamples["label"]
y_pred = np.load(predictions_path)
return accuracy_score(y_true, np.argmax(y_pred, axis=1))
self.metric_key = add_metric(
client=self.algo_provider,
metric_function=accuracy,
permissions=permissions_metric,
dependencies=metric_deps,
)
def set_aggregation(self):
self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id)
for org_id in self.clients:
train_data_node = TrainDataNode(
organization_id=org_id,
data_manager_key=self.dataset_keys[org_id],
data_sample_keys=[self.train_data_sample_keys[org_id]],
)
self.train_data_nodes.append(train_data_node)
def set_testing(self):
for org_id in self.clients:
test_data_node = TestDataNode(
organization_id=org_id,
data_manager_key=self.dataset_keys[org_id],
test_data_sample_keys=[self.test_data_sample_keys[org_id]],
metric_keys=[self.metric_key],
)
self.test_data_nodes.append(test_data_node)
self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1)
def run_compute_plan(self):
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])
self.compute_plan = execute_experiment(
client=self.algo_provider,
algo=self.algorithm,
strategy=self.strategy,
train_data_nodes=self.train_data_nodes,
evaluation_strategy=self.eval_strategy,
aggregation_node=self.aggregation_node,
num_rounds=self.NUM_ROUNDS,
experiment_folder=self.experiment_folder,
dependencies=algo_deps,
)
def algo_generator(model, criterion, optimizer, index_generator, dataset, seed):
class MyAlgo(TorchFedAvgAlgo):
def __init__(self):
super().__init__(
model=model,
criterion=criterion,
optimizer=optimizer,
index_generator=index_generator,
dataset=dataset,
seed=seed,
)
return MyAlgo