Spaces:
Sleeping
Sleeping
File size: 5,910 Bytes
201434d e2e5007 da6fc76 201434d 05dae02 201434d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import pandas as pd
MODEL_MAPPINGS = {
"gpt-4o-2024-05-13": "GPT-4o",
"gpt-4-0613": "GPT-4",
"gpt-4-turbo-2024-04-09": "GPT-4 Turbo",
"gpt-4-0125-preview": "GPT-4 Preview",
"gpt-3.5": "GPT-3.5",
"gpt-3.5-turbo-0125": "GPT-3.5 Turbo",
"claude-3-opus-20240229": "Claude-3 O",
"claude-3-sonnet-20240229": "Claude-3 S",
"claude-3-haiku-20240307": "Claude-3 H",
"claude-3-5-sonnet-20240620": "Claude-3.5 S",
"llama-2-70b-chat": "Llama-2 70b",
"llama-2-13b-chat": "Llama-2 13b",
"llama-2-7b-chat": "Llama-2 7b",
"llama-3-8b-chat": "Llama-3 8b",
"llama-3-70b-chat": "Llama-3 70b",
"codellama-70b-instruct": "Codellama 70b",
"mistral-large-2402": "Mistral Large",
"mistral-medium-2312": "Mistral Medium",
"open-mixtral-8x22b-instruct-v0.1": "Mixtral 8x22b",
"open-mixtral-8x7b-instruct": "Mixtral 8x7b",
"open-mistral-7b-instruct": "Mistral 7b",
"open-mistral-7b": "Mistral 7b",
"open-mixtral-8x22b": "Mixtral 8x22b",
"open-mixtral-8x7b": "Mixtral 8x7b",
"open-mistral-7b-instruct-v0.1": "Mistral 7b",
"dbrx-instruct": "DBRX",
"command-r-plus": "Command R Plus",
"gemma-7b-it": "Gemma 7b",
"gemma-2b-it": "Gemma 2b",
"gemini-1.5-pro-latest": "Gemini 1.5",
"gemini-pro": "Gemini 1.0",
"qwen1.5-7b-chat": "Qwen 1.5 7b",
"qwen1.5-14b-chat": "Qwen 1.5 14b",
"qwen1.5-32b-chat": "Qwen 1.5 32b",
"qwen1.5-72b-chat": "Qwen 1.5 72b",
"qwen1.5-0.5b-chat": "Qwen 1.5 0.5b",
"qwen1.5-1.8b-chat": "Qwen 1.5 1.8b",
"qwen2-72b-instruct": "Qwen 2 72b",
"codestral-2405": "Codestral"
}
resp_url = 'https://github.com/LAION-AI/AIW/raw/main/collected_responses/responses.jsonl'
df = pd.read_json(resp_url, lines=True)
df['model'] = df['model'].map(MODEL_MAPPINGS)
df['prompt'] = df[['prompt', 'prompt_id']].apply(lambda x: f"{x['prompt']} [{x['prompt_id']}]", axis=1)
model_list = df['model'].unique()
prompt_id_list = list(df['prompt'].unique())
prompt_id_list = sorted(prompt_id_list, key=lambda x: int(x.split('[')[1].split(']')[0]))
def response(num_responses, model, correct, prompt_ids):
responses = df
if model:
responses = responses[responses['model'].isin(model)]
if correct:
responses = responses[responses['correct'].isin(correct)]
if prompt_ids:
responses = responses[responses['prompt'].isin(prompt_ids)]
# if num_responses > len(responses):
# num_responses = len(responses)
# sample num_responses for each model
responses = responses.groupby('model').apply(lambda x: x.sample(num_responses) if num_responses < len(x) else x).reset_index()
return responses[['model', 'prompt', 'model_response', 'correct']]
def barplot_for_prompt_id(prompt_ids, models):
responses = df
if prompt_ids:
responses = responses[responses['prompt'].isin(prompt_ids)]
if models:
responses = responses[responses['model'].isin(models)]
means = responses.groupby(['model', 'prompt_id'])['correct'].mean()
means = means.reset_index()
means['prompt_id'] = means['prompt_id'].astype(str)
prompt_ids = list(set([p for p in means['prompt_id']]))
prompt_ids_str = ', '.join(prompt_ids)
return gr.BarPlot(
means,
x='prompt_id',
y='correct',
group='model',
color='prompt_id',
group_title="",
title=f'Correctness for Prompt IDs: {prompt_ids_str}',
x_title="",
)
title= "🎩🐇 Alice in Wonderland: Simple Tasks Showing Complete Reasoning Breakdown in State-Of-the-Art Large Language Models"
with gr.Blocks() as demo:
with gr.Row(elem_id="header-row"):
gr.HTML(
f"""<h1 style='font-size: 30px; font-weight: bold; text-align: center;'>{title}</h1>
<h4 align="center"><a href="https://marianna13.github.io/aiw/" target="_blank">🌐Homepage</a> | <a href="https://arxiv.org/pdf/2406.02061" target="_blank"> 📝Paper</a> | <a href="https://github.com/LAION-AI/AIW"target="_blank">🛠️Code</a></h4>
<p style='color: #000000; font-size: 20px; text-align: center;'>This demo shows the responses of different models to a set of prompts. The responses are categorized as correct or incorrect. You can choose the number of responses, the model, the correctness of the responses, and the prompt IDs to see the responses.</p>
<p style='color: #000000; font-size: 20px; text-align: center;'>You can also see the correctness of the responses for different prompt IDs using the robustness plot tab.</p>
"""
)
with gr.Tab("Responses"):
gr.Interface(
response,
[
gr.Slider(2, 20, value=4, label="Number of responses", info="Choose between 2 and 20"),
gr.Dropdown(
list(model_list), label="Model", info="Choose to see responses", multiselect=True
),
gr.CheckboxGroup([("Correct", True), ("Incorrect", False)], label="Correct or not", info="Choose to see correct or incorrect responses"),
gr.Dropdown(
prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
),
],
gr.DataFrame(type="pandas", wrap=True, label="Responses"),
)
with gr.Tab("Robustness plot"):
gr.Interface(
barplot_for_prompt_id,
[
gr.Dropdown(
prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
),
gr.Dropdown(
list(model_list), label="Model", info="Choose to see responses", multiselect=True
)],
gr.BarPlot( title="Correctness for Prompt IDs"),
)
if __name__ == "__main__":
demo.launch()
|