Spaces:
Running
Running
import gradio as gr | |
from generate_plot import generate_main_plot, generate_sub_plot | |
from utils.score_extract.ood_agg import ood_t2i_agg, ood_i2t_agg | |
from utils.score_extract.hallucination_agg import hallucination_t2i_agg, hallucination_i2t_agg | |
from utils.score_extract.safety_agg import safety_t2i_agg, safety_i2t_agg | |
from utils.score_extract.adversarial_robustness_agg import adversarial_robustness_t2i_agg, adversarial_robustness_i2t_agg | |
from utils.score_extract.fairness_agg import fairness_t2i_agg, fairness_i2t_agg | |
from utils.score_extract.privacy_agg import privacy_t2i_agg, privacy_i2t_agg | |
t2i_models = [ # Average time spent running the following example | |
"dall-e-2", | |
"dall-e-3", | |
"DeepFloyd/IF-I-M-v1.0", # 15.372 | |
"dreamlike-art/dreamlike-photoreal-2.0", # 3.526 | |
"prompthero/openjourney-v4", # 4.981 | |
"stabilityai/stable-diffusion-xl-base-1.0", # 7.463 | |
] | |
i2t_models = [ # Average time spent running the following example | |
"gpt-4-vision-preview", | |
"gpt-4o-2024-05-13", | |
"llava-hf/llava-v1.6-vicuna-7b-hf" | |
] | |
perspectives = ["Safety", "Fairness", "Hallucination", "Privacy", "Adv", "OOD"] | |
main_scores_t2i = {} | |
main_scores_i2t = {} | |
sub_scores_t2i = {} | |
sub_scores_i2t = {} | |
for model in t2i_models: | |
model = model.split("/")[-1] | |
main_scores_t2i[model] = {} | |
for perspective in perspectives: | |
if perspective not in sub_scores_t2i.keys(): | |
sub_scores_t2i[perspective] = {} | |
if perspective == "Hallucination": | |
main_scores_t2i[model][perspective] = hallucination_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = hallucination_t2i_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Safety": | |
main_scores_t2i[model][perspective] = safety_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = safety_t2i_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Adv": | |
main_scores_t2i[model][perspective] = adversarial_robustness_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = adversarial_robustness_t2i_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Fairness": | |
main_scores_t2i[model][perspective] = fairness_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = fairness_t2i_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Privacy": | |
main_scores_t2i[model][perspective] = privacy_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = privacy_t2i_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "OOD": | |
main_scores_t2i[model][perspective] = ood_t2i_agg(model, "./data/results")["score"] | |
sub_scores_t2i[perspective][model] = ood_t2i_agg(model, "./data/results")["subscenarios"] | |
else: | |
raise ValueError("Invalid perspective") | |
for model in i2t_models: | |
model = model.split("/")[-1] | |
main_scores_i2t[model] = {} | |
for perspective in perspectives: | |
if perspective not in sub_scores_i2t.keys(): | |
sub_scores_i2t[perspective] = {} | |
if perspective == "Hallucination": | |
main_scores_i2t[model][perspective] = hallucination_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = hallucination_i2t_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Safety": | |
main_scores_i2t[model][perspective] = safety_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = safety_i2t_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Adv": | |
main_scores_i2t[model][perspective] = adversarial_robustness_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = adversarial_robustness_i2t_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Fairness": | |
main_scores_i2t[model][perspective] = fairness_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = fairness_i2t_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "Privacy": | |
main_scores_i2t[model][perspective] = privacy_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = privacy_i2t_agg(model, "./data/results")["subscenarios"] | |
elif perspective == "OOD": | |
main_scores_i2t[model][perspective] = ood_i2t_agg(model, "./data/results")["score"] | |
sub_scores_i2t[perspective][model] = ood_i2t_agg(model, "./data/results")["subscenarios"] | |
else: | |
raise ValueError("Invalid perspective") | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Column(visible=True) as output_col: | |
with gr.Row(visible=True) as report_col: | |
curr_select = gr.Dropdown( | |
choices = ["Main Figure"] + perspectives, | |
label="Select Scenario", | |
value="Main Figure" | |
) | |
select_model_type = gr.Dropdown( | |
choices = ["T2I", "I2T"], | |
label = "Select Model Type", | |
value = "T2I" | |
) | |
gr.Markdown("# Overall statistics") | |
plot = gr.Plot(value=generate_main_plot(t2i_models, main_scores_t2i)) | |
def radar(model_type, perspective): | |
perspectives_name = perspectives + ["Main Figure"] | |
if model_type == "T2I": | |
models = t2i_models | |
main_scores = main_scores_t2i | |
sub_scores = sub_scores_t2i | |
else: | |
models = i2t_models | |
main_scores = main_scores_i2t | |
sub_scores = sub_scores_i2t | |
if len(perspective) == 0 or perspective == "Main Figure": | |
fig = generate_main_plot(models, main_scores) | |
select = gr.Dropdown(choices=perspectives_name, value="Main Figure", label="Select Scenario") | |
type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type) | |
else: | |
fig = generate_sub_plot(models, sub_scores, perspective) | |
select = gr.Dropdown(choices=perspectives_name, value=perspective, label="Select Scenario") | |
type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type) | |
return {plot: fig, curr_select: select, select_model_type: type_dropdown} | |
gr.on(triggers=[curr_select.change, select_model_type.change], fn=radar, inputs=[select_model_type, curr_select], outputs=[plot, curr_select, select_model_type]) | |
if __name__ == "__main__": | |
demo.queue().launch() | |