Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +3 -9
- candidates.py +213 -0
- data.csv +0 -0
- metrics.py +188 -0
- output.csv +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
output.csv filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Candidates_viewer_NPR_challenge
|
3 |
+
app_file: candidates.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 5.15.0
|
|
|
|
|
6 |
---
|
|
|
|
candidates.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for joining .csv candidate data into a .duckdb results.
|
3 |
+
Launches a gradio app to review candidates
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
from pathlib import Path
|
7 |
+
import pandas as pd
|
8 |
+
from metrics import load_results
|
9 |
+
import numpy as np
|
10 |
+
import json
|
11 |
+
import ast
|
12 |
+
import gradio as gr
|
13 |
+
from typing import List
|
14 |
+
from hashlib import sha256
|
15 |
+
import re
|
16 |
+
|
17 |
+
def _query_format_models(models: List[str]) -> str:
|
18 |
+
"""
|
19 |
+
Format model names for the SQL query `WHERE <this_model> IN <models>
|
20 |
+
"""
|
21 |
+
return "('" + "','".join(["completions-"+m for m in models]) + "')"
|
22 |
+
|
23 |
+
def _hash(text: str) -> str:
|
24 |
+
return sha256(bytes(text, "utf-8")).hexdigest()
|
25 |
+
|
26 |
+
SQL_QUERY = """
|
27 |
+
WITH AllResults AS (
|
28 |
+
SELECT
|
29 |
+
results.parent_dir AS model,
|
30 |
+
*
|
31 |
+
FROM
|
32 |
+
results.completions results
|
33 |
+
JOIN
|
34 |
+
challenges challenges
|
35 |
+
ON
|
36 |
+
results.prompt_id = challenges.ID
|
37 |
+
)
|
38 |
+
SELECT prompt_id, model, completion, answer as solution, prompt
|
39 |
+
FROM AllResults
|
40 |
+
WHERE
|
41 |
+
AllResults.model IN {models}
|
42 |
+
""".format(models=_query_format_models(['r1_distill_qwen32b','r1','gemini2']))
|
43 |
+
|
44 |
+
|
45 |
+
def print_info(db_connection):
|
46 |
+
tables = db_connection.execute("SHOW TABLES").fetchall()
|
47 |
+
# Iterate over each table and print its name and columns
|
48 |
+
for table in tables:
|
49 |
+
table_name = table[0]
|
50 |
+
print(f"Table: {table_name}")
|
51 |
+
|
52 |
+
# Get the columns for this table
|
53 |
+
columns = db_connection.execute(f"DESCRIBE {table_name}").fetchall()
|
54 |
+
|
55 |
+
# Print the column details
|
56 |
+
for column in columns:
|
57 |
+
print(f" - {column[0]} ({column[1]})") # column[0] is the column name, column[1] is the data type
|
58 |
+
|
59 |
+
print() # Add a blank line between tables for readability
|
60 |
+
|
61 |
+
def _parse(x):
|
62 |
+
if isinstance(x, str):
|
63 |
+
if len(x.strip()) == 0 or x.strip() in ["]","["]:
|
64 |
+
return [] # bad gen
|
65 |
+
else:
|
66 |
+
try:
|
67 |
+
return ast.literal_eval(x)
|
68 |
+
except:
|
69 |
+
raise ValueError(f"Bad gen: {x}")
|
70 |
+
elif np.isnan(x):
|
71 |
+
return []
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Found unexpected type {type(x)}: {x}")
|
74 |
+
|
75 |
+
def _concat(series: pd.Series) -> np.array:
|
76 |
+
items = list(filter(lambda x: len(x) > 0, map(_parse, series)))
|
77 |
+
if len(items) > 0:
|
78 |
+
return np.unique(np.concatenate(items))
|
79 |
+
else:
|
80 |
+
return items
|
81 |
+
|
82 |
+
def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame):
|
83 |
+
"""
|
84 |
+
Perform a variety of sanity checks ie:
|
85 |
+
- all chunks are present
|
86 |
+
- all attempted answers are in the completion
|
87 |
+
"""
|
88 |
+
MANUALLY_CHECKED_SPECIAL_CASES = [
|
89 |
+
"4fd9a9adf162fe558cd94ab7ebcf8f42882873dca133aa1a4620572caa364c0c", # extracted as a str list, eg. `FIED, GOA`
|
90 |
+
"7dd4a475af16d67ed896275674d6a9b51911a3ee22aaca84411fb0a946245fa1"
|
91 |
+
]
|
92 |
+
for _,row in merged_df.iterrows():
|
93 |
+
candidates = json.loads(row["candidates"])
|
94 |
+
comp = row["completion"].lower()
|
95 |
+
for c in candidates:
|
96 |
+
assert c.lower() in comp or \
|
97 |
+
c.lower() in re.sub(r'[^a-z0-9]', '', comp) or \
|
98 |
+
row["_original_completion_hash"] in MANUALLY_CHECKED_SPECIAL_CASES, \
|
99 |
+
json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4)
|
100 |
+
|
101 |
+
# grouped = candidates.groupby(["model","prompt_id"]).agg({"chunk_id": "unique", "num_chunks":"first"})
|
102 |
+
# for _,row in grouped.iterrows():
|
103 |
+
# assert list(row["chunk_id"]) == range(row["num_chunks"]+1), (row["chunk_id"], row["num_chunks"])
|
104 |
+
|
105 |
+
|
106 |
+
def launch_app(df: pd.DataFrame, share_demo: bool = False):
|
107 |
+
|
108 |
+
# Define function to display table and toggle completion
|
109 |
+
def show_table(show_completion, example_idx):
|
110 |
+
# Extract the row based on the slider index
|
111 |
+
example = df.iloc[example_idx]
|
112 |
+
|
113 |
+
# Function to highlight words from the candidates list
|
114 |
+
def highlight_words(text, candidates):
|
115 |
+
for word in candidates:
|
116 |
+
# Use word boundaries to ensure we only match whole words
|
117 |
+
text = re.sub(rf'\b({re.escape(word)})\b', r'<mark>\1</mark>', text, flags=re.IGNORECASE)
|
118 |
+
return text
|
119 |
+
|
120 |
+
# Highlight words in the 'completion' column
|
121 |
+
candidates = json.loads(example['candidates'])
|
122 |
+
highlighted_completion = highlight_words(example['completion'], candidates)
|
123 |
+
|
124 |
+
# Create a table with the core columns
|
125 |
+
table_html = f"""
|
126 |
+
<table>
|
127 |
+
<tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr>
|
128 |
+
<tr><td><b>Model</b></td><td>{example['model']}</td></tr>
|
129 |
+
<tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr>
|
130 |
+
<tr><td><b>Solution</b></td><td>{example['solution']}</td></tr>
|
131 |
+
<tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr>
|
132 |
+
<tr><td><b>Candidates</b></td><td>{candidates}</td></tr>
|
133 |
+
</table>
|
134 |
+
"""
|
135 |
+
|
136 |
+
# If the toggle is checked, show the 'completion' column with highlighted words
|
137 |
+
if show_completion:
|
138 |
+
table_html += f"""
|
139 |
+
<br><b>Completion:</b><br>
|
140 |
+
<p>{highlighted_completion}</p>
|
141 |
+
"""
|
142 |
+
|
143 |
+
return table_html
|
144 |
+
|
145 |
+
# Create the Gradio interface
|
146 |
+
with gr.Blocks() as demo:
|
147 |
+
# Slider to navigate through examples
|
148 |
+
example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0)
|
149 |
+
|
150 |
+
# Toggle button for showing/hiding completion
|
151 |
+
toggle_button = gr.Checkbox(label="Show Completion", value=False)
|
152 |
+
|
153 |
+
with gr.Row():
|
154 |
+
gr.HTML('<h1>Candidates Table</h1>')
|
155 |
+
|
156 |
+
# Table display
|
157 |
+
table_output = gr.HTML()
|
158 |
+
|
159 |
+
# Set interaction behavior: update the table when slider or checkbox changes
|
160 |
+
example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
|
161 |
+
toggle_button.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
|
162 |
+
|
163 |
+
# Launch the app
|
164 |
+
demo.launch(share=share_demo)
|
165 |
+
|
166 |
+
|
167 |
+
def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool):
|
168 |
+
if not output_csv.exists():
|
169 |
+
candidates = pd.read_csv(candidates.as_posix())
|
170 |
+
conn = load_results()
|
171 |
+
completions = conn.sql(SQL_QUERY).df()
|
172 |
+
|
173 |
+
candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({
|
174 |
+
"candidates": "unique"
|
175 |
+
}).reset_index()
|
176 |
+
candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x))))
|
177 |
+
completions["_original_completion_hash"] = completions["completion"].apply(_hash)
|
178 |
+
|
179 |
+
df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"])
|
180 |
+
print(df, candidates, completions, sep="\n")
|
181 |
+
# print_info(conn)
|
182 |
+
# check_candidates(candidates, df)
|
183 |
+
df.to_csv(output_csv)
|
184 |
+
|
185 |
+
# tables = conn.execute("SHOW TABLES").fetchall()
|
186 |
+
# if not ("candidates", ) in tables:
|
187 |
+
# # Create a table in DuckDB and insert the candidate data
|
188 |
+
# conn.execute("CREATE TABLE candidates (model VARCHAR, prompt_id INTEGER, \
|
189 |
+
# prompt VARCHAR, completion VARCHAR, solution VARCHAR, candidates VARCHAR)")
|
190 |
+
|
191 |
+
# # Insert the list of rows into the table
|
192 |
+
# for _,row in df.iterrows():
|
193 |
+
# drow = [row["model"],row["prompt_id"],row["prompt"],row["completion"],row["solution"],row["candidates"]]
|
194 |
+
# conn.execute("INSERT INTO candidates VALUES (?, ?, ?, ?, ?, ?)", drow)
|
195 |
+
|
196 |
+
# conn.commit()
|
197 |
+
# print_info(conn)
|
198 |
+
# conn.close()
|
199 |
+
else:
|
200 |
+
df = pd.read_csv(output_csv.as_posix())
|
201 |
+
|
202 |
+
print(df)
|
203 |
+
if launch_gradio:
|
204 |
+
launch_app(df, share_demo)
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
parser = argparse.ArgumentParser()
|
208 |
+
parser.add_argument("candidates", type=Path, help="path to .csv data containing extracted candidates")
|
209 |
+
parser.add_argument("output_csv", type=Path, help="path to .csv output file; will reload from here if path exists")
|
210 |
+
parser.add_argument("-gr","--launch_gradio", action="store_true")
|
211 |
+
parser.add_argument("-s", "--share_demo", action="store_true")
|
212 |
+
args = parser.parse_args()
|
213 |
+
main(**vars(args))
|
data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
metrics.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import duckdb
|
3 |
+
import textwrap
|
4 |
+
from typing import List, Tuple
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
def _parse_answer(text: str) -> List[List[str]]:
|
8 |
+
"""
|
9 |
+
Converts text to lowercase. Then interprets ";" as a separator between
|
10 |
+
alternatives. Within each alternative, interprets "," and "-->" as separators
|
11 |
+
for elements of a set. Within each set, drops all non-alphanumeric characters
|
12 |
+
and returns that set.
|
13 |
+
|
14 |
+
|
15 |
+
Another way to describe this is that we interpret adjacent words as
|
16 |
+
phrases that must be present literally. However, comma and arrow separate
|
17 |
+
distinct phrases that may be present in any order. All other characters
|
18 |
+
are dropped.
|
19 |
+
"""
|
20 |
+
text = text.lower()
|
21 |
+
alternatives = re.split(r';', text)
|
22 |
+
result = [ ]
|
23 |
+
for alternative in alternatives:
|
24 |
+
groups = re.split(r'–?-?-?>|,', alternative)
|
25 |
+
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
|
26 |
+
return result
|
27 |
+
|
28 |
+
def _answer_without_thoughts(completion: str) -> str:
|
29 |
+
if "<think>" not in completion[:200]:
|
30 |
+
return completion
|
31 |
+
|
32 |
+
chunks = completion.split("</think>")
|
33 |
+
if len(chunks) <= 1:
|
34 |
+
return ""
|
35 |
+
|
36 |
+
return chunks[-1].strip()
|
37 |
+
|
38 |
+
def _check_answer(completion: str, answer: str) -> bool:
|
39 |
+
"""
|
40 |
+
Check that all the phrases that must appear in the answer appear in the
|
41 |
+
completion. We ignore "thoughts", capitalization, and punctuation.
|
42 |
+
"""
|
43 |
+
completion = _answer_without_thoughts(completion).lower()
|
44 |
+
completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
|
45 |
+
completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
|
46 |
+
alternative_answers = _parse_answer(answer)
|
47 |
+
for answer_phrases in alternative_answers:
|
48 |
+
# if all(phrase in completion for phrase in answer_phrases):
|
49 |
+
if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
|
50 |
+
return True
|
51 |
+
return False
|
52 |
+
|
53 |
+
|
54 |
+
def _clip_text(text: str, width: int) -> str:
|
55 |
+
return text if len(text) <= width else text[:width] + "..."
|
56 |
+
|
57 |
+
def _wrap_text(text: str, width: int) -> str:
|
58 |
+
return textwrap.fill(text, width=width)
|
59 |
+
|
60 |
+
def load_results():
|
61 |
+
conn = duckdb.connect(":memory:")
|
62 |
+
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
|
63 |
+
# conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
|
64 |
+
conn.execute("""
|
65 |
+
CREATE TABLE challenges AS
|
66 |
+
SELECT * FROM 'puzzles_cleaned.csv'
|
67 |
+
WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
|
68 |
+
""")
|
69 |
+
conn.create_function("check_answer", _check_answer)
|
70 |
+
conn.create_function("clip_text", _clip_text)
|
71 |
+
conn.create_function("wrap_text", _wrap_text)
|
72 |
+
return conn
|
73 |
+
|
74 |
+
def r1_accuracy_by_completion_length(conn,model_name):
|
75 |
+
"""
|
76 |
+
For the responses from the completions-r1 model:
|
77 |
+
1. We calculate completion length and correctness for each problem.
|
78 |
+
2. We sort by length.
|
79 |
+
3. We compute cumulative number of correct responses.
|
80 |
+
"""
|
81 |
+
r1_completions = conn.sql(f"""
|
82 |
+
WITH LengthsAndCorrectness AS (
|
83 |
+
SELECT
|
84 |
+
LENGTH(results.completion) AS length,
|
85 |
+
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
|
86 |
+
FROM results.completions results JOIN challenges
|
87 |
+
ON results.prompt_id = challenges.ID
|
88 |
+
WHERE results.parent_dir = '{model_name}'
|
89 |
+
),
|
90 |
+
TotalItems AS (
|
91 |
+
SELECT COUNT(*) as total_count
|
92 |
+
FROM LengthsAndCorrectness
|
93 |
+
),
|
94 |
+
CumulativeCorrect AS (
|
95 |
+
SELECT
|
96 |
+
length,
|
97 |
+
SUM(correct) OVER (ORDER BY length) as cumulative_correct,
|
98 |
+
FROM LengthsAndCorrectness
|
99 |
+
)
|
100 |
+
|
101 |
+
SELECT
|
102 |
+
length,
|
103 |
+
cumulative_correct,
|
104 |
+
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
|
105 |
+
FROM CumulativeCorrect, TotalItems
|
106 |
+
ORDER BY length
|
107 |
+
""")
|
108 |
+
return r1_completions
|
109 |
+
|
110 |
+
|
111 |
+
def accuracy_by_model_and_time(conn):
|
112 |
+
model_accuracies = conn.sql("""
|
113 |
+
WITH ChallengesWithDates AS (
|
114 |
+
SELECT
|
115 |
+
ID,
|
116 |
+
answer,
|
117 |
+
EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
|
118 |
+
FROM
|
119 |
+
challenges
|
120 |
+
),
|
121 |
+
DateAnswerCheck AS (
|
122 |
+
SELECT
|
123 |
+
results.parent_dir AS model,
|
124 |
+
dates.year,
|
125 |
+
COUNT(*) AS total,
|
126 |
+
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
|
127 |
+
FROM
|
128 |
+
results.completions results
|
129 |
+
JOIN
|
130 |
+
ChallengesWithDates dates
|
131 |
+
ON
|
132 |
+
results.prompt_id = dates.ID
|
133 |
+
GROUP BY
|
134 |
+
results.parent_dir,
|
135 |
+
dates.year
|
136 |
+
)
|
137 |
+
SELECT
|
138 |
+
model,
|
139 |
+
year,
|
140 |
+
total,
|
141 |
+
correct,
|
142 |
+
ROUND(correct / total, 2) AS accuracy
|
143 |
+
FROM
|
144 |
+
DateAnswerCheck
|
145 |
+
ORDER BY
|
146 |
+
model,
|
147 |
+
year
|
148 |
+
""")
|
149 |
+
|
150 |
+
return model_accuracies
|
151 |
+
|
152 |
+
def accuracy_by_model(conn):
|
153 |
+
return conn.sql("""
|
154 |
+
WITH AnswerCheck AS (
|
155 |
+
SELECT
|
156 |
+
results.parent_dir AS model,
|
157 |
+
SUM(results.count) AS total,
|
158 |
+
SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
|
159 |
+
FROM
|
160 |
+
results.completions results
|
161 |
+
JOIN
|
162 |
+
challenges challenges
|
163 |
+
ON
|
164 |
+
results.prompt_id = challenges.ID
|
165 |
+
GROUP BY
|
166 |
+
results.parent_dir
|
167 |
+
)
|
168 |
+
SELECT
|
169 |
+
model,
|
170 |
+
total,
|
171 |
+
correct,
|
172 |
+
ROUND(correct / total, 2) AS accuracy
|
173 |
+
FROM
|
174 |
+
AnswerCheck
|
175 |
+
""")
|
176 |
+
|
177 |
+
def main():
|
178 |
+
parser = argparse.ArgumentParser()
|
179 |
+
parser.add_argument("--by-model-and-time", action="store_true")
|
180 |
+
args = parser.parse_args()
|
181 |
+
conn = load_results()
|
182 |
+
if args.by_model_and_time:
|
183 |
+
print(accuracy_by_model_and_time(conn))
|
184 |
+
else:
|
185 |
+
print(accuracy_by_model(conn))
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
main()
|
output.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ad970dce3fb60473dcbfa707515ab67dd78b3cbcc2856feeff6fcb33c918e69
|
3 |
+
size 18655953
|