MMDT-radar / app.py
polaris73's picture
update font
bf25481
import gradio as gr
from generate_plot import generate_main_plot, generate_sub_plot
from utils.score_extract.ood_agg import ood_t2i_agg, ood_i2t_agg
from utils.score_extract.hallucination_agg import hallucination_t2i_agg, hallucination_i2t_agg
from utils.score_extract.safety_agg import safety_t2i_agg, safety_i2t_agg
from utils.score_extract.adversarial_robustness_agg import adversarial_robustness_t2i_agg, adversarial_robustness_i2t_agg
from utils.score_extract.fairness_agg import fairness_t2i_agg, fairness_i2t_agg
from utils.score_extract.privacy_agg import privacy_t2i_agg, privacy_i2t_agg
t2i_models = [ # Average time spent running the following example
"dall-e-2",
"dall-e-3",
"DeepFloyd/IF-I-M-v1.0", # 15.372
"dreamlike-art/dreamlike-photoreal-2.0", # 3.526
"prompthero/openjourney-v4", # 4.981
"stabilityai/stable-diffusion-xl-base-1.0", # 7.463
]
i2t_models = [ # Average time spent running the following example
"gpt-4-vision-preview",
"gpt-4o-2024-05-13",
"llava-hf/llava-v1.6-vicuna-7b-hf"
]
perspectives = ["Safety", "Fairness", "Hallucination", "Privacy", "Adv", "OOD"]
main_scores_t2i = {}
main_scores_i2t = {}
sub_scores_t2i = {}
sub_scores_i2t = {}
for model in t2i_models:
model = model.split("/")[-1]
main_scores_t2i[model] = {}
for perspective in perspectives:
if perspective not in sub_scores_t2i.keys():
sub_scores_t2i[perspective] = {}
if perspective == "Hallucination":
main_scores_t2i[model][perspective] = hallucination_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = hallucination_t2i_agg(model, "./data/results")["subscenarios"]
elif perspective == "Safety":
main_scores_t2i[model][perspective] = safety_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = safety_t2i_agg(model, "./data/results")["subscenarios"]
elif perspective == "Adv":
main_scores_t2i[model][perspective] = adversarial_robustness_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = adversarial_robustness_t2i_agg(model, "./data/results")["subscenarios"]
elif perspective == "Fairness":
main_scores_t2i[model][perspective] = fairness_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = fairness_t2i_agg(model, "./data/results")["subscenarios"]
elif perspective == "Privacy":
main_scores_t2i[model][perspective] = privacy_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = privacy_t2i_agg(model, "./data/results")["subscenarios"]
elif perspective == "OOD":
main_scores_t2i[model][perspective] = ood_t2i_agg(model, "./data/results")["score"]
sub_scores_t2i[perspective][model] = ood_t2i_agg(model, "./data/results")["subscenarios"]
else:
raise ValueError("Invalid perspective")
for model in i2t_models:
model = model.split("/")[-1]
main_scores_i2t[model] = {}
for perspective in perspectives:
if perspective not in sub_scores_i2t.keys():
sub_scores_i2t[perspective] = {}
if perspective == "Hallucination":
main_scores_i2t[model][perspective] = hallucination_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = hallucination_i2t_agg(model, "./data/results")["subscenarios"]
elif perspective == "Safety":
main_scores_i2t[model][perspective] = safety_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = safety_i2t_agg(model, "./data/results")["subscenarios"]
elif perspective == "Adv":
main_scores_i2t[model][perspective] = adversarial_robustness_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = adversarial_robustness_i2t_agg(model, "./data/results")["subscenarios"]
elif perspective == "Fairness":
main_scores_i2t[model][perspective] = fairness_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = fairness_i2t_agg(model, "./data/results")["subscenarios"]
elif perspective == "Privacy":
main_scores_i2t[model][perspective] = privacy_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = privacy_i2t_agg(model, "./data/results")["subscenarios"]
elif perspective == "OOD":
main_scores_i2t[model][perspective] = ood_i2t_agg(model, "./data/results")["score"]
sub_scores_i2t[perspective][model] = ood_i2t_agg(model, "./data/results")["subscenarios"]
else:
raise ValueError("Invalid perspective")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column(visible=True) as output_col:
with gr.Row(visible=True) as report_col:
curr_select = gr.Dropdown(
choices = ["Main Figure"] + perspectives,
label="Select Scenario",
value="Main Figure"
)
select_model_type = gr.Dropdown(
choices = ["T2I", "I2T"],
label = "Select Model Type",
value = "T2I"
)
gr.Markdown("# Overall statistics")
plot = gr.Plot(value=generate_main_plot(t2i_models, main_scores_t2i))
def radar(model_type, perspective):
perspectives_name = perspectives + ["Main Figure"]
if model_type == "T2I":
models = t2i_models
main_scores = main_scores_t2i
sub_scores = sub_scores_t2i
else:
models = i2t_models
main_scores = main_scores_i2t
sub_scores = sub_scores_i2t
if len(perspective) == 0 or perspective == "Main Figure":
fig = generate_main_plot(models, main_scores)
select = gr.Dropdown(choices=perspectives_name, value="Main Figure", label="Select Scenario")
type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type)
else:
fig = generate_sub_plot(models, sub_scores, perspective)
select = gr.Dropdown(choices=perspectives_name, value=perspective, label="Select Scenario")
type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type)
return {plot: fig, curr_select: select, select_model_type: type_dropdown}
gr.on(triggers=[curr_select.change, select_model_type.change], fn=radar, inputs=[select_model_type, curr_select], outputs=[plot, curr_select, select_model_type])
if __name__ == "__main__":
demo.queue().launch()