Spaces:
Running
Running
changes to use new table and descriptive runs
Browse files- app.py +23 -14
- data_access.py +77 -74
- eval_tables.py +29 -6
- scripts/__init__.py +0 -0
- tests/test_db_layer.py +45 -4
app.py
CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
-
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
@@ -18,8 +18,9 @@ question_options = []
|
|
18 |
baseline_rankers_dict = {}
|
19 |
baseline_ranker_options = []
|
20 |
run_ids = []
|
|
|
21 |
finder_options = []
|
22 |
-
previous_run_id =
|
23 |
|
24 |
run_id_dropdown = None
|
25 |
|
@@ -60,8 +61,9 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
|
|
60 |
|
61 |
# Main function to handle UI interactions
|
62 |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
|
|
|
63 |
if not question_option:
|
64 |
-
return gr.skip(), gr.skip(), gr.skip(), "No question selected"
|
65 |
logger.info("processing update")
|
66 |
if type(baseline_ranker_name) == list:
|
67 |
baseline_ranker_name = baseline_ranker_name[0]
|
@@ -75,20 +77,21 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
75 |
|
76 |
if question_option == "All questions":
|
77 |
if finder_id_int and type(run_id) == str:
|
78 |
-
|
|
|
79 |
else:
|
80 |
all_stats = None
|
81 |
-
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results"
|
82 |
|
83 |
# Extract question ID from selection
|
84 |
question_id = questions_dict.get(question_option)
|
85 |
|
86 |
-
|
87 |
-
run_id_options =
|
88 |
if run_id not in run_id_options:
|
89 |
run_id = run_id_options[0]
|
90 |
|
91 |
-
run_id_int =
|
92 |
|
93 |
|
94 |
|
@@ -96,7 +99,7 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
96 |
stats = None
|
97 |
# Get source runs data
|
98 |
if finder_id_int:
|
99 |
-
source_runs, stats = await get_unified_sources(question_id,
|
100 |
# Create DataFrame for display
|
101 |
df = pd.DataFrame(source_runs)
|
102 |
|
@@ -110,9 +113,10 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
110 |
|
111 |
# CSV for download
|
112 |
# csv_data = df.to_csv(index=False)
|
|
|
113 |
|
114 |
result_message = f"Found {len(source_runs)} results"
|
115 |
-
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message,
|
116 |
|
117 |
|
118 |
# Add a new function to handle row selection
|
@@ -182,7 +186,6 @@ async def main():
|
|
182 |
# Sidebar area
|
183 |
gr.Markdown("### About")
|
184 |
gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
|
185 |
-
gr.Markdown("Start by selecting a question, then optionally filter by source finder and run ID.")
|
186 |
|
187 |
gr.Markdown("### Statistics")
|
188 |
gr.Markdown(f"Total Questions: {len(questions)}")
|
@@ -204,6 +207,12 @@ async def main():
|
|
204 |
],
|
205 |
interactive=False,
|
206 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
with gr.Row():
|
208 |
gr.Markdown("# Sources Found")
|
209 |
with gr.Row():
|
@@ -240,20 +249,20 @@ async def main():
|
|
240 |
update_sources_list,
|
241 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
242 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
243 |
-
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
|
244 |
)
|
245 |
|
246 |
source_finder_dropdown.change(
|
247 |
update_sources_list,
|
248 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
249 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
250 |
-
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
|
251 |
)
|
252 |
|
253 |
run_id_dropdown.change(
|
254 |
update_sources_list,
|
255 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
256 |
-
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
|
257 |
)
|
258 |
|
259 |
|
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
+
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
|
|
18 |
baseline_rankers_dict = {}
|
19 |
baseline_ranker_options = []
|
20 |
run_ids = []
|
21 |
+
available_run_id_dict = {}
|
22 |
finder_options = []
|
23 |
+
previous_run_id = None
|
24 |
|
25 |
run_id_dropdown = None
|
26 |
|
|
|
61 |
|
62 |
# Main function to handle UI interactions
|
63 |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
|
64 |
+
global available_run_id_dict
|
65 |
if not question_option:
|
66 |
+
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
|
67 |
logger.info("processing update")
|
68 |
if type(baseline_ranker_name) == list:
|
69 |
baseline_ranker_name = baseline_ranker_name[0]
|
|
|
77 |
|
78 |
if question_option == "All questions":
|
79 |
if finder_id_int and type(run_id) == str:
|
80 |
+
run_id_int = available_run_id_dict.get(run_id)
|
81 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(run_id_int, baseline_ranker_id_int)
|
82 |
else:
|
83 |
all_stats = None
|
84 |
+
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
|
85 |
|
86 |
# Extract question ID from selection
|
87 |
question_id = questions_dict.get(question_option)
|
88 |
|
89 |
+
available_run_id_dict = await get_run_ids(question_id, finder_id_int)
|
90 |
+
run_id_options = list(available_run_id_dict.keys())
|
91 |
if run_id not in run_id_options:
|
92 |
run_id = run_id_options[0]
|
93 |
|
94 |
+
run_id_int = available_run_id_dict.get(run_id)
|
95 |
|
96 |
|
97 |
|
|
|
99 |
stats = None
|
100 |
# Get source runs data
|
101 |
if finder_id_int:
|
102 |
+
source_runs, stats = await get_unified_sources(question_id, run_id_int, baseline_ranker_id_int)
|
103 |
# Create DataFrame for display
|
104 |
df = pd.DataFrame(source_runs)
|
105 |
|
|
|
113 |
|
114 |
# CSV for download
|
115 |
# csv_data = df.to_csv(index=False)
|
116 |
+
metadata = await get_metadata(question_id, run_id_int)
|
117 |
|
118 |
result_message = f"Found {len(source_runs)} results"
|
119 |
+
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
|
120 |
|
121 |
|
122 |
# Add a new function to handle row selection
|
|
|
186 |
# Sidebar area
|
187 |
gr.Markdown("### About")
|
188 |
gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
|
|
|
189 |
|
190 |
gr.Markdown("### Statistics")
|
191 |
gr.Markdown(f"Total Questions: {len(questions)}")
|
|
|
207 |
],
|
208 |
interactive=False,
|
209 |
)
|
210 |
+
with gr.Row():
|
211 |
+
metadata_text = gr.TextArea(
|
212 |
+
label="Metadata of Source Finder for Selected Question",
|
213 |
+
elem_id="metadata",
|
214 |
+
lines = 2
|
215 |
+
)
|
216 |
with gr.Row():
|
217 |
gr.Markdown("# Sources Found")
|
218 |
with gr.Row():
|
|
|
249 |
update_sources_list,
|
250 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
251 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
252 |
+
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
253 |
)
|
254 |
|
255 |
source_finder_dropdown.change(
|
256 |
update_sources_list,
|
257 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
258 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
259 |
+
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
260 |
)
|
261 |
|
262 |
run_id_dropdown.change(
|
263 |
update_sources_list,
|
264 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
265 |
+
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
266 |
)
|
267 |
|
268 |
|
data_access.py
CHANGED
@@ -35,6 +35,17 @@ async def get_questions():
|
|
35 |
questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
|
36 |
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# Get distinct source finders
|
40 |
async def get_source_finders():
|
@@ -44,32 +55,19 @@ async def get_source_finders():
|
|
44 |
|
45 |
|
46 |
# Get distinct run IDs for a question
|
47 |
-
async def get_run_ids(question_id: int):
|
48 |
-
async with get_async_connection() as conn:
|
49 |
-
query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc"
|
50 |
-
params = [question_id]
|
51 |
-
run_ids = await conn.fetch(query, *params)
|
52 |
-
return [r["run_id"] for r in run_ids]
|
53 |
-
|
54 |
-
|
55 |
-
# Get source runs for a specific question with filters
|
56 |
-
async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None,
|
57 |
-
run_id: Optional[int] = None):
|
58 |
async with get_async_connection() as conn:
|
59 |
-
# Build query with filters
|
60 |
query = """
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
"""
|
67 |
-
|
68 |
-
|
69 |
-
query += " ORDER BY sr.rank DESC"
|
70 |
|
71 |
-
sources = await conn.fetch(query, *params)
|
72 |
-
return [dict(s) for s in sources]
|
73 |
|
74 |
async def get_baseline_rankers():
|
75 |
async with get_async_connection() as conn:
|
@@ -110,13 +108,12 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
|
|
110 |
return results_df
|
111 |
|
112 |
|
113 |
-
async def calculate_cumulative_statistics_for_all_questions(
|
114 |
"""
|
115 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
116 |
|
117 |
Args:
|
118 |
-
|
119 |
-
run_id (int): Run ID to analyze
|
120 |
ranker_id (int): ID of the baseline ranker
|
121 |
|
122 |
Returns:
|
@@ -141,7 +138,7 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_id: in
|
|
141 |
for question_id in question_ids:
|
142 |
try:
|
143 |
# Get unified sources for this question
|
144 |
-
|
145 |
|
146 |
if sources and len(sources) > 0:
|
147 |
valid_questions += 1
|
@@ -186,62 +183,68 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_id: in
|
|
186 |
return pd.DataFrame([cumulative_stats])
|
187 |
|
188 |
|
189 |
-
async def get_unified_sources(question_id: int,
|
190 |
"""
|
191 |
Create unified view of sources from both baseline_sources and source_runs
|
192 |
with indicators of where each source appears and their respective ranks.
|
193 |
"""
|
194 |
async with get_async_connection() as conn:
|
195 |
-
|
196 |
-
query_runs = """
|
197 |
-
SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio,
|
198 |
-
sr.reason as source_reason, sr.metadata
|
199 |
-
FROM source_runs sr join talmud_bavli tb on sr.sugya_id = tb.xml_id
|
200 |
-
WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
|
201 |
-
"""
|
202 |
-
source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
|
203 |
|
204 |
-
|
205 |
-
query_baseline = """
|
206 |
-
SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
|
207 |
-
FROM baseline_sources bs join talmud_bavli tb on bs.sugya_id = tb.xml_id
|
208 |
-
WHERE bs.question_id = $1 AND bs.ranker_id = $2
|
209 |
-
"""
|
210 |
-
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
211 |
-
|
212 |
-
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
213 |
-
|
214 |
-
# Convert to dictionaries for easier lookup
|
215 |
-
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
216 |
-
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
217 |
-
|
218 |
-
# Get all unique sugya_ids
|
219 |
-
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
220 |
-
|
221 |
-
# Build unified results
|
222 |
-
unified_results = []
|
223 |
-
for sugya_id in all_sugya_ids:
|
224 |
-
in_source_run = sugya_id in source_runs_dict
|
225 |
-
in_baseline = sugya_id in baseline_dict
|
226 |
-
if in_baseline:
|
227 |
-
info = baseline_dict[sugya_id]
|
228 |
-
else:
|
229 |
-
info = source_runs_dict[sugya_id]
|
230 |
-
result = {
|
231 |
-
"id": sugya_id,
|
232 |
-
"tractate": info.get("tractate"),
|
233 |
-
"folio": info.get("folio"),
|
234 |
-
"in_baseline": "Yes" if in_baseline else "No",
|
235 |
-
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
|
236 |
-
"in_source_run": "Yes" if in_source_run else "No",
|
237 |
-
"source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
|
238 |
-
"source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A"),
|
239 |
-
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
240 |
-
}
|
241 |
-
unified_results.append(result)
|
242 |
|
243 |
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
|
247 |
async def get_source_text(tractate_chunk_id: int):
|
|
|
35 |
questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
|
36 |
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
37 |
|
38 |
+
async def get_metadata(question_id: int, source_finder_id_run_id: int):
|
39 |
+
async with get_async_connection() as conn:
|
40 |
+
metadata = await conn.fetchrow('''
|
41 |
+
SELECT metadata
|
42 |
+
FROM source_finder_run_question_metadata sfrqm
|
43 |
+
WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
|
44 |
+
''', question_id, source_finder_id_run_id)
|
45 |
+
if metadata is None:
|
46 |
+
return ""
|
47 |
+
return metadata.get('metadata')
|
48 |
+
|
49 |
|
50 |
# Get distinct source finders
|
51 |
async def get_source_finders():
|
|
|
55 |
|
56 |
|
57 |
# Get distinct run IDs for a question
|
58 |
+
async def get_run_ids(question_id: int, source_finder_id: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
async with get_async_connection() as conn:
|
|
|
60 |
query = """
|
61 |
+
select distinct sfr.description, srs.source_finder_run_id as run_id
|
62 |
+
from talmudexplore.source_run_results srs
|
63 |
+
join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
|
64 |
+
join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
|
65 |
+
where sfr.source_finder_id = $1
|
66 |
+
and srs.question_id = $2
|
67 |
"""
|
68 |
+
run_ids = await conn.fetch(query, source_finder_id, question_id)
|
69 |
+
return {r["description"]:r["run_id"] for r in run_ids}
|
|
|
70 |
|
|
|
|
|
71 |
|
72 |
async def get_baseline_rankers():
|
73 |
async with get_async_connection() as conn:
|
|
|
108 |
return results_df
|
109 |
|
110 |
|
111 |
+
async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id: int, ranker_id: int):
|
112 |
"""
|
113 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
114 |
|
115 |
Args:
|
116 |
+
source_finder_run_id (int): ID of the source finder and run as appears in source runs
|
|
|
117 |
ranker_id (int): ID of the baseline ranker
|
118 |
|
119 |
Returns:
|
|
|
138 |
for question_id in question_ids:
|
139 |
try:
|
140 |
# Get unified sources for this question
|
141 |
+
stats, sources = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
|
142 |
|
143 |
if sources and len(sources) > 0:
|
144 |
valid_questions += 1
|
|
|
183 |
return pd.DataFrame([cumulative_stats])
|
184 |
|
185 |
|
186 |
+
async def get_unified_sources(question_id: int, source_finder_run_id: int, ranker_id: int):
|
187 |
"""
|
188 |
Create unified view of sources from both baseline_sources and source_runs
|
189 |
with indicators of where each source appears and their respective ranks.
|
190 |
"""
|
191 |
async with get_async_connection() as conn:
|
192 |
+
stats_df, unified_results = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
return unified_results, stats_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
+
async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
198 |
+
# Get sources from source_runs
|
199 |
+
query_runs = """
|
200 |
+
SELECT tb.tractate_chunk_id as id,
|
201 |
+
sr.rank as source_rank,
|
202 |
+
sr.tractate,
|
203 |
+
sr.folio,
|
204 |
+
sr.reason as source_reason
|
205 |
+
FROM source_run_results sr
|
206 |
+
join talmud_bavli tb on sr.sugya_id = tb.xml_id
|
207 |
+
WHERE sr.question_id = $1
|
208 |
+
AND sr.source_finder_run_id = $2
|
209 |
+
"""
|
210 |
+
source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
|
211 |
+
# Get sources from baseline_sources
|
212 |
+
query_baseline = """
|
213 |
+
SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
|
214 |
+
FROM baseline_sources bs
|
215 |
+
join talmud_bavli tb on bs.sugya_id = tb.xml_id
|
216 |
+
WHERE bs.question_id = $1
|
217 |
+
AND bs.ranker_id = $2
|
218 |
+
"""
|
219 |
+
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
220 |
+
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
221 |
+
# Convert to dictionaries for easier lookup
|
222 |
+
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
223 |
+
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
224 |
+
# Get all unique sugya_ids
|
225 |
+
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
226 |
+
# Build unified results
|
227 |
+
unified_results = []
|
228 |
+
for sugya_id in all_sugya_ids:
|
229 |
+
in_source_run = sugya_id in source_runs_dict
|
230 |
+
in_baseline = sugya_id in baseline_dict
|
231 |
+
if in_baseline:
|
232 |
+
info = baseline_dict[sugya_id]
|
233 |
+
else:
|
234 |
+
info = source_runs_dict[sugya_id]
|
235 |
+
result = {
|
236 |
+
"id": sugya_id,
|
237 |
+
"tractate": info.get("tractate"),
|
238 |
+
"folio": info.get("folio"),
|
239 |
+
"in_baseline": "Yes" if in_baseline else "No",
|
240 |
+
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
|
241 |
+
"in_source_run": "Yes" if in_source_run else "No",
|
242 |
+
"source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
|
243 |
+
"source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A"),
|
244 |
+
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
245 |
+
}
|
246 |
+
unified_results.append(result)
|
247 |
+
return stats_df, unified_results
|
248 |
|
249 |
|
250 |
async def get_source_text(tractate_chunk_id: int):
|
eval_tables.py
CHANGED
@@ -51,12 +51,35 @@ def create_eval_database():
|
|
51 |
);
|
52 |
''')
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Create table for logging all sources from each run
|
55 |
cursor.execute('''
|
56 |
-
CREATE TABLE IF NOT EXISTS
|
57 |
id SERIAL PRIMARY KEY,
|
58 |
-
|
59 |
-
run_id TEXT NOT NULL,
|
60 |
run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
61 |
question_id INTEGER NOT NULL,
|
62 |
tractate TEXT NOT NULL,
|
@@ -64,7 +87,7 @@ def create_eval_database():
|
|
64 |
sugya_id TEXT NOT NULL,
|
65 |
rank INTEGER NOT NULL,
|
66 |
reason TEXT,
|
67 |
-
FOREIGN KEY (
|
68 |
FOREIGN KEY (question_id) REFERENCES questions(id)
|
69 |
);
|
70 |
''')
|
@@ -99,8 +122,8 @@ def load_baseline_sources():
|
|
99 |
|
100 |
if __name__ == '__main__':
|
101 |
# Create the database
|
102 |
-
|
103 |
-
load_baseline_sources()
|
104 |
|
105 |
|
106 |
|
|
|
51 |
);
|
52 |
''')
|
53 |
|
54 |
+
cursor.execute('''
|
55 |
+
CREATE TABLE IF NOT EXISTS source_finder_runs (
|
56 |
+
id SERIAL PRIMARY KEY,
|
57 |
+
run_id INTEGER NOT NULL,
|
58 |
+
source_finder_id INTEGER NOT NULL,
|
59 |
+
description TEXT,
|
60 |
+
FOREIGN KEY (source_finder_id) REFERENCES source_finders(id),
|
61 |
+
CONSTRAINT unique_source_per_run_id UNIQUE(run_id, source_finder_id)
|
62 |
+
);
|
63 |
+
''')
|
64 |
+
|
65 |
+
cursor.execute('''
|
66 |
+
CREATE TABLE IF NOT EXISTS source_finder_run_question_metadata (
|
67 |
+
id SERIAL PRIMARY KEY,
|
68 |
+
question_id INTEGER NOT NULL,
|
69 |
+
source_finder_run_id INTEGER NOT NULL,
|
70 |
+
metadata JSON,
|
71 |
+
FOREIGN KEY (source_finder_run_id) REFERENCES source_finder_runs(id),
|
72 |
+
FOREIGN KEY (question_id) REFERENCES questions(id),
|
73 |
+
CONSTRAINT unique_question_per_run_id UNIQUE(question_id, source_finder_run_id)
|
74 |
+
);
|
75 |
+
''')
|
76 |
+
|
77 |
+
|
78 |
# Create table for logging all sources from each run
|
79 |
cursor.execute('''
|
80 |
+
CREATE TABLE IF NOT EXISTS source_run_results (
|
81 |
id SERIAL PRIMARY KEY,
|
82 |
+
source_finder_run_id INTEGER NOT NULL,
|
|
|
83 |
run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
84 |
question_id INTEGER NOT NULL,
|
85 |
tractate TEXT NOT NULL,
|
|
|
87 |
sugya_id TEXT NOT NULL,
|
88 |
rank INTEGER NOT NULL,
|
89 |
reason TEXT,
|
90 |
+
FOREIGN KEY (source_finder_run_id) REFERENCES source_finder_runs(id),
|
91 |
FOREIGN KEY (question_id) REFERENCES questions(id)
|
92 |
);
|
93 |
''')
|
|
|
122 |
|
123 |
if __name__ == '__main__':
|
124 |
# Create the database
|
125 |
+
create_eval_database()
|
126 |
+
# load_baseline_sources()
|
127 |
|
128 |
|
129 |
|
scripts/__init__.py
ADDED
File without changes
|
tests/test_db_layer.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
-
from data_access import calculate_cumulative_statistics_for_all_questions
|
5 |
from data_access import get_unified_sources
|
6 |
|
7 |
|
@@ -20,9 +20,6 @@ async def test_get_unified_sources():
|
|
20 |
# You can also check specific stats columns
|
21 |
assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
@pytest.mark.asyncio
|
27 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
28 |
# Test with known source_finder_id, run_id, and ranker_id
|
@@ -65,3 +62,47 @@ async def test_calculate_cumulative_statistics_for_all_questions():
|
|
65 |
assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
|
66 |
0] <= 100, "High ranked overlap percentage should be between 0 and 100"
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
+
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids
|
5 |
from data_access import get_unified_sources
|
6 |
|
7 |
|
|
|
20 |
# You can also check specific stats columns
|
21 |
assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
|
22 |
|
|
|
|
|
|
|
23 |
@pytest.mark.asyncio
|
24 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
25 |
# Test with known source_finder_id, run_id, and ranker_id
|
|
|
62 |
assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
|
63 |
0] <= 100, "High ranked overlap percentage should be between 0 and 100"
|
64 |
|
65 |
+
@pytest.mark.asyncio
|
66 |
+
async def test_get_metadata_none_returned():
|
67 |
+
# Test with known source_finder_id, run_id, and ranker_id
|
68 |
+
source_finder_id = 1
|
69 |
+
run_id = 1
|
70 |
+
question_id = 1
|
71 |
+
|
72 |
+
# Call the function to test
|
73 |
+
result = await get_metadata(question_id, source_finder_id, run_id)
|
74 |
+
|
75 |
+
assert result == "", "Should return empty string when no metadata is found"
|
76 |
+
|
77 |
+
@pytest.mark.asyncio
|
78 |
+
async def test_get_metadata():
|
79 |
+
# Test with known source_finder_id, run_id, and ranker_id
|
80 |
+
source_finder_run_id = 4
|
81 |
+
question_id = 1
|
82 |
+
|
83 |
+
# Call the function to test
|
84 |
+
result = await get_metadata(question_id, source_finder_run_id)
|
85 |
+
|
86 |
+
assert result is not None, "Should return metadata when it exists"
|
87 |
+
|
88 |
+
|
89 |
+
@pytest.mark.asyncio
|
90 |
+
async def test_get_run_ids():
|
91 |
+
# Test with known question_id and source_finder_id
|
92 |
+
question_id = 2 # Using a question ID that exists in the test database
|
93 |
+
source_finder_id = 2 # Using a source finder ID that exists in the test database
|
94 |
+
|
95 |
+
# Call the function to test
|
96 |
+
result = await get_run_ids(question_id, source_finder_id)
|
97 |
+
|
98 |
+
# Verify the result is a dictionary
|
99 |
+
assert isinstance(result, dict), "Result should be a dictionary"
|
100 |
+
|
101 |
+
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
|
102 |
+
assert len(result) > 0, "Should return at least one run ID"
|
103 |
+
|
104 |
+
# Test with a non-existent question_id
|
105 |
+
non_existent_question_id = 9999
|
106 |
+
empty_result = await get_run_ids(non_existent_question_id, source_finder_id)
|
107 |
+
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
|
108 |
+
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
|