|
import gradio as gr |
|
from flwr_datasets import FederatedDataset |
|
from flwr_datasets.partitioner import ( |
|
DirichletPartitioner, |
|
IidPartitioner, |
|
PathologicalPartitioner, |
|
ShardPartitioner, |
|
LinearPartitioner, |
|
SquarePartitioner, |
|
ExponentialPartitioner, |
|
NaturalIdPartitioner |
|
) |
|
from flwr_datasets.visualization import plot_label_distributions |
|
import matplotlib.pyplot as plt |
|
|
|
partitioner_types = { |
|
"DirichletPartitioner": DirichletPartitioner, |
|
"IidPartitioner": IidPartitioner, |
|
"PathologicalPartitioner": PathologicalPartitioner, |
|
"ShardPartitioner": ShardPartitioner, |
|
"LinearPartitioner": LinearPartitioner, |
|
"SquarePartitioner": SquarePartitioner, |
|
"ExponentialPartitioner": ExponentialPartitioner, |
|
"NaturalIdPartitioner": NaturalIdPartitioner, |
|
} |
|
|
|
partitioner_parameters = { |
|
"DirichletPartitioner": ["num_partitions", "alpha", "partition_by", "min_partition_size", "self_balancing"], |
|
"IidPartitioner": ["num_partitions"], |
|
"PathologicalPartitioner": ["num_partitions", "partition_by", "num_classes_per_partition", "class_assignment_mode"], |
|
"ShardPartitioner": ["num_partitions", "partition_by", "num_shards_per_partition", "shard_size", "keep_incomplete_shard"], |
|
"NaturalIdPartitioner": ["partition_by"], |
|
"LinearPartitioner": ["num_partitions"], |
|
"SquarePartitioner": ["num_partitions"], |
|
"ExponentialPartitioner": ["num_partitions"], |
|
} |
|
|
|
def update_parameter_visibility(partitioner_type): |
|
required_params = partitioner_parameters.get(partitioner_type, []) |
|
updates = [] |
|
|
|
if "num_partitions" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "alpha" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "partition_by" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "min_partition_size" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "self_balancing" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "num_classes_per_partition" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "class_assignment_mode" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "num_shards_per_partition" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "shard_size" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
|
|
if "keep_incomplete_shard" in required_params: |
|
updates.append(gr.update(visible=True)) |
|
else: |
|
updates.append(gr.update(visible=False)) |
|
return updates |
|
|
|
def partition_and_plot( |
|
dataset, |
|
partitioner_type, |
|
num_partitions, |
|
alpha, |
|
partition_by, |
|
min_partition_size, |
|
self_balancing, |
|
num_classes_per_partition, |
|
class_assignment_mode, |
|
num_shards_per_partition, |
|
shard_size, |
|
keep_incomplete_shard, |
|
label_name, |
|
title, |
|
legend, |
|
verbose_labels, |
|
size_unit, |
|
partition_id_axis, |
|
): |
|
partitioner_params = {} |
|
try: |
|
if partitioner_type == "DirichletPartitioner": |
|
partitioner_params = { |
|
"num_partitions": int(num_partitions), |
|
"partition_by": partition_by, |
|
"alpha": float(alpha), |
|
"min_partition_size": int(min_partition_size), |
|
"self_balancing": self_balancing, |
|
} |
|
elif partitioner_type == "IidPartitioner": |
|
partitioner_params = { |
|
"num_partitions": int(num_partitions), |
|
} |
|
elif partitioner_type == "PathologicalPartitioner": |
|
partitioner_params = { |
|
"num_partitions": int(num_partitions), |
|
"partition_by": partition_by, |
|
"num_classes_per_partition": int(num_classes_per_partition), |
|
"class_assignment_mode": class_assignment_mode, |
|
} |
|
elif partitioner_type == "ShardPartitioner": |
|
partitioner_params = { |
|
"num_partitions": int(num_partitions), |
|
"partition_by": partition_by, |
|
"num_shards_per_partition": int(num_shards_per_partition), |
|
"shard_size": int(shard_size), |
|
"keep_incomplete_shard": keep_incomplete_shard == "True", |
|
} |
|
elif partitioner_type == "NaturalIdPartitioner": |
|
partitioner_params = { |
|
"partition_by": partition_by, |
|
} |
|
elif partitioner_type in ["LinearPartitioner", "SquarePartitioner", "ExponentialPartitioner"]: |
|
partitioner_params = { |
|
"num_partitions": int(num_partitions), |
|
} |
|
|
|
partitioner_class = partitioner_types[partitioner_type] |
|
partitioner = partitioner_class(**partitioner_params) |
|
fds = FederatedDataset( |
|
dataset=dataset, |
|
partitioners={ |
|
"train": partitioner, |
|
}, |
|
trust_remote_code=True, |
|
) |
|
partitioner = fds.partitioners["train"] |
|
figure, axis, dataframe = plot_label_distributions( |
|
partitioner=partitioner, |
|
label_name=label_name, |
|
title=title, |
|
legend=legend, |
|
verbose_labels=verbose_labels, |
|
size_unit=size_unit, |
|
partition_id_axis=partition_id_axis, |
|
) |
|
|
|
|
|
plot_filename = "label_distribution.png" |
|
figure.savefig(plot_filename, bbox_inches='tight') |
|
|
|
|
|
partitioner_params_str = "\n" |
|
n_params = len(partitioner_params) |
|
i = 0 |
|
for k, v in partitioner_params.items(): |
|
if isinstance(v, str): |
|
v = f'"{v}"' |
|
if i != (n_params - 1): |
|
partitioner_params_str = partitioner_params_str + f"\t{k} = {v},\n" |
|
else: |
|
partitioner_params_str = partitioner_params_str + f"\t{k} = {v}\n" |
|
i +=1 |
|
|
|
code = f""" |
|
from flwr_datasets import FederatedDataset |
|
from flwr_datasets.partitioner import {partitioner_type} |
|
from flwr_datasets.visualization import plot_label_distributions |
|
|
|
partitioner = {partitioner_type}({partitioner_params_str}) |
|
fds = FederatedDataset( |
|
dataset="{dataset}", |
|
partitioners={{ |
|
"train": partitioner, |
|
}}, |
|
trust_remote_code=True, |
|
) |
|
partitioner = fds.partitioners["train"] |
|
figure, axis, dataframe = plot_label_distributions( |
|
partitioner=partitioner, |
|
label_name="label", |
|
title="{title}", |
|
legend={legend}, |
|
verbose_labels={verbose_labels}, |
|
size_unit="{size_unit}", |
|
partition_id_axis="{partition_id_axis}", |
|
) |
|
""" |
|
return plot_filename, code |
|
except Exception as e: |
|
|
|
error_message = str(e) |
|
return None, f"Error: {error_message}", None, None |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Federated Dataset: Partitioning Visualization") |
|
gr.Markdown("See partitioned datasets for Federated Learning experiments. The partitioning and visualization were created using `flwr-datasets`. To open in a new tab, click the [link](https://huggingface.co/spaces/flwrlabs/federated-learning-datasets-by-flwr-datasets).") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
with gr.Accordion("Federated Dataset Parameters", open=True): |
|
dataset_input = gr.Textbox(label="Dataset", value="cifar10") |
|
partitioner_type_input = gr.Dropdown(label="Partitioner", choices=list(partitioner_types.keys()), value="DirichletPartitioner") |
|
num_partitions_input = gr.Number(label="num_partitions", value=10, visible=True) |
|
alpha_input = gr.Number(label="alpha", value=0.3, visible=True) |
|
partition_by_input = gr.Textbox(label="partition_by", value="label", visible=True) |
|
min_partition_size_input = gr.Number(label="min_partition_size", value=0, visible=True) |
|
self_balancing_input = gr.Radio(label="self_balancing", choices=[True, False], value=False, visible=True) |
|
|
|
num_classes_per_partition_input = gr.Number(label="num_classes_per_partition", value=2, visible=False) |
|
class_assignment_mode_input = gr.Dropdown(label="class_assignment_mode", choices=["random", "first-deterministic", "deterministic"], value="first-deterministic", visible=False) |
|
num_shards_per_partition_input = gr.Number(label="num_shards_per_partition", value=2, visible=False) |
|
shard_size_input = gr.Number(label="shard_size", value=0, visible=False) |
|
keep_incomplete_shard_input = gr.Radio(label="keep_incomplete_shard", choices=["True", "False"], value="True", visible=False) |
|
with gr.Accordion("Plot Parameters", open=False): |
|
label_name = gr.Textbox(label="label_name", value="label") |
|
title = gr.Textbox(label="title", value="Per Partition Label Distribution") |
|
|
|
legend = gr.Radio(label="legend", choices=[True, False], value=True) |
|
verbose_labels = gr.Radio(label="verbose_labels", choices=[True, False], value=True) |
|
size_unit = gr.Radio(label="size_unit", choices=["absolute", "percent"], value="absolute") |
|
partition_id_axis = gr.Radio(label="partition_id_axis", choices=["x", "y"], value="x") |
|
|
|
|
|
|
|
|
|
partitioner_type_input.change( |
|
fn=update_parameter_visibility, |
|
inputs=[partitioner_type_input], |
|
outputs=[ |
|
num_partitions_input, |
|
alpha_input, |
|
partition_by_input, |
|
min_partition_size_input, |
|
self_balancing_input, |
|
num_classes_per_partition_input, |
|
class_assignment_mode_input, |
|
num_shards_per_partition_input, |
|
shard_size_input, |
|
keep_incomplete_shard_input |
|
] |
|
) |
|
with gr.Column(scale=3, min_width=480): |
|
gr.Markdown("## Label Distribution Plot") |
|
plot_output = gr.Image(label="Label Distribution Plot") |
|
submit_button = gr.Button("Partition and Plot", variant="primary") |
|
|
|
gr.Markdown("## Code") |
|
code_output = gr.Code(label="Code", language="python") |
|
|
|
|
|
|
|
size_skew_examples = gr.Examples( |
|
examples=[ |
|
["cifar10", "IidPartitioner", 10], |
|
["cifar10", "LinearPartitioner", 10], |
|
["cifar10", "SquarePartitioner", 10], |
|
["cifar10", "ExponentialPartitioner", 10], |
|
], |
|
inputs=[ |
|
dataset_input, |
|
partitioner_type_input, |
|
num_partitions_input, |
|
], |
|
label="Size Skew Examples", |
|
) |
|
|
|
dirichlet_examples = gr.Examples( |
|
examples=[ |
|
["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "absolute"], |
|
["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "percent"], |
|
], |
|
inputs=[ |
|
dataset_input, |
|
partitioner_type_input, |
|
num_partitions_input, |
|
alpha_input, |
|
partition_by_input, |
|
min_partition_size_input, |
|
self_balancing_input, |
|
size_unit, |
|
], |
|
label="Dirichlet Examples", |
|
) |
|
|
|
pathological_examples = gr.Examples( |
|
examples=[ |
|
["cifar10", "PathologicalPartitioner", 10, 2, "first-deterministic", "label"], |
|
["cifar10", "PathologicalPartitioner", 10, 3, "deterministic", "label"], |
|
], |
|
inputs=[ |
|
dataset_input, |
|
partitioner_type_input, |
|
num_partitions_input, |
|
num_classes_per_partition_input, |
|
class_assignment_mode_input, |
|
partition_by_input, |
|
], |
|
label="Pathological Examples", |
|
) |
|
markdown = gr.Markdown("See more tutorial, examples and documentation on [https://flower.ai/docs/datasets/index.html](https://flower.ai/docs/datasets/index.html).") |
|
|
|
|
|
submit_button.click( |
|
fn=partition_and_plot, |
|
inputs=[ |
|
dataset_input, |
|
partitioner_type_input, |
|
num_partitions_input, |
|
alpha_input, |
|
partition_by_input, |
|
min_partition_size_input, |
|
self_balancing_input, |
|
num_classes_per_partition_input, |
|
class_assignment_mode_input, |
|
num_shards_per_partition_input, |
|
shard_size_input, |
|
keep_incomplete_shard_input, |
|
label_name, |
|
title, |
|
legend, |
|
verbose_labels, |
|
size_unit, |
|
partition_id_axis, |
|
], |
|
outputs=[ |
|
plot_output, |
|
code_output, |
|
|
|
|
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |