Spaces:
Running
Running
fix connection reuse
Browse files- app.py +49 -47
- data_access.py +141 -153
- eval_tables.py +5 -0
- tests/test_db_layer.py +21 -17
app.py
CHANGED
@@ -5,7 +5,8 @@ import pandas as pd
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
-
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata
|
|
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
@@ -20,7 +21,7 @@ baseline_ranker_options = []
|
|
20 |
run_ids = []
|
21 |
available_run_id_dict = {}
|
22 |
finder_options = []
|
23 |
-
previous_run_id =
|
24 |
|
25 |
run_id_dropdown = None
|
26 |
|
@@ -29,13 +30,13 @@ run_id_dropdown = None
|
|
29 |
# Initialize data in a single async function
|
30 |
async def initialize_data():
|
31 |
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
questions = await get_questions()
|
34 |
-
source_finders = await get_source_finders()
|
35 |
-
|
36 |
-
baseline_rankers = await get_baseline_rankers()
|
37 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
38 |
-
|
39 |
# Convert to dictionaries for easier lookup
|
40 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
41 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
@@ -52,7 +53,7 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
|
|
52 |
if evt:
|
53 |
logger.info(f"event: {evt.target.elem_id}")
|
54 |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
|
55 |
-
return gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
56 |
|
57 |
if type(run_id) == str:
|
58 |
previous_run_id = run_id
|
@@ -65,55 +66,56 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
65 |
if not question_option:
|
66 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
|
67 |
logger.info("processing update")
|
68 |
-
|
69 |
-
baseline_ranker_name
|
70 |
-
|
71 |
-
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
72 |
|
73 |
-
|
74 |
-
finder_id_int = source_finders_dict.get(source_finder_name)
|
75 |
-
else:
|
76 |
-
finder_id_int = None
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
run_id_int = available_run_id_dict.get(run_id)
|
81 |
-
all_stats = await calculate_cumulative_statistics_for_all_questions(run_id_int, baseline_ranker_id_int)
|
82 |
else:
|
83 |
-
|
84 |
-
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
|
85 |
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
if run_id not in run_id_options:
|
92 |
-
run_id = run_id_options[0]
|
93 |
|
94 |
-
|
|
|
|
|
|
|
95 |
|
|
|
96 |
|
97 |
|
98 |
-
source_runs = None
|
99 |
-
stats = None
|
100 |
-
# Get source runs data
|
101 |
-
if finder_id_int:
|
102 |
-
source_runs, stats = await get_unified_sources(question_id, run_id_int, baseline_ranker_id_int)
|
103 |
-
# Create DataFrame for display
|
104 |
-
df = pd.DataFrame(source_runs)
|
105 |
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
'folio', 'reason']
|
112 |
-
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
result_message = f"Found {len(source_runs)} results"
|
119 |
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
|
@@ -128,7 +130,8 @@ async def handle_row_selection_async(evt: gr.SelectData):
|
|
128 |
# Get the ID from the selected row
|
129 |
tractate_chunk_id = evt.row_value[0]
|
130 |
# Get the source text
|
131 |
-
|
|
|
132 |
return text
|
133 |
except Exception as e:
|
134 |
return f"Error retrieving source text: {str(e)}"
|
@@ -248,7 +251,6 @@ async def main():
|
|
248 |
question_dropdown.change(
|
249 |
update_sources_list,
|
250 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
251 |
-
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
252 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
253 |
)
|
254 |
|
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
+
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
|
9 |
+
get_async_connection
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
21 |
run_ids = []
|
22 |
available_run_id_dict = {}
|
23 |
finder_options = []
|
24 |
+
previous_run_id = "initial_run"
|
25 |
|
26 |
run_id_dropdown = None
|
27 |
|
|
|
30 |
# Initialize data in a single async function
|
31 |
async def initialize_data():
|
32 |
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
|
33 |
+
async with get_async_connection() as conn:
|
34 |
+
# Get questions and source finders
|
35 |
+
questions = await get_questions(conn)
|
36 |
+
source_finders = await get_source_finders(conn)
|
37 |
+
baseline_rankers = await get_baseline_rankers(conn)
|
38 |
|
|
|
|
|
|
|
|
|
39 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
|
40 |
# Convert to dictionaries for easier lookup
|
41 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
42 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
|
53 |
if evt:
|
54 |
logger.info(f"event: {evt.target.elem_id}")
|
55 |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
|
56 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
57 |
|
58 |
if type(run_id) == str:
|
59 |
previous_run_id = run_id
|
|
|
66 |
if not question_option:
|
67 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
|
68 |
logger.info("processing update")
|
69 |
+
async with get_async_connection() as conn:
|
70 |
+
if type(baseline_ranker_name) == list:
|
71 |
+
baseline_ranker_name = baseline_ranker_name[0]
|
|
|
72 |
|
73 |
+
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
|
|
|
|
|
|
74 |
|
75 |
+
if len(source_finder_name):
|
76 |
+
finder_id_int = source_finders_dict.get(source_finder_name)
|
|
|
|
|
77 |
else:
|
78 |
+
finder_id_int = None
|
|
|
79 |
|
80 |
+
if question_option == "All questions":
|
81 |
+
if finder_id_int and type(run_id) == str:
|
82 |
+
run_id_int = available_run_id_dict.get(run_id)
|
83 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
|
84 |
+
else:
|
85 |
+
all_stats = None
|
86 |
+
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
|
87 |
|
88 |
+
# Extract question ID from selection
|
89 |
+
question_id = questions_dict.get(question_option)
|
|
|
|
|
90 |
|
91 |
+
available_run_id_dict = await get_run_ids(conn, question_id, finder_id_int)
|
92 |
+
run_id_options = list(available_run_id_dict.keys())
|
93 |
+
if run_id not in run_id_options:
|
94 |
+
run_id = run_id_options[0]
|
95 |
|
96 |
+
run_id_int = available_run_id_dict.get(run_id)
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
source_runs = None
|
101 |
+
stats = None
|
102 |
+
# Get source runs data
|
103 |
+
if finder_id_int:
|
104 |
+
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int)
|
105 |
+
# Create DataFrame for display
|
106 |
+
df = pd.DataFrame(source_runs)
|
107 |
|
108 |
+
if not source_runs:
|
109 |
+
return None, None, run_id_options, "No results found for the selected filters",
|
|
|
|
|
110 |
|
111 |
+
# Format table columns
|
112 |
+
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
|
113 |
+
'folio', 'reason']
|
114 |
+
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
115 |
+
|
116 |
+
# CSV for download
|
117 |
+
# csv_data = df.to_csv(index=False)
|
118 |
+
metadata = await get_metadata(conn, question_id, run_id_int)
|
119 |
|
120 |
result_message = f"Found {len(source_runs)} results"
|
121 |
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
|
|
|
130 |
# Get the ID from the selected row
|
131 |
tractate_chunk_id = evt.row_value[0]
|
132 |
# Get the source text
|
133 |
+
async with get_async_connection() as conn:
|
134 |
+
text = await get_source_text(conn, tractate_chunk_id)
|
135 |
return text
|
136 |
except Exception as e:
|
137 |
return f"Error retrieving source text: {str(e)}"
|
|
|
251 |
question_dropdown.change(
|
252 |
update_sources_list,
|
253 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
|
|
254 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
255 |
)
|
256 |
|
data_access.py
CHANGED
@@ -30,85 +30,80 @@ async def get_async_connection(schema="talmudexplore"):
|
|
30 |
await conn.close()
|
31 |
|
32 |
|
33 |
-
async def get_questions():
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
return ""
|
47 |
-
return metadata.get('metadata')
|
48 |
|
49 |
|
50 |
# Get distinct source finders
|
51 |
-
async def get_source_finders():
|
52 |
-
|
53 |
-
|
54 |
-
return [{"id": f["id"], "name": f["name"]} for f in finders]
|
55 |
|
56 |
|
57 |
# Get distinct run IDs for a question
|
58 |
-
async def get_run_ids(question_id: int, source_finder_id: int):
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources):
|
78 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
79 |
# e.g. overlap, high ranked overlap, etc.
|
80 |
-
async with get_async_connection() as conn:
|
81 |
-
actual_sources_set = {s["id"] for s in source_runs_sources}
|
82 |
-
baseline_sources_set = {s["id"] for s in baseline_sources}
|
83 |
-
|
84 |
-
# Calculate overlap
|
85 |
-
overlap = actual_sources_set.intersection(baseline_sources_set)
|
86 |
-
# only_in_1 = actual_sources_set - baseline_sources_set
|
87 |
-
# only_in_2 = baseline_sources_set - actual_sources_set
|
88 |
-
|
89 |
-
# Calculate high-ranked overlap (rank >= 4)
|
90 |
-
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
|
91 |
-
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
|
92 |
-
|
93 |
-
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
94 |
-
|
95 |
-
results = {
|
96 |
-
"total_baseline_sources": len(baseline_sources),
|
97 |
-
"total_found_sources": len(source_runs_sources),
|
98 |
-
"overlap_count": len(overlap),
|
99 |
-
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
|
100 |
-
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
|
101 |
-
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
|
102 |
-
"num_high_ranked_found_sources": len(actual_high_ranked),
|
103 |
-
"high_ranked_overlap_count": len(high_ranked_overlap),
|
104 |
-
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
105 |
-
}
|
106 |
-
#convert results to dataframe
|
107 |
-
results_df = pd.DataFrame([results])
|
108 |
-
return results_df
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
"""
|
113 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
114 |
|
@@ -119,83 +114,75 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id
|
|
119 |
Returns:
|
120 |
pd.DataFrame: DataFrame containing aggregated statistics
|
121 |
"""
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
async def get_unified_sources(question_id: int, source_finder_run_id: int, ranker_id: int):
|
187 |
"""
|
188 |
Create unified view of sources from both baseline_sources and source_runs
|
189 |
with indicators of where each source appears and their respective ranks.
|
190 |
"""
|
191 |
-
async with get_async_connection() as conn:
|
192 |
-
stats_df, unified_results = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
|
193 |
-
|
194 |
-
return unified_results, stats_df
|
195 |
|
196 |
-
|
197 |
-
async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
198 |
-
# Get sources from source_runs
|
199 |
query_runs = """
|
200 |
SELECT tb.tractate_chunk_id as id,
|
201 |
sr.rank as source_rank,
|
@@ -217,7 +204,7 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
|
217 |
AND bs.ranker_id = $2
|
218 |
"""
|
219 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
220 |
-
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
221 |
# Convert to dictionaries for easier lookup
|
222 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
223 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
@@ -244,21 +231,22 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
|
244 |
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
245 |
}
|
246 |
unified_results.append(result)
|
247 |
-
|
|
|
248 |
|
249 |
|
250 |
-
async def get_source_text(tractate_chunk_id: int):
|
251 |
"""
|
252 |
Retrieves the text content for a given tractate chunk ID.
|
253 |
"""
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
|
263 |
def get_pg_sync_connection(schema="talmudexplore"):
|
264 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
|
|
30 |
await conn.close()
|
31 |
|
32 |
|
33 |
+
async def get_questions(conn: asyncpg.Connection):
|
34 |
+
questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
|
35 |
+
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
36 |
+
|
37 |
+
async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int):
|
38 |
+
metadata = await conn.fetchrow('''
|
39 |
+
SELECT metadata
|
40 |
+
FROM source_finder_run_question_metadata sfrqm
|
41 |
+
WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
|
42 |
+
''', question_id, source_finder_id_run_id)
|
43 |
+
if metadata is None:
|
44 |
+
return ""
|
45 |
+
return metadata.get('metadata')
|
|
|
|
|
46 |
|
47 |
|
48 |
# Get distinct source finders
|
49 |
+
async def get_source_finders(conn: asyncpg.Connection):
|
50 |
+
finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
|
51 |
+
return [{"id": f["id"], "name": f["name"]} for f in finders]
|
|
|
52 |
|
53 |
|
54 |
# Get distinct run IDs for a question
|
55 |
+
async def get_run_ids(conn: asyncpg.Connection, question_id: int, source_finder_id: int):
|
56 |
+
query = """
|
57 |
+
select distinct sfr.description, srs.source_finder_run_id as run_id
|
58 |
+
from talmudexplore.source_run_results srs
|
59 |
+
join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
|
60 |
+
join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
|
61 |
+
where sfr.source_finder_id = $1
|
62 |
+
and srs.question_id = $2
|
63 |
+
"""
|
64 |
+
run_ids = await conn.fetch(query, source_finder_id, question_id)
|
65 |
+
return {r["description"]:r["run_id"] for r in run_ids}
|
66 |
+
|
67 |
+
|
68 |
+
async def get_baseline_rankers(conn: asyncpg.Connection):
|
69 |
+
rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
|
70 |
+
return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
|
71 |
+
|
72 |
+
async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
|
|
|
|
|
73 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
74 |
# e.g. overlap, high ranked overlap, etc.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
actual_sources_set = {s["id"] for s in source_runs_sources}
|
77 |
+
baseline_sources_set = {s["id"] for s in baseline_sources}
|
78 |
+
|
79 |
+
# Calculate overlap
|
80 |
+
overlap = actual_sources_set.intersection(baseline_sources_set)
|
81 |
+
# only_in_1 = actual_sources_set - baseline_sources_set
|
82 |
+
# only_in_2 = baseline_sources_set - actual_sources_set
|
83 |
+
|
84 |
+
# Calculate high-ranked overlap (rank >= 4)
|
85 |
+
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
|
86 |
+
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
|
87 |
+
|
88 |
+
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
89 |
+
|
90 |
+
results = {
|
91 |
+
"total_baseline_sources": len(baseline_sources),
|
92 |
+
"total_found_sources": len(source_runs_sources),
|
93 |
+
"overlap_count": len(overlap),
|
94 |
+
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
|
95 |
+
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
|
96 |
+
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
|
97 |
+
"num_high_ranked_found_sources": len(actual_high_ranked),
|
98 |
+
"high_ranked_overlap_count": len(high_ranked_overlap),
|
99 |
+
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
100 |
+
}
|
101 |
+
#convert results to dataframe
|
102 |
+
results_df = pd.DataFrame([results])
|
103 |
+
return results_df
|
104 |
+
|
105 |
+
|
106 |
+
async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int):
|
107 |
"""
|
108 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
109 |
|
|
|
114 |
Returns:
|
115 |
pd.DataFrame: DataFrame containing aggregated statistics
|
116 |
"""
|
117 |
+
# Get all questions
|
118 |
+
query = "SELECT id FROM questions ORDER BY id"
|
119 |
+
questions = await conn.fetch(query)
|
120 |
+
question_ids = [q["id"] for q in questions]
|
121 |
+
|
122 |
+
# Initialize aggregates
|
123 |
+
total_baseline_sources = 0
|
124 |
+
total_found_sources = 0
|
125 |
+
total_overlap = 0
|
126 |
+
total_high_ranked_baseline = 0
|
127 |
+
total_high_ranked_found = 0
|
128 |
+
total_high_ranked_overlap = 0
|
129 |
+
|
130 |
+
# Process each question
|
131 |
+
valid_questions = 0
|
132 |
+
for question_id in question_ids:
|
133 |
+
try:
|
134 |
+
# Get unified sources for this question
|
135 |
+
sources, stats = await get_unified_sources(conn, question_id, ranker_id, source_finder_run_id)
|
136 |
+
|
137 |
+
if sources and len(sources) > 0:
|
138 |
+
valid_questions += 1
|
139 |
+
stats_dict = stats.iloc[0].to_dict()
|
140 |
+
|
141 |
+
# Add to running totals
|
142 |
+
total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
|
143 |
+
total_found_sources += stats_dict.get('total_found_sources', 0)
|
144 |
+
total_overlap += stats_dict.get('overlap_count', 0)
|
145 |
+
total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
|
146 |
+
total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
|
147 |
+
total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
|
148 |
+
except Exception as e:
|
149 |
+
# Skip questions with errors
|
150 |
+
continue
|
151 |
+
|
152 |
+
# Calculate overall percentages
|
153 |
+
overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
|
154 |
+
if max(total_baseline_sources, total_found_sources) > 0 else 0
|
155 |
+
|
156 |
+
high_ranked_overlap_percentage = round(
|
157 |
+
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
|
158 |
+
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
|
159 |
+
|
160 |
+
# Compile results
|
161 |
+
cumulative_stats = {
|
162 |
+
"total_questions_analyzed": valid_questions,
|
163 |
+
"total_baseline_sources": total_baseline_sources,
|
164 |
+
"total_found_sources": total_found_sources,
|
165 |
+
"total_overlap_count": total_overlap,
|
166 |
+
"overall_overlap_percentage": overlap_percentage,
|
167 |
+
"total_high_ranked_baseline_sources": total_high_ranked_baseline,
|
168 |
+
"total_high_ranked_found_sources": total_high_ranked_found,
|
169 |
+
"total_high_ranked_overlap_count": total_high_ranked_overlap,
|
170 |
+
"overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
|
171 |
+
"avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
|
172 |
+
2) if valid_questions > 0 else 0,
|
173 |
+
"avg_found_sources_per_question": round(total_found_sources / valid_questions,
|
174 |
+
2) if valid_questions > 0 else 0
|
175 |
+
}
|
176 |
+
|
177 |
+
return pd.DataFrame([cumulative_stats])
|
178 |
+
|
179 |
+
|
180 |
+
async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int):
|
|
|
181 |
"""
|
182 |
Create unified view of sources from both baseline_sources and source_runs
|
183 |
with indicators of where each source appears and their respective ranks.
|
184 |
"""
|
|
|
|
|
|
|
|
|
185 |
|
|
|
|
|
|
|
186 |
query_runs = """
|
187 |
SELECT tb.tractate_chunk_id as id,
|
188 |
sr.rank as source_rank,
|
|
|
204 |
AND bs.ranker_id = $2
|
205 |
"""
|
206 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
207 |
+
stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
|
208 |
# Convert to dictionaries for easier lookup
|
209 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
210 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
|
|
231 |
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
232 |
}
|
233 |
unified_results.append(result)
|
234 |
+
|
235 |
+
return unified_results, stats_df
|
236 |
|
237 |
|
238 |
+
async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int):
|
239 |
"""
|
240 |
Retrieves the text content for a given tractate chunk ID.
|
241 |
"""
|
242 |
+
|
243 |
+
query = """
|
244 |
+
SELECT tb.text_with_nikud as text
|
245 |
+
FROM talmud_bavli tb
|
246 |
+
WHERE tb.tractate_chunk_id = $1
|
247 |
+
"""
|
248 |
+
result = await conn.fetchrow(query, tractate_chunk_id)
|
249 |
+
return result["text"] if result else "Source text not found"
|
250 |
|
251 |
def get_pg_sync_connection(schema="talmudexplore"):
|
252 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
eval_tables.py
CHANGED
@@ -92,6 +92,11 @@ def create_eval_database():
|
|
92 |
);
|
93 |
''')
|
94 |
|
|
|
|
|
|
|
|
|
|
|
95 |
conn.commit()
|
96 |
conn.close()
|
97 |
|
|
|
92 |
);
|
93 |
''')
|
94 |
|
95 |
+
cursor.execute('''alter table source_run_results
|
96 |
+
add constraint source_run_results_pk
|
97 |
+
unique (source_finder_run_id, question_id, sugya_id);
|
98 |
+
''')
|
99 |
+
|
100 |
conn.commit()
|
101 |
conn.close()
|
102 |
|
tests/test_db_layer.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
-
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids
|
|
|
5 |
from data_access import get_unified_sources
|
6 |
|
7 |
|
8 |
@pytest.mark.asyncio
|
9 |
async def test_get_unified_sources():
|
10 |
-
|
|
|
11 |
assert results is not None
|
12 |
assert stats is not None
|
13 |
|
@@ -23,12 +25,12 @@ async def test_get_unified_sources():
|
|
23 |
@pytest.mark.asyncio
|
24 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
25 |
# Test with known source_finder_id, run_id, and ranker_id
|
26 |
-
|
27 |
-
run_id = 1
|
28 |
ranker_id = 1
|
29 |
|
30 |
# Call the function to test
|
31 |
-
|
|
|
32 |
|
33 |
# Check basic structure of results
|
34 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
@@ -65,12 +67,12 @@ async def test_calculate_cumulative_statistics_for_all_questions():
|
|
65 |
@pytest.mark.asyncio
|
66 |
async def test_get_metadata_none_returned():
|
67 |
# Test with known source_finder_id, run_id, and ranker_id
|
68 |
-
|
69 |
-
run_id = 1
|
70 |
question_id = 1
|
71 |
|
72 |
# Call the function to test
|
73 |
-
|
|
|
74 |
|
75 |
assert result == "", "Should return empty string when no metadata is found"
|
76 |
|
@@ -81,7 +83,8 @@ async def test_get_metadata():
|
|
81 |
question_id = 1
|
82 |
|
83 |
# Call the function to test
|
84 |
-
|
|
|
85 |
|
86 |
assert result is not None, "Should return metadata when it exists"
|
87 |
|
@@ -93,16 +96,17 @@ async def test_get_run_ids():
|
|
93 |
source_finder_id = 2 # Using a source finder ID that exists in the test database
|
94 |
|
95 |
# Call the function to test
|
96 |
-
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
|
108 |
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
|
|
|
1 |
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
+
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
|
5 |
+
get_async_connection
|
6 |
from data_access import get_unified_sources
|
7 |
|
8 |
|
9 |
@pytest.mark.asyncio
|
10 |
async def test_get_unified_sources():
|
11 |
+
async with get_async_connection() as conn:
|
12 |
+
results, stats = await get_unified_sources(conn,2, 2, 1)
|
13 |
assert results is not None
|
14 |
assert stats is not None
|
15 |
|
|
|
25 |
@pytest.mark.asyncio
|
26 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
27 |
# Test with known source_finder_id, run_id, and ranker_id
|
28 |
+
source_finder_run_id = 2
|
|
|
29 |
ranker_id = 1
|
30 |
|
31 |
# Call the function to test
|
32 |
+
async with get_async_connection() as conn:
|
33 |
+
result = await calculate_cumulative_statistics_for_all_questions(conn, source_finder_run_id, ranker_id)
|
34 |
|
35 |
# Check basic structure of results
|
36 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
|
|
67 |
@pytest.mark.asyncio
|
68 |
async def test_get_metadata_none_returned():
|
69 |
# Test with known source_finder_id, run_id, and ranker_id
|
70 |
+
source_finder_run_id = 1
|
|
|
71 |
question_id = 1
|
72 |
|
73 |
# Call the function to test
|
74 |
+
async with get_async_connection() as conn:
|
75 |
+
result = await get_metadata(conn, question_id, source_finder_run_id)
|
76 |
|
77 |
assert result == "", "Should return empty string when no metadata is found"
|
78 |
|
|
|
83 |
question_id = 1
|
84 |
|
85 |
# Call the function to test
|
86 |
+
async with get_async_connection() as conn:
|
87 |
+
result = await get_metadata(conn, question_id, source_finder_run_id)
|
88 |
|
89 |
assert result is not None, "Should return metadata when it exists"
|
90 |
|
|
|
96 |
source_finder_id = 2 # Using a source finder ID that exists in the test database
|
97 |
|
98 |
# Call the function to test
|
99 |
+
async with get_async_connection() as conn:
|
100 |
+
result = await get_run_ids(conn, question_id, source_finder_id)
|
101 |
|
102 |
+
# Verify the result is a dictionary
|
103 |
+
assert isinstance(result, dict), "Result should be a dictionary"
|
104 |
|
105 |
+
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
|
106 |
+
assert len(result) > 0, "Should return at least one run ID"
|
107 |
|
108 |
+
# Test with a non-existent question_id
|
109 |
+
non_existent_question_id = 9999
|
110 |
+
empty_result = await get_run_ids(conn, non_existent_question_id, source_finder_id)
|
111 |
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
|
112 |
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
|