franlucc commited on
Commit
8a17d1d
·
verified ·
1 Parent(s): 95a4b2a

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +3 -9
  3. candidates.py +213 -0
  4. data.csv +0 -0
  5. metrics.py +188 -0
  6. output.csv +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output.csv filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Candidates Viewer NPR Challenge
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.20.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Candidates_viewer_NPR_challenge
3
+ app_file: candidates.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.15.0
 
 
6
  ---
 
 
candidates.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for joining .csv candidate data into a .duckdb results.
3
+ Launches a gradio app to review candidates
4
+ """
5
+ import argparse
6
+ from pathlib import Path
7
+ import pandas as pd
8
+ from metrics import load_results
9
+ import numpy as np
10
+ import json
11
+ import ast
12
+ import gradio as gr
13
+ from typing import List
14
+ from hashlib import sha256
15
+ import re
16
+
17
+ def _query_format_models(models: List[str]) -> str:
18
+ """
19
+ Format model names for the SQL query `WHERE <this_model> IN <models>
20
+ """
21
+ return "('" + "','".join(["completions-"+m for m in models]) + "')"
22
+
23
+ def _hash(text: str) -> str:
24
+ return sha256(bytes(text, "utf-8")).hexdigest()
25
+
26
+ SQL_QUERY = """
27
+ WITH AllResults AS (
28
+ SELECT
29
+ results.parent_dir AS model,
30
+ *
31
+ FROM
32
+ results.completions results
33
+ JOIN
34
+ challenges challenges
35
+ ON
36
+ results.prompt_id = challenges.ID
37
+ )
38
+ SELECT prompt_id, model, completion, answer as solution, prompt
39
+ FROM AllResults
40
+ WHERE
41
+ AllResults.model IN {models}
42
+ """.format(models=_query_format_models(['r1_distill_qwen32b','r1','gemini2']))
43
+
44
+
45
+ def print_info(db_connection):
46
+ tables = db_connection.execute("SHOW TABLES").fetchall()
47
+ # Iterate over each table and print its name and columns
48
+ for table in tables:
49
+ table_name = table[0]
50
+ print(f"Table: {table_name}")
51
+
52
+ # Get the columns for this table
53
+ columns = db_connection.execute(f"DESCRIBE {table_name}").fetchall()
54
+
55
+ # Print the column details
56
+ for column in columns:
57
+ print(f" - {column[0]} ({column[1]})") # column[0] is the column name, column[1] is the data type
58
+
59
+ print() # Add a blank line between tables for readability
60
+
61
+ def _parse(x):
62
+ if isinstance(x, str):
63
+ if len(x.strip()) == 0 or x.strip() in ["]","["]:
64
+ return [] # bad gen
65
+ else:
66
+ try:
67
+ return ast.literal_eval(x)
68
+ except:
69
+ raise ValueError(f"Bad gen: {x}")
70
+ elif np.isnan(x):
71
+ return []
72
+ else:
73
+ raise ValueError(f"Found unexpected type {type(x)}: {x}")
74
+
75
+ def _concat(series: pd.Series) -> np.array:
76
+ items = list(filter(lambda x: len(x) > 0, map(_parse, series)))
77
+ if len(items) > 0:
78
+ return np.unique(np.concatenate(items))
79
+ else:
80
+ return items
81
+
82
+ def check_candidates(candidates: pd.DataFrame, merged_df: pd.DataFrame):
83
+ """
84
+ Perform a variety of sanity checks ie:
85
+ - all chunks are present
86
+ - all attempted answers are in the completion
87
+ """
88
+ MANUALLY_CHECKED_SPECIAL_CASES = [
89
+ "4fd9a9adf162fe558cd94ab7ebcf8f42882873dca133aa1a4620572caa364c0c", # extracted as a str list, eg. `FIED, GOA`
90
+ "7dd4a475af16d67ed896275674d6a9b51911a3ee22aaca84411fb0a946245fa1"
91
+ ]
92
+ for _,row in merged_df.iterrows():
93
+ candidates = json.loads(row["candidates"])
94
+ comp = row["completion"].lower()
95
+ for c in candidates:
96
+ assert c.lower() in comp or \
97
+ c.lower() in re.sub(r'[^a-z0-9]', '', comp) or \
98
+ row["_original_completion_hash"] in MANUALLY_CHECKED_SPECIAL_CASES, \
99
+ json.dumps({"candidate":c, "completion":row["completion"], "hash": row["_original_completion_hash"]}, indent=4)
100
+
101
+ # grouped = candidates.groupby(["model","prompt_id"]).agg({"chunk_id": "unique", "num_chunks":"first"})
102
+ # for _,row in grouped.iterrows():
103
+ # assert list(row["chunk_id"]) == range(row["num_chunks"]+1), (row["chunk_id"], row["num_chunks"])
104
+
105
+
106
+ def launch_app(df: pd.DataFrame, share_demo: bool = False):
107
+
108
+ # Define function to display table and toggle completion
109
+ def show_table(show_completion, example_idx):
110
+ # Extract the row based on the slider index
111
+ example = df.iloc[example_idx]
112
+
113
+ # Function to highlight words from the candidates list
114
+ def highlight_words(text, candidates):
115
+ for word in candidates:
116
+ # Use word boundaries to ensure we only match whole words
117
+ text = re.sub(rf'\b({re.escape(word)})\b', r'<mark>\1</mark>', text, flags=re.IGNORECASE)
118
+ return text
119
+
120
+ # Highlight words in the 'completion' column
121
+ candidates = json.loads(example['candidates'])
122
+ highlighted_completion = highlight_words(example['completion'], candidates)
123
+
124
+ # Create a table with the core columns
125
+ table_html = f"""
126
+ <table>
127
+ <tr><td><b>Completion hash</b></td><td>{example['_original_completion_hash']}</td></tr>
128
+ <tr><td><b>Model</b></td><td>{example['model']}</td></tr>
129
+ <tr><td><b>Prompt ID</b></td><td>{example['prompt_id']}</td></tr>
130
+ <tr><td><b>Solution</b></td><td>{example['solution']}</td></tr>
131
+ <tr><td><b>Prompt</b></td><td>{example['prompt']}</td></tr>
132
+ <tr><td><b>Candidates</b></td><td>{candidates}</td></tr>
133
+ </table>
134
+ """
135
+
136
+ # If the toggle is checked, show the 'completion' column with highlighted words
137
+ if show_completion:
138
+ table_html += f"""
139
+ <br><b>Completion:</b><br>
140
+ <p>{highlighted_completion}</p>
141
+ """
142
+
143
+ return table_html
144
+
145
+ # Create the Gradio interface
146
+ with gr.Blocks() as demo:
147
+ # Slider to navigate through examples
148
+ example_slider = gr.Slider(minimum=0, maximum=len(df)-1, step=1, label="Example", value=0)
149
+
150
+ # Toggle button for showing/hiding completion
151
+ toggle_button = gr.Checkbox(label="Show Completion", value=False)
152
+
153
+ with gr.Row():
154
+ gr.HTML('<h1>Candidates Table</h1>')
155
+
156
+ # Table display
157
+ table_output = gr.HTML()
158
+
159
+ # Set interaction behavior: update the table when slider or checkbox changes
160
+ example_slider.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
161
+ toggle_button.change(show_table, inputs=[toggle_button, example_slider], outputs=[table_output])
162
+
163
+ # Launch the app
164
+ demo.launch(share=share_demo)
165
+
166
+
167
+ def main(candidates: Path, output_csv: Path, launch_gradio: bool, share_demo: bool):
168
+ if not output_csv.exists():
169
+ candidates = pd.read_csv(candidates.as_posix())
170
+ conn = load_results()
171
+ completions = conn.sql(SQL_QUERY).df()
172
+
173
+ candidates = candidates.groupby(["model","prompt_id","solution","prompt","_original_completion_hash"]).agg({
174
+ "candidates": "unique"
175
+ }).reset_index()
176
+ candidates["candidates"] = candidates["candidates"].apply(lambda x: json.dumps(list(_concat(x))))
177
+ completions["_original_completion_hash"] = completions["completion"].apply(_hash)
178
+
179
+ df = candidates.merge(completions, on=["model","prompt_id","prompt","solution","_original_completion_hash"])
180
+ print(df, candidates, completions, sep="\n")
181
+ # print_info(conn)
182
+ # check_candidates(candidates, df)
183
+ df.to_csv(output_csv)
184
+
185
+ # tables = conn.execute("SHOW TABLES").fetchall()
186
+ # if not ("candidates", ) in tables:
187
+ # # Create a table in DuckDB and insert the candidate data
188
+ # conn.execute("CREATE TABLE candidates (model VARCHAR, prompt_id INTEGER, \
189
+ # prompt VARCHAR, completion VARCHAR, solution VARCHAR, candidates VARCHAR)")
190
+
191
+ # # Insert the list of rows into the table
192
+ # for _,row in df.iterrows():
193
+ # drow = [row["model"],row["prompt_id"],row["prompt"],row["completion"],row["solution"],row["candidates"]]
194
+ # conn.execute("INSERT INTO candidates VALUES (?, ?, ?, ?, ?, ?)", drow)
195
+
196
+ # conn.commit()
197
+ # print_info(conn)
198
+ # conn.close()
199
+ else:
200
+ df = pd.read_csv(output_csv.as_posix())
201
+
202
+ print(df)
203
+ if launch_gradio:
204
+ launch_app(df, share_demo)
205
+
206
+ if __name__ == "__main__":
207
+ parser = argparse.ArgumentParser()
208
+ parser.add_argument("candidates", type=Path, help="path to .csv data containing extracted candidates")
209
+ parser.add_argument("output_csv", type=Path, help="path to .csv output file; will reload from here if path exists")
210
+ parser.add_argument("-gr","--launch_gradio", action="store_true")
211
+ parser.add_argument("-s", "--share_demo", action="store_true")
212
+ args = parser.parse_args()
213
+ main(**vars(args))
data.csv ADDED
The diff for this file is too large to render. See raw diff
 
metrics.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import duckdb
3
+ import textwrap
4
+ from typing import List, Tuple
5
+ import argparse
6
+
7
+ def _parse_answer(text: str) -> List[List[str]]:
8
+ """
9
+ Converts text to lowercase. Then interprets ";" as a separator between
10
+ alternatives. Within each alternative, interprets "," and "-->" as separators
11
+ for elements of a set. Within each set, drops all non-alphanumeric characters
12
+ and returns that set.
13
+
14
+
15
+ Another way to describe this is that we interpret adjacent words as
16
+ phrases that must be present literally. However, comma and arrow separate
17
+ distinct phrases that may be present in any order. All other characters
18
+ are dropped.
19
+ """
20
+ text = text.lower()
21
+ alternatives = re.split(r';', text)
22
+ result = [ ]
23
+ for alternative in alternatives:
24
+ groups = re.split(r'–?-?-?>|,', alternative)
25
+ result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
26
+ return result
27
+
28
+ def _answer_without_thoughts(completion: str) -> str:
29
+ if "<think>" not in completion[:200]:
30
+ return completion
31
+
32
+ chunks = completion.split("</think>")
33
+ if len(chunks) <= 1:
34
+ return ""
35
+
36
+ return chunks[-1].strip()
37
+
38
+ def _check_answer(completion: str, answer: str) -> bool:
39
+ """
40
+ Check that all the phrases that must appear in the answer appear in the
41
+ completion. We ignore "thoughts", capitalization, and punctuation.
42
+ """
43
+ completion = _answer_without_thoughts(completion).lower()
44
+ completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
45
+ completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
46
+ alternative_answers = _parse_answer(answer)
47
+ for answer_phrases in alternative_answers:
48
+ # if all(phrase in completion for phrase in answer_phrases):
49
+ if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
50
+ return True
51
+ return False
52
+
53
+
54
+ def _clip_text(text: str, width: int) -> str:
55
+ return text if len(text) <= width else text[:width] + "..."
56
+
57
+ def _wrap_text(text: str, width: int) -> str:
58
+ return textwrap.fill(text, width=width)
59
+
60
+ def load_results():
61
+ conn = duckdb.connect(":memory:")
62
+ conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
63
+ # conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
64
+ conn.execute("""
65
+ CREATE TABLE challenges AS
66
+ SELECT * FROM 'puzzles_cleaned.csv'
67
+ WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
68
+ """)
69
+ conn.create_function("check_answer", _check_answer)
70
+ conn.create_function("clip_text", _clip_text)
71
+ conn.create_function("wrap_text", _wrap_text)
72
+ return conn
73
+
74
+ def r1_accuracy_by_completion_length(conn,model_name):
75
+ """
76
+ For the responses from the completions-r1 model:
77
+ 1. We calculate completion length and correctness for each problem.
78
+ 2. We sort by length.
79
+ 3. We compute cumulative number of correct responses.
80
+ """
81
+ r1_completions = conn.sql(f"""
82
+ WITH LengthsAndCorrectness AS (
83
+ SELECT
84
+ LENGTH(results.completion) AS length,
85
+ CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
86
+ FROM results.completions results JOIN challenges
87
+ ON results.prompt_id = challenges.ID
88
+ WHERE results.parent_dir = '{model_name}'
89
+ ),
90
+ TotalItems AS (
91
+ SELECT COUNT(*) as total_count
92
+ FROM LengthsAndCorrectness
93
+ ),
94
+ CumulativeCorrect AS (
95
+ SELECT
96
+ length,
97
+ SUM(correct) OVER (ORDER BY length) as cumulative_correct,
98
+ FROM LengthsAndCorrectness
99
+ )
100
+
101
+ SELECT
102
+ length,
103
+ cumulative_correct,
104
+ CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
105
+ FROM CumulativeCorrect, TotalItems
106
+ ORDER BY length
107
+ """)
108
+ return r1_completions
109
+
110
+
111
+ def accuracy_by_model_and_time(conn):
112
+ model_accuracies = conn.sql("""
113
+ WITH ChallengesWithDates AS (
114
+ SELECT
115
+ ID,
116
+ answer,
117
+ EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
118
+ FROM
119
+ challenges
120
+ ),
121
+ DateAnswerCheck AS (
122
+ SELECT
123
+ results.parent_dir AS model,
124
+ dates.year,
125
+ COUNT(*) AS total,
126
+ SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
127
+ FROM
128
+ results.completions results
129
+ JOIN
130
+ ChallengesWithDates dates
131
+ ON
132
+ results.prompt_id = dates.ID
133
+ GROUP BY
134
+ results.parent_dir,
135
+ dates.year
136
+ )
137
+ SELECT
138
+ model,
139
+ year,
140
+ total,
141
+ correct,
142
+ ROUND(correct / total, 2) AS accuracy
143
+ FROM
144
+ DateAnswerCheck
145
+ ORDER BY
146
+ model,
147
+ year
148
+ """)
149
+
150
+ return model_accuracies
151
+
152
+ def accuracy_by_model(conn):
153
+ return conn.sql("""
154
+ WITH AnswerCheck AS (
155
+ SELECT
156
+ results.parent_dir AS model,
157
+ SUM(results.count) AS total,
158
+ SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
159
+ FROM
160
+ results.completions results
161
+ JOIN
162
+ challenges challenges
163
+ ON
164
+ results.prompt_id = challenges.ID
165
+ GROUP BY
166
+ results.parent_dir
167
+ )
168
+ SELECT
169
+ model,
170
+ total,
171
+ correct,
172
+ ROUND(correct / total, 2) AS accuracy
173
+ FROM
174
+ AnswerCheck
175
+ """)
176
+
177
+ def main():
178
+ parser = argparse.ArgumentParser()
179
+ parser.add_argument("--by-model-and-time", action="store_true")
180
+ args = parser.parse_args()
181
+ conn = load_results()
182
+ if args.by_model_and_time:
183
+ print(accuracy_by_model_and_time(conn))
184
+ else:
185
+ print(accuracy_by_model(conn))
186
+
187
+ if __name__ == "__main__":
188
+ main()
output.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ad970dce3fb60473dcbfa707515ab67dd78b3cbcc2856feeff6fcb33c918e69
3
+ size 18655953