import os import csv import random import shutil import pandas as pd import gradio as gr DATA_DIR = "./data" def list_to_csv(list_of_dicts: list, filename: str): keys = dict(list_of_dicts[0]).keys() with open(filename, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=keys) writer.writeheader() for data in list_of_dicts: writer.writerow(data) return filename def random_allocation(participants: int, ratio: list): total = sum(ratio) splits = [0] for i, r in enumerate(ratio): splits.append(splits[i] + int(1.0 * r / total * participants)) splits[-1] = participants partist = list(range(1, participants + 1)) random.shuffle(partist) allocation = [] groups = len(ratio) for i in range(groups): start = splits[i] end = splits[i + 1] for participant in partist[start:end]: allocation.append({"id": participant, "group": i + 1}) sorted_data = sorted(allocation, key=lambda x: x["id"]) filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv") return filename, pd.DataFrame(sorted_data) def infer(participants: float, ratios: str): if os.path.exists(DATA_DIR): shutil.rmtree(DATA_DIR) os.makedirs(DATA_DIR, exist_ok=True) ratio_list = ratios.split(":") ratio = [] try: for r in ratio_list: current_ratio = float(r.strip()) if current_ratio > 0: ratio.append(current_ratio) except Exception: print("Invalid input of ratio!") return random_allocation(int(participants), ratio) if __name__ == "__main__": gr.Interface( fn=infer, inputs=[ gr.Number( label="Number of participants", value=10, ), gr.Textbox( label="Grouping ratio", value="8:1:1", show_copy_button=True, ), ], outputs=[ gr.File(label="Download data CSV"), gr.Dataframe(label="Data preview"), ], title="Randomized Controlled Trial Generator", description="Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.", flagging_mode="never", ).launch()