Blair Yang commited on
Commit
5264831
·
1 Parent(s): e159d95
Config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASETS = [
2
+ 'mmlu',
3
+ # 'Anthropic_safety_eval'
4
+ ]
5
+
6
+ TOPICS = {
7
+ 'mmlu':
8
+ [
9
+ # 'high_school_biology',
10
+ 'high_school_physics'
11
+ ],
12
+ 'Anthropic_safety_eval':
13
+ [
14
+ 'myopia'
15
+ ]
16
+ }
17
+
18
+ MODELS = ['Llama-2-70b-chat-hf',
19
+ 'Llama-2-13b-chat-hf',
20
+ 'Mixtral-8x7B-Instruct-v0.1',
21
+ 'Mistral-7B-Instruct-v0.2'
22
+ ]
23
+
24
+ RANDOM_SEED = 42
__pycache__/Config.cpython-311.pyc ADDED
Binary file (472 Bytes). View file
 
app.py CHANGED
@@ -1,7 +1,90 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import plotly.express as px
3
+ import plotly.graph_objs as go
4
+ from collections import defaultdict
5
+ import json, math, gdown
6
+ import numpy as np
7
+ import pandas as pd
8
+ from Config import *
9
+ pd.options.display.float_format = '{:.2f}'.format
10
 
 
 
11
 
12
+ battles = np.linspace(0, 100, 100)
13
+
14
+ meta_topics = ['mmlu']
15
+
16
+ def generate_plot(meta_index, topic_index):
17
+ """
18
+ Bar plot of a specific dataset
19
+ """
20
+ # battles = np.linspace(0, 100, 100)
21
+ meta_topic = meta_topics[meta_index]
22
+ print(meta_topic)
23
+ topic = TOPICS[meta_topic][topic_index]
24
+
25
+ data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
26
+
27
+ topic_data = data[data['sub_topic'] == topic]
28
+
29
+ # Compute human and llm accuracy
30
+ topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
31
+ topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
32
+
33
+ # Calculate mean and standard deviation for the sample data
34
+ mean_data = topic_data.groupby('model_name').mean().reset_index()
35
+ std_deviation = topic_data.groupby('model_name').std().reset_index()
36
+
37
+ # Prepare the plot data
38
+ plot_data = []
39
+
40
+ # Define a consistent color scheme
41
+ colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
42
+ opacities = [0.7, 0.7, 0.7] # Opacity for average bars
43
+
44
+ # Add bars with error bars for the averages
45
+ for acc_type, color, opacity in zip(['oracle_acc', 'human_acc', 'llm_acc'], colors, opacities):
46
+ plot_data.append(go.Bar(
47
+ x=mean_data['model_name'],
48
+ y=mean_data[acc_type],
49
+ error_y=dict(
50
+ type='data',
51
+ array=std_deviation[acc_type],
52
+ visible=True
53
+ ),
54
+ name=acc_type.split('_')[0].capitalize(),
55
+ marker=dict(color=color, opacity=opacity)
56
+ ))
57
+
58
+ # Layout
59
+ layout = go.Layout(
60
+ title=f"Accuracy for {meta_topic} ({topic})",
61
+ xaxis=dict(title='Model Name'),
62
+ yaxis=dict(title='Accuracy'),
63
+ showlegend=True,
64
+ legend=dict(title='Accuracy Type'),
65
+ barmode='group'
66
+ )
67
+
68
+ fig = go.Figure(data=plot_data, layout=layout)
69
+ return fig
70
+
71
+
72
+ # Gradio interface with grid layout
73
+ with gr.Blocks() as interface:
74
+ with gr.Row(): # Row 1
75
+ plot1 = gr.Plot(generate_plot(0, 0))
76
+ # plot1.update(inputs=[0, 0])
77
+ plot2 = gr.Plot(generate_plot(0, 0))
78
+ # plot2.update(inputs=[0, 1])
79
+ with gr.Row(): # Row 2
80
+ plot3 = gr.Plot(generate_plot(0, 0))
81
+ # plot3.update(inputs=[1, 0])
82
+ plot4 = gr.Plot(generate_plot(0, 0))
83
+ # plot4.update(inputs=[1, 1])
84
+ with gr.Row(): # Row 3
85
+ plot5 = gr.Plot(generate_plot(0, 0))
86
+ # plot5.update(inputs=[2, 0])
87
+ plot6 = gr.Plot(generate_plot(0, 0))
88
+ # plot6.update(inputs=[2, 1])
89
+
90
+ interface.launch()
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/mmlu/response_rec.csv CHANGED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sub_topic,model_name,card_idx,no_responses_human,no_correct_human,no_responses_llm,no_correct_llm,oracle_acc
2
+ high_school_physics,Mixtral-8x7B-Instruct-v0.1,-1,10,8,10,7,0.68
3
+ high_school_physics,Mixtral-8x7B-Instruct-v0.1,0,6,4,6,3,0.66
4
+ high_school_physics,Mixtral-8x7B-Instruct-v0.1,1,4,4,4,4,0.7
5
+ high_school_physics,Mistral-7B-Instruct-v0.2,-1,0,0,0,0,0
6
+ high_school_physics,Mistral-7B-Instruct-v0.2,0,0,0,0,0,0
7
+ high_school_physics,Mistral-7B-Instruct-v0.2,1,0,0,0,0,0
8
+ high_school_biology,Mixtral-8x7B-Instruct-v0.1,-1,10,8,10,7,0.68
9
+ high_school_biology,Mixtral-8x7B-Instruct-v0.1,0,6,4,6,3,0.66
10
+ high_school_biology,Mixtral-8x7B-Instruct-v0.1,1,4,4,4,4,0.7
11
+ high_school_biology,Mistral-7B-Instruct-v0.2,-1,0,0,0,0,0
12
+ high_school_biology,Mistral-7B-Instruct-v0.2,0,0,0,0,0,0
13
+ high_school_biology,Mistral-7B-Instruct-v0.2,1,0,0,0,0,0
plot.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json, math, gdown
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ from tqdm import tqdm
7
+ pd.options.display.float_format = '{:.2f}'.format
8
+
9
+
10
+ battles = np.linspace(0, 100, 100)
11
+ fig = px.bar(battles,
12
+ title="Counts of Battle Outcomes", text_auto=True, height=400)
13
+ fig.update_layout(xaxis_title="Battle Outcome", yaxis_title="Count",
14
+ showlegend=False)
15
+ fig.show()
requirements.txt CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ plotly
2
+ numpy
3
+ pandas
4
+ tqdm
util.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def read_data(file_path):
4
+ """
5
+ Read data from a csv file
6
+ """
7
+ return pd.read_csv(file_path, sep=",")
8
+
9
+
10
+ if __name__ == "__main__":
11
+ file_path = "data/mmlu/response_rec.csv"
12
+ data = read_data(file_path)
13
+ high_school_physics = data[data['sub_topic'] == 'high_school_physics']
14
+ print(high_school_physics.head(5))