|
import os |
|
import csv |
|
import random |
|
import shutil |
|
import pandas as pd |
|
import gradio as gr |
|
|
|
DATA_DIR = "./data" |
|
|
|
|
|
def list_to_csv(list_of_dicts: list, filename: str): |
|
keys = dict(list_of_dicts[0]).keys() |
|
|
|
with open(filename, "w", newline="", encoding="utf-8") as csvfile: |
|
writer = csv.DictWriter(csvfile, fieldnames=keys) |
|
writer.writeheader() |
|
for data in list_of_dicts: |
|
writer.writerow(data) |
|
|
|
return filename |
|
|
|
|
|
def random_allocation(participants: int, ratio: list): |
|
total = sum(ratio) |
|
splits = [0] |
|
for i, r in enumerate(ratio): |
|
splits.append(splits[i] + int(1.0 * r / total * participants)) |
|
|
|
splits[-1] = participants |
|
partist = list(range(1, participants + 1)) |
|
random.shuffle(partist) |
|
allocation = [] |
|
groups = len(ratio) |
|
for i in range(groups): |
|
start = splits[i] |
|
end = splits[i + 1] |
|
for participant in partist[start:end]: |
|
allocation.append({"id": participant, "group": i + 1}) |
|
|
|
sorted_data = sorted(allocation, key=lambda x: x["id"]) |
|
filename = list_to_csv(sorted_data, f"{DATA_DIR}/output.csv") |
|
return filename, pd.DataFrame(sorted_data) |
|
|
|
|
|
def inference(participants: float, ratios: str): |
|
if os.path.exists(DATA_DIR): |
|
shutil.rmtree(DATA_DIR) |
|
|
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
ratio_list = ratios.split(":") |
|
ratio = [] |
|
try: |
|
for r in ratio_list: |
|
current_ratio = float(r.strip()) |
|
if current_ratio > 0: |
|
ratio.append(current_ratio) |
|
|
|
except Exception: |
|
print("Invalid input of ratio!") |
|
|
|
return random_allocation(int(participants), ratio) |
|
|
|
|
|
if __name__ == "__main__": |
|
gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.Number( |
|
label="输入参与者数量 (Number of participants)", |
|
value=10, |
|
), |
|
gr.Textbox(label="输入分组比率 (Grouping ratio)", value="8:1:1"), |
|
], |
|
outputs=[ |
|
gr.components.File(label="下载随机分组数据 CSV (Download data CSV)"), |
|
gr.Dataframe(label="随机分组数据预览 (Data preview)"), |
|
], |
|
title="随机对照试验随机数生成器<br>Randomized Controlled Trial Generator", |
|
description="输入参与者数量和分组比率,格式为用:隔开的数字,生成随机分组数据。<br>Enter the number of participants and the grouping ratio in the format of numbers separated by : to generate randomized grouping data.", |
|
allow_flagging=False, |
|
).launch() |
|
|