import gradio as gr from flwr_datasets import FederatedDataset from flwr_datasets.partitioner import ( DirichletPartitioner, IidPartitioner, PathologicalPartitioner, ShardPartitioner, LinearPartitioner, SquarePartitioner, ExponentialPartitioner, NaturalIdPartitioner ) from flwr_datasets.visualization import plot_label_distributions import matplotlib.pyplot as plt partitioner_types = { "DirichletPartitioner": DirichletPartitioner, "IidPartitioner": IidPartitioner, "PathologicalPartitioner": PathologicalPartitioner, "ShardPartitioner": ShardPartitioner, "LinearPartitioner": LinearPartitioner, "SquarePartitioner": SquarePartitioner, "ExponentialPartitioner": ExponentialPartitioner, "NaturalIdPartitioner": NaturalIdPartitioner, } partitioner_parameters = { "DirichletPartitioner": ["num_partitions", "alpha", "partition_by", "min_partition_size", "self_balancing"], "IidPartitioner": ["num_partitions"], "PathologicalPartitioner": ["num_partitions", "partition_by", "num_classes_per_partition", "class_assignment_mode"], "ShardPartitioner": ["num_partitions", "partition_by", "num_shards_per_partition", "shard_size", "keep_incomplete_shard"], "NaturalIdPartitioner": ["partition_by"], "LinearPartitioner": ["num_partitions"], "SquarePartitioner": ["num_partitions"], "ExponentialPartitioner": ["num_partitions"], } def update_parameter_visibility(partitioner_type): required_params = partitioner_parameters.get(partitioner_type, []) updates = [] # For num_partitions_input if "num_partitions" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For alpha_input if "alpha" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For partition_by_input if "partition_by" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For min_partition_size_input if "min_partition_size" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For self_balancing_input if "self_balancing" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For num_classes_per_partition_input if "num_classes_per_partition" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For class_assignment_mode_input if "class_assignment_mode" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For num_shards_per_partition_input if "num_shards_per_partition" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For shard_size_input if "shard_size" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) # For keep_incomplete_shard_input if "keep_incomplete_shard" in required_params: updates.append(gr.update(visible=True)) else: updates.append(gr.update(visible=False)) return updates def partition_and_plot( dataset, partitioner_type, num_partitions, alpha, partition_by, min_partition_size, self_balancing, num_classes_per_partition, class_assignment_mode, num_shards_per_partition, shard_size, keep_incomplete_shard, label_name, title, legend, verbose_labels, size_unit, partition_id_axis, ): partitioner_params = {} try: if partitioner_type == "DirichletPartitioner": partitioner_params = { "num_partitions": int(num_partitions), "partition_by": partition_by, "alpha": float(alpha), "min_partition_size": int(min_partition_size), "self_balancing": self_balancing, } elif partitioner_type == "IidPartitioner": partitioner_params = { "num_partitions": int(num_partitions), } elif partitioner_type == "PathologicalPartitioner": partitioner_params = { "num_partitions": int(num_partitions), "partition_by": partition_by, "num_classes_per_partition": int(num_classes_per_partition), "class_assignment_mode": class_assignment_mode, } elif partitioner_type == "ShardPartitioner": partitioner_params = { "num_partitions": int(num_partitions), "partition_by": partition_by, "num_shards_per_partition": int(num_shards_per_partition), "shard_size": int(shard_size), "keep_incomplete_shard": keep_incomplete_shard == "True", } elif partitioner_type == "NaturalIdPartitioner": partitioner_params = { "partition_by": partition_by, } elif partitioner_type in ["LinearPartitioner", "SquarePartitioner", "ExponentialPartitioner"]: partitioner_params = { "num_partitions": int(num_partitions), } partitioner_class = partitioner_types[partitioner_type] partitioner = partitioner_class(**partitioner_params) fds = FederatedDataset( dataset=dataset, partitioners={ "train": partitioner, }, trust_remote_code=True, ) partitioner = fds.partitioners["train"] figure, axis, dataframe = plot_label_distributions( partitioner=partitioner, label_name=label_name, title=title, legend=legend, verbose_labels=verbose_labels, size_unit=size_unit, partition_id_axis=partition_id_axis, ) # Save plot to a file plot_filename = "label_distribution.png" figure.savefig(plot_filename, bbox_inches='tight') # Generate the code partitioner_params_str = "\n" n_params = len(partitioner_params) i = 0 for k, v in partitioner_params.items(): if isinstance(v, str): v = f'"{v}"' if i != (n_params - 1): partitioner_params_str = partitioner_params_str + f"\t{k} = {v},\n" else: partitioner_params_str = partitioner_params_str + f"\t{k} = {v}\n" i +=1 code = f""" from flwr_datasets import FederatedDataset from flwr_datasets.partitioner import {partitioner_type} from flwr_datasets.visualization import plot_label_distributions partitioner = {partitioner_type}({partitioner_params_str}) fds = FederatedDataset( dataset="{dataset}", partitioners={{ "train": partitioner, }}, trust_remote_code=True, ) partitioner = fds.partitioners["train"] figure, axis, dataframe = plot_label_distributions( partitioner=partitioner, label_name="label", title="{title}", legend={legend}, verbose_labels={verbose_labels}, size_unit="{size_unit}", partition_id_axis="{partition_id_axis}", ) """ return plot_filename, code#, plot_filename # with df: plot_filename, code, dataframe, plot_filename except Exception as e: # Return error messages error_message = str(e) return None, f"Error: {error_message}", None, None with gr.Blocks() as demo: gr.Markdown("# Federated Dataset: Partitioning Visualization") gr.Markdown("See partitioned datasets for Federated Learning experiments. The partitioning and visualization were created using `flwr-datasets`. To open in a new tab, click the [link](https://huggingface.co/spaces/flwrlabs/federated-learning-datasets-by-flwr-datasets).") with gr.Row(): with gr.Column(scale=1): # gr.Markdown("## Federated Dataset Parameters") with gr.Accordion("Federated Dataset Parameters", open=True): dataset_input = gr.Textbox(label="Dataset", value="cifar10") partitioner_type_input = gr.Dropdown(label="Partitioner", choices=list(partitioner_types.keys()), value="DirichletPartitioner") num_partitions_input = gr.Number(label="num_partitions", value=10, visible=True) alpha_input = gr.Number(label="alpha", value=0.3, visible=True) partition_by_input = gr.Textbox(label="partition_by", value="label", visible=True) min_partition_size_input = gr.Number(label="min_partition_size", value=0, visible=True) self_balancing_input = gr.Radio(label="self_balancing", choices=[True, False], value=False, visible=True) num_classes_per_partition_input = gr.Number(label="num_classes_per_partition", value=2, visible=False) class_assignment_mode_input = gr.Dropdown(label="class_assignment_mode", choices=["random", "first-deterministic", "deterministic"], value="first-deterministic", visible=False) num_shards_per_partition_input = gr.Number(label="num_shards_per_partition", value=2, visible=False) shard_size_input = gr.Number(label="shard_size", value=0, visible=False) keep_incomplete_shard_input = gr.Radio(label="keep_incomplete_shard", choices=["True", "False"], value="True", visible=False) with gr.Accordion("Plot Parameters", open=False): label_name = gr.Textbox(label="label_name", value="label") title = gr.Textbox(label="title", value="Per Partition Label Distribution") # legend_title = gr.Textbox(label="legend_title", value=None) legend = gr.Radio(label="legend", choices=[True, False], value=True) verbose_labels = gr.Radio(label="verbose_labels", choices=[True, False], value=True) size_unit = gr.Radio(label="size_unit", choices=["absolute", "percent"], value="absolute") partition_id_axis = gr.Radio(label="partition_id_axis", choices=["x", "y"], value="x") # Update parameter visibility when partitioner_type_input changes partitioner_type_input.change( fn=update_parameter_visibility, inputs=[partitioner_type_input], outputs=[ num_partitions_input, alpha_input, partition_by_input, min_partition_size_input, self_balancing_input, num_classes_per_partition_input, class_assignment_mode_input, num_shards_per_partition_input, shard_size_input, keep_incomplete_shard_input ] ) with gr.Column(scale=3, min_width=480): gr.Markdown("## Label Distribution Plot") plot_output = gr.Image(label="Label Distribution Plot") submit_button = gr.Button("Partition and Plot", variant="primary") # download_button = gr.DownloadButton(label="Download Plot", value="label_distribution.png") gr.Markdown("## Code") code_output = gr.Code(label="Code", language="python") # Uncomment to show dataframe (note that it only works with header that is of type "string") # gr.Markdown("## Partitioning DataFrame") # dataframe_output = gr.Dataframe(label="Partitioning DataFrame") size_skew_examples = gr.Examples( examples=[ ["cifar10", "IidPartitioner", 10], ["cifar10", "LinearPartitioner", 10], ["cifar10", "SquarePartitioner", 10], ["cifar10", "ExponentialPartitioner", 10], ], inputs=[ dataset_input, partitioner_type_input, num_partitions_input, ], label="Size Skew Examples", ) dirichlet_examples = gr.Examples( examples=[ ["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "absolute"], ["cifar10", "DirichletPartitioner", 10, 0.1, "label", 0, False, "percent"], ], inputs=[ dataset_input, partitioner_type_input, num_partitions_input, alpha_input, partition_by_input, min_partition_size_input, self_balancing_input, size_unit, ], label="Dirichlet Examples", ) pathological_examples = gr.Examples( examples=[ ["cifar10", "PathologicalPartitioner", 10, 2, "first-deterministic", "label"], ["cifar10", "PathologicalPartitioner", 10, 3, "deterministic", "label"], ], inputs=[ dataset_input, partitioner_type_input, num_partitions_input, num_classes_per_partition_input, class_assignment_mode_input, partition_by_input, ], label="Pathological Examples", ) markdown = gr.Markdown("See more tutorial, examples and documentation on [https://flower.ai/docs/datasets/index.html](https://flower.ai/docs/datasets/index.html).") # Set up the event handler for the submit_button submit_button.click( fn=partition_and_plot, inputs=[ dataset_input, partitioner_type_input, num_partitions_input, alpha_input, partition_by_input, min_partition_size_input, self_balancing_input, num_classes_per_partition_input, class_assignment_mode_input, num_shards_per_partition_input, shard_size_input, keep_incomplete_shard_input, label_name, title, legend, verbose_labels, size_unit, partition_id_axis, ], outputs=[ plot_output, code_output, # dataframe_output, # download_button ] ) if __name__ == "__main__": demo.launch()