adamnarozniak's picture
Update app.py
da08cce verified
import gradio as gr
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import (
DirichletPartitioner,
IidPartitioner,
PathologicalPartitioner,
ShardPartitioner,
LinearPartitioner,
SquarePartitioner,
ExponentialPartitioner,
NaturalIdPartitioner
)
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt
partitioner_types = {
"DirichletPartitioner": DirichletPartitioner,
"IidPartitioner": IidPartitioner,
"PathologicalPartitioner": PathologicalPartitioner,
"ShardPartitioner": ShardPartitioner,
"LinearPartitioner": LinearPartitioner,
"SquarePartitioner": SquarePartitioner,
"ExponentialPartitioner": ExponentialPartitioner,
"NaturalIdPartitioner": NaturalIdPartitioner,
}
partitioner_parameters = {
"DirichletPartitioner": ["num_partitions", "alpha", "partition_by", "min_partition_size", "self_balancing"],
"IidPartitioner": ["num_partitions"],
"PathologicalPartitioner": ["num_partitions", "partition_by", "num_classes_per_partition", "class_assignment_mode"],
"ShardPartitioner": ["num_partitions", "partition_by", "num_shards_per_partition", "shard_size", "keep_incomplete_shard"],
"NaturalIdPartitioner": ["partition_by"],
"LinearPartitioner": ["num_partitions"],
"SquarePartitioner": ["num_partitions"],
"ExponentialPartitioner": ["num_partitions"],
}
def update_parameter_visibility(partitioner_type):
required_params = partitioner_parameters.get(partitioner_type, [])
updates = []
# For num_partitions_input
if "num_partitions" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For alpha_input
if "alpha" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For partition_by_input
if "partition_by" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For min_partition_size_input
if "min_partition_size" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For self_balancing_input
if "self_balancing" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For num_classes_per_partition_input
if "num_classes_per_partition" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For class_assignment_mode_input
if "class_assignment_mode" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For num_shards_per_partition_input
if "num_shards_per_partition" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For shard_size_input
if "shard_size" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
# For keep_incomplete_shard_input
if "keep_incomplete_shard" in required_params:
updates.append(gr.update(visible=True))
else:
updates.append(gr.update(visible=False))
return updates
def partition_and_plot(
dataset,
partitioner_type,
num_partitions,
alpha,
partition_by,
min_partition_size,
self_balancing,
num_classes_per_partition,
class_assignment_mode,
num_shards_per_partition,
shard_size,
keep_incomplete_shard,
label_name,
title,
legend,
verbose_labels,
size_unit,
partition_id_axis,
):
partitioner_params = {}
try:
if partitioner_type == "DirichletPartitioner":
partitioner_params = {
"num_partitions": int(num_partitions),
"partition_by": partition_by,
"alpha": float(alpha),
"min_partition_size": int(min_partition_size),
"self_balancing": self_balancing,
}
elif partitioner_type == "IidPartitioner":
partitioner_params = {
"num_partitions": int(num_partitions),
}
elif partitioner_type == "PathologicalPartitioner":
partitioner_params = {
"num_partitions": int(num_partitions),
"partition_by": partition_by,
"num_classes_per_partition": int(num_classes_per_partition),
"class_assignment_mode": class_assignment_mode,
}
elif partitioner_type == "ShardPartitioner":
partitioner_params = {
"num_partitions": int(num_partitions),
"partition_by": partition_by,
"num_shards_per_partition": int(num_shards_per_partition),
"shard_size": int(shard_size),
"keep_incomplete_shard": keep_incomplete_shard == "True",
}
elif partitioner_type == "NaturalIdPartitioner":
partitioner_params = {
"partition_by": partition_by,
}
elif partitioner_type in ["LinearPartitioner", "SquarePartitioner", "ExponentialPartitioner"]:
partitioner_params = {
"num_partitions": int(num_partitions),
}
partitioner_class = partitioner_types[partitioner_type]
partitioner = partitioner_class(**partitioner_params)
fds = FederatedDataset(
dataset=dataset,
partitioners={
"train": partitioner,
},
trust_remote_code=True,
)
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
partitioner=partitioner,
label_name=label_name,
title=title,
legend=legend,
verbose_labels=verbose_labels,
size_unit=size_unit,
partition_id_axis=partition_id_axis,
)
# Save plot to a file
plot_filename = "label_distribution.png"
figure.savefig(plot_filename, bbox_inches='tight')
# Generate the code
partitioner_params_str = "\n"
n_params = len(partitioner_params)
i = 0
for k, v in partitioner_params.items():
if isinstance(v, str):
v = f'"{v}"'
if i != (n_params - 1):
partitioner_params_str = partitioner_params_str + f"\t{k} = {v},\n"
else:
partitioner_params_str = partitioner_params_str + f"\t{k} = {v}\n"
i +=1
code = f"""
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import {partitioner_type}
from flwr_datasets.visualization import plot_label_distributions
partitioner = {partitioner_type}({partitioner_params_str})
fds = FederatedDataset(
dataset="{dataset}",
partitioners={{
"train": partitioner,
}},
trust_remote_code=True,
)
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
partitioner=partitioner,
label_name="label",
title="{title}",
legend={legend},
verbose_labels={verbose_labels},
size_unit="{size_unit}",
partition_id_axis="{partition_id_axis}",
)
"""
return plot_filename, code#, plot_filename # with df: plot_filename, code, dataframe, plot_filename
except Exception as e:
# Return error messages
error_message = str(e)
return None, f"Error: {error_message}", None, None
with gr.Blocks() as demo:
gr.Markdown("# Federated Dataset: Partitioning Visualization")
gr.Markdown("See partitioned datasets for Federated Learning experiments. The partitioning and visualization were created using `flwr-datasets`. To open in a new tab, click the [link](https://huggingface.co/spaces/flwrlabs/federated-learning-datasets-by-flwr-datasets).")
with gr.Row():
with gr.Column(scale=1):
# gr.Markdown("## Federated Dataset Parameters")
with gr.Accordion("Federated Dataset Parameters", open=True):
dataset_input = gr.Textbox(label="Dataset", value="cifar10")
partitioner_type_input = gr.Dropdown(label="Partitioner", choices=list(partitioner_types.keys()), value="DirichletPartitioner")
num_partitions_input = gr.Number(label="num_partitions", value=10, visible=True)
alpha_input = gr.Number(label="alpha", value=0.3, visible=True)
partition_by_input = gr.Textbox(label="partition_by", value="label", visible=True)
min_partition_size_input = gr.Number(label="min_partition_size", value=0, visible=True)
self_balancing_input = gr.Radio(label="self_balancing", choices=[True, False], value=False, visible=True)
num_classes_per_partition_input = gr.Number(label="num_classes_per_partition", value=2, visible=False)
class_assignment_mode_input = gr.Dropdown(label="class_assignment_mode", choices=["random", "first-deterministic", "deterministic"], value="first-deterministic", visible=False)
num_shards_per_partition_input = gr.Number(label="num_shards_per_partition", value=2, visible=False)
shard_size_input = gr.Number(label="shard_size", value=0, visible=False)
keep_incomplete_shard_input = gr.Radio(label="keep_incomplete_shard", choices=["True", "False"], value="True", visible=False)
with gr.Accordion("Plot Parameters", open=False):
label_name = gr.Textbox(label="label_name", value="label")
title = gr.Textbox(label="title", value="Per Partition Label Distribution")
# legend_title = gr.Textbox(label="legend_title", value=None)
legend = gr.Radio(label="legend", choices=[True, False], value=True)
verbose_labels = gr.Radio(label="verbose_labels", choices=[True, False], value=True)
size_unit = gr.Radio(label="size_unit", choices=["absolute", "percent"], value="absolute")
partition_id_axis = gr.Radio(label="partition_id_axis", choices=["x", "y"], value="x")
# Update parameter visibility when partitioner_type_input changes
partitioner_type_input.change(
fn=update_parameter_visibility,
inputs=[partitioner_type_input],
outputs=[
num_partitions_input,
alpha_input,
partition_by_input,
min_partition_size_input,
self_balancing_input,
num_classes_per_partition_input,
class_assignment_mode_input,
num_shards_per_partition_input,
shard_size_input,
keep_incomplete_shard_input
]
)
with gr.Column(scale=3, min_width=480):
gr.Markdown("## Label Distribution Plot")
plot_output = gr.Image(label="Label Distribution Plot")
submit_button = gr.Button("Partition and Plot", variant="primary")
# download_button = gr.DownloadButton(label="Download Plot", value="label_distribution.png")
gr.Markdown("## Code")
code_output = gr.Code(label="Code", language="python")
# Uncomment to show dataframe (note that it only works with header that is of type "string")
# gr.Markdown("## Partitioning DataFrame")
# dataframe_output = gr.Dataframe(label="Partitioning DataFrame")
size_skew_examples = gr.Examples(
examples=[
["cifar10", "IidPartitioner", 10],
["cifar10", "LinearPartitioner", 10],
["cifar10", "SquarePartitioner", 10],
["cifar10", "ExponentialPartitioner", 10],
],
inputs=[
dataset_input,
partitioner_type_input,
num_partitions_input,
],
label="Size Skew Examples",
)
dirichlet_examples = gr.Examples(
examples=[
["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "absolute"],
["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "percent"],
],
inputs=[
dataset_input,
partitioner_type_input,
num_partitions_input,
alpha_input,
partition_by_input,
min_partition_size_input,
self_balancing_input,
size_unit,
],
label="Dirichlet Examples",
)
pathological_examples = gr.Examples(
examples=[
["cifar10", "PathologicalPartitioner", 10, 2, "first-deterministic", "label"],
["cifar10", "PathologicalPartitioner", 10, 3, "deterministic", "label"],
],
inputs=[
dataset_input,
partitioner_type_input,
num_partitions_input,
num_classes_per_partition_input,
class_assignment_mode_input,
partition_by_input,
],
label="Pathological Examples",
)
markdown = gr.Markdown("See more tutorial, examples and documentation on [https://flower.ai/docs/datasets/index.html](https://flower.ai/docs/datasets/index.html).")
# Set up the event handler for the submit_button
submit_button.click(
fn=partition_and_plot,
inputs=[
dataset_input,
partitioner_type_input,
num_partitions_input,
alpha_input,
partition_by_input,
min_partition_size_input,
self_balancing_input,
num_classes_per_partition_input,
class_assignment_mode_input,
num_shards_per_partition_input,
shard_size_input,
keep_incomplete_shard_input,
label_name,
title,
legend,
verbose_labels,
size_unit,
partition_id_axis,
],
outputs=[
plot_output,
code_output,
# dataframe_output,
# download_button
]
)
if __name__ == "__main__":
demo.launch()