File size: 5,910 Bytes
201434d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e5007
da6fc76
201434d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05dae02
 
 
 
 
 
201434d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import pandas as pd

MODEL_MAPPINGS = {
    "gpt-4o-2024-05-13": "GPT-4o",
    "gpt-4-0613": "GPT-4",
    "gpt-4-turbo-2024-04-09": "GPT-4 Turbo",
    "gpt-4-0125-preview": "GPT-4 Preview",
    "gpt-3.5": "GPT-3.5",
    "gpt-3.5-turbo-0125": "GPT-3.5 Turbo",
    "claude-3-opus-20240229": "Claude-3 O",
    "claude-3-sonnet-20240229": "Claude-3 S",
    "claude-3-haiku-20240307": "Claude-3 H",
    "claude-3-5-sonnet-20240620": "Claude-3.5 S",
    "llama-2-70b-chat": "Llama-2 70b",
    "llama-2-13b-chat": "Llama-2 13b",
    "llama-2-7b-chat": "Llama-2 7b",
    "llama-3-8b-chat": "Llama-3 8b",
    "llama-3-70b-chat": "Llama-3 70b",
    "codellama-70b-instruct": "Codellama 70b",
    "mistral-large-2402": "Mistral Large",
    "mistral-medium-2312": "Mistral Medium",
    "open-mixtral-8x22b-instruct-v0.1": "Mixtral 8x22b",
    "open-mixtral-8x7b-instruct": "Mixtral 8x7b",
    "open-mistral-7b-instruct": "Mistral 7b",
    "open-mistral-7b": "Mistral 7b",
    "open-mixtral-8x22b": "Mixtral 8x22b",
    "open-mixtral-8x7b": "Mixtral 8x7b",
    "open-mistral-7b-instruct-v0.1": "Mistral 7b",
    "dbrx-instruct": "DBRX",
    "command-r-plus": "Command R Plus",
    "gemma-7b-it": "Gemma 7b",
    "gemma-2b-it": "Gemma 2b",
    "gemini-1.5-pro-latest": "Gemini 1.5",
    "gemini-pro": "Gemini 1.0",
    "qwen1.5-7b-chat": "Qwen 1.5 7b",
    "qwen1.5-14b-chat": "Qwen 1.5 14b",
    "qwen1.5-32b-chat": "Qwen 1.5 32b",
    "qwen1.5-72b-chat": "Qwen 1.5 72b",
    "qwen1.5-0.5b-chat": "Qwen 1.5 0.5b",
    "qwen1.5-1.8b-chat": "Qwen 1.5 1.8b",
    "qwen2-72b-instruct": "Qwen 2 72b",
    "codestral-2405": "Codestral"
}
resp_url = 'https://github.com/LAION-AI/AIW/raw/main/collected_responses/responses.jsonl'
df = pd.read_json(resp_url, lines=True)
df['model'] = df['model'].map(MODEL_MAPPINGS)
df['prompt'] = df[['prompt', 'prompt_id']].apply(lambda x: f"{x['prompt']} [{x['prompt_id']}]", axis=1)


model_list = df['model'].unique()
prompt_id_list = list(df['prompt'].unique())
prompt_id_list = sorted(prompt_id_list, key=lambda x: int(x.split('[')[1].split(']')[0]))


def response(num_responses, model, correct, prompt_ids):
    responses = df
    if model:
        responses = responses[responses['model'].isin(model)]
    if correct:
        responses = responses[responses['correct'].isin(correct)]
    if prompt_ids:
        responses = responses[responses['prompt'].isin(prompt_ids)]
    # if num_responses > len(responses):
    #     num_responses = len(responses)

    # sample num_responses for each model
    responses = responses.groupby('model').apply(lambda x: x.sample(num_responses) if num_responses < len(x) else x).reset_index()
    return responses[['model', 'prompt', 'model_response', 'correct']]


def barplot_for_prompt_id(prompt_ids, models):
    responses = df
    if prompt_ids:
        responses = responses[responses['prompt'].isin(prompt_ids)]
    if models:
        responses = responses[responses['model'].isin(models)]
    means = responses.groupby(['model', 'prompt_id'])['correct'].mean()
    means = means.reset_index()
    means['prompt_id'] = means['prompt_id'].astype(str)
    prompt_ids = list(set([p for p in means['prompt_id']]))
    prompt_ids_str = ', '.join(prompt_ids)
    return gr.BarPlot(
        means,
        x='prompt_id',
        y='correct',
        group='model',
        color='prompt_id',
        group_title="",
        title=f'Correctness for Prompt IDs: {prompt_ids_str}',
        x_title="",
    )

title= "🎩🐇 Alice in Wonderland: Simple Tasks Showing Complete Reasoning Breakdown in State-Of-the-Art Large Language Models"

with gr.Blocks() as demo:
    with gr.Row(elem_id="header-row"):
        gr.HTML(
            f"""<h1 style='font-size: 30px; font-weight: bold; text-align: center;'>{title}</h1>
            <h4 align="center"><a href="https://marianna13.github.io/aiw/" target="_blank">🌐Homepage</a> | <a href="https://arxiv.org/pdf/2406.02061" target="_blank"> 📝Paper</a> | <a href="https://github.com/LAION-AI/AIW"target="_blank">🛠️Code</a></h4>
            <p style='color: #000000; font-size: 20px; text-align: center;'>This demo shows the responses of different models to a set of prompts. The responses are categorized as correct or incorrect. You can choose the number of responses, the model, the correctness of the responses, and the prompt IDs to see the responses.</p>
            <p style='color: #000000; font-size: 20px; text-align: center;'>You can also see the correctness of the responses for different prompt IDs using the robustness plot tab.</p>
            """
        )
    with gr.Tab("Responses"):
        gr.Interface(
            response,
            [
                gr.Slider(2, 20, value=4, label="Number of responses", info="Choose between 2 and 20"),
                gr.Dropdown(
                    list(model_list), label="Model", info="Choose to see responses", multiselect=True
                ),
                gr.CheckboxGroup([("Correct", True), ("Incorrect", False)], label="Correct or not", info="Choose to see correct or incorrect responses"),
                gr.Dropdown(
                    prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
                ),
            ],
            gr.DataFrame(type="pandas", wrap=True, label="Responses"),

        )

    with gr.Tab("Robustness plot"):
        gr.Interface(
            barplot_for_prompt_id,
           [
               gr.Dropdown(
                prompt_id_list, multiselect=True, label="Prompt IDs", info="Choose to see responses for a specific prompt ID(s)"
            ),
            gr.Dropdown(
                list(model_list), label="Model", info="Choose to see responses", multiselect=True
            )],
            gr.BarPlot( title="Correctness for Prompt IDs"),
        )

if __name__ == "__main__":
    demo.launch()