Spaces:
Running
Running
changes for new version
Browse files- app.py +85 -74
- data_access.py +32 -13
- requirements.txt +1 -1
- tests/test_db_layer.py +14 -5
app.py
CHANGED
@@ -33,45 +33,60 @@ run_id_dropdown = None
|
|
33 |
|
34 |
# Initialize data in a single async function
|
35 |
async def initialize_data():
|
36 |
-
global
|
37 |
async with get_async_connection() as conn:
|
38 |
-
# Get questions and source finders
|
39 |
-
questions = await get_questions(conn)
|
40 |
source_finders = await get_source_finders(conn)
|
41 |
baseline_rankers = await get_baseline_rankers(conn)
|
42 |
|
43 |
# Convert to dictionaries for easier lookup
|
44 |
-
questions_dict = {q["text"]: q["id"] for q in questions}
|
45 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
46 |
source_finders_dict = {f["name"]: f["id"] for f in source_finders}
|
47 |
|
48 |
# Create formatted options for dropdowns
|
49 |
-
question_options = [q['text'] for q in questions]
|
50 |
finder_options = [s["name"] for s in source_finders]
|
51 |
baseline_ranker_options = [b["name"] for b in baseline_rankers]
|
52 |
-
await update_run_ids_async(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0])
|
53 |
|
54 |
|
55 |
-
def update_run_ids(question_option, source_finder_name):
|
56 |
-
return asyncio.run(update_run_ids_async(question_option, source_finder_name))
|
57 |
|
58 |
|
59 |
-
async def update_run_ids_async(question_option, source_finder_name):
|
60 |
-
global previous_run_id, available_run_id_dict, run_id_options
|
61 |
async with get_async_connection() as conn:
|
62 |
finder_id_int = source_finders_dict.get(source_finder_name)
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
68 |
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
|
77 |
evt: gr.EventData = None):
|
@@ -88,9 +103,11 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
|
|
88 |
|
89 |
# Main function to handle UI interactions
|
90 |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
|
91 |
-
global available_run_id_dict, previous_run_id
|
92 |
if not question_option:
|
93 |
-
return gr.skip(), gr.skip(),
|
|
|
|
|
94 |
logger.info("processing update")
|
95 |
async with get_async_connection() as conn:
|
96 |
if type(baseline_ranker_name) == list:
|
@@ -106,28 +123,18 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
106 |
|
107 |
if question_option == ALL_QUESTIONS_STR:
|
108 |
if finder_id_int:
|
109 |
-
if run_id is None:
|
110 |
-
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
111 |
-
run_id = list(available_run_id_dict.keys())[0]
|
112 |
-
previous_run_id = run_id
|
113 |
run_id_int = available_run_id_dict.get(run_id)
|
114 |
-
all_stats = await calculate_cumulative_statistics_for_all_questions(conn,
|
|
|
115 |
baseline_ranker_id_int)
|
116 |
-
|
117 |
else:
|
118 |
-
run_id_options = list(available_run_id_dict.keys())
|
119 |
all_stats = None
|
120 |
-
|
121 |
-
return None, all_stats, gr.Dropdown(choices=run_id_options,
|
122 |
-
value=run_id), "Select Run Id and source finder to see results", ""
|
123 |
|
124 |
# Extract question ID from selection
|
125 |
question_id = questions_dict.get(question_option)
|
126 |
|
127 |
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
|
128 |
-
run_id_options = list(available_run_id_dict.keys())
|
129 |
-
if run_id not in run_id_options:
|
130 |
-
run_id = run_id_options[0]
|
131 |
previous_run_id = run_id
|
132 |
run_id_int = available_run_id_dict.get(run_id)
|
133 |
|
@@ -140,7 +147,7 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
140 |
df = pd.DataFrame(source_runs)
|
141 |
|
142 |
if not source_runs:
|
143 |
-
return None, None,
|
144 |
|
145 |
# Format table columns
|
146 |
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
|
@@ -152,8 +159,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
152 |
# csv_data = df.to_csv(index=False)
|
153 |
metadata = await get_metadata(conn, question_id, run_id_int)
|
154 |
|
155 |
-
result_message = f"Found {len(source_runs)} results"
|
156 |
-
return df_display, stats,
|
157 |
|
158 |
|
159 |
# Add a new function to handle row selection
|
@@ -189,46 +196,50 @@ async def main():
|
|
189 |
with gr.Column(scale=3):
|
190 |
with gr.Row():
|
191 |
with gr.Column(scale=1):
|
192 |
-
|
193 |
-
|
194 |
-
choices=[ALL_QUESTIONS_STR] + question_options,
|
195 |
-
label="Select Question",
|
196 |
value=None,
|
|
|
197 |
interactive=True,
|
198 |
-
elem_id="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
)
|
200 |
with gr.Column(scale=1):
|
201 |
baseline_rankers_dropdown = gr.Dropdown(
|
202 |
choices=baseline_ranker_options,
|
|
|
203 |
label="Select Baseline Ranker",
|
204 |
interactive=True,
|
205 |
elem_id="baseline_rankers_dropdown"
|
206 |
)
|
207 |
-
|
208 |
with gr.Row():
|
209 |
with gr.Column(scale=1):
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
)
|
216 |
-
with gr.Column(scale=1):
|
217 |
-
run_id_dropdown = gr.Dropdown(
|
218 |
-
choices=run_id_options,
|
219 |
-
allow_custom_value=True,
|
220 |
-
label="Run id for Question and source finder",
|
221 |
interactive=True,
|
222 |
-
elem_id="
|
223 |
)
|
|
|
224 |
with gr.Column(scale=1):
|
225 |
# Sidebar area
|
226 |
-
gr.Markdown("
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
232 |
|
233 |
with gr.Row():
|
234 |
result_text = gr.Markdown("Select a question to view source runs")
|
@@ -283,29 +294,29 @@ async def main():
|
|
283 |
)
|
284 |
|
285 |
baseline_rankers_dropdown.change(
|
286 |
-
|
287 |
-
inputs=[
|
288 |
-
outputs=[
|
|
|
|
|
289 |
|
|
|
|
|
|
|
|
|
290 |
)
|
291 |
|
292 |
question_dropdown.change(
|
293 |
update_sources_list,
|
294 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
295 |
-
outputs=[results_table, statistics_table,
|
296 |
)
|
297 |
|
298 |
source_finder_dropdown.change(
|
299 |
update_run_ids,
|
300 |
-
inputs=[question_dropdown, source_finder_dropdown],
|
301 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
302 |
-
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
303 |
-
)
|
304 |
-
|
305 |
-
run_id_dropdown.change(
|
306 |
-
update_sources_list,
|
307 |
-
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
308 |
-
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
309 |
)
|
310 |
|
311 |
app.queue()
|
|
|
33 |
|
34 |
# Initialize data in a single async function
|
35 |
async def initialize_data():
|
36 |
+
global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
|
37 |
async with get_async_connection() as conn:
|
|
|
|
|
38 |
source_finders = await get_source_finders(conn)
|
39 |
baseline_rankers = await get_baseline_rankers(conn)
|
40 |
|
41 |
# Convert to dictionaries for easier lookup
|
|
|
42 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
43 |
source_finders_dict = {f["name"]: f["id"] for f in source_finders}
|
44 |
|
45 |
# Create formatted options for dropdowns
|
|
|
46 |
finder_options = [s["name"] for s in source_finders]
|
47 |
baseline_ranker_options = [b["name"] for b in baseline_rankers]
|
|
|
48 |
|
49 |
|
50 |
+
def update_run_ids(question_option, source_finder_name, baseline_ranker_name):
|
51 |
+
return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name))
|
52 |
|
53 |
|
54 |
+
async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name):
|
55 |
+
global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options
|
56 |
async with get_async_connection() as conn:
|
57 |
finder_id_int = source_finders_dict.get(source_finder_name)
|
58 |
+
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
59 |
+
run_id_options = list(available_run_id_dict.keys())
|
60 |
+
return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options,
|
61 |
+
value=None), "Select Question to see results.csv", ""
|
|
|
62 |
|
63 |
|
64 |
+
def update_questions_list(source_finder_name, run_id, baseline_ranker_name):
|
65 |
+
return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name))
|
66 |
+
|
67 |
+
|
68 |
+
async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name):
|
69 |
+
global available_run_id_dict
|
70 |
+
if source_finder_name and run_id and baseline_ranker_name:
|
71 |
+
async with get_async_connection() as conn:
|
72 |
+
run_id_int = available_run_id_dict.get(run_id)
|
73 |
+
baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name)
|
74 |
+
questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int)
|
75 |
+
return gr.Dropdown(choices=questions, value=None), None, None, None, None
|
76 |
+
else:
|
77 |
+
return None, None, None, None, None
|
78 |
+
|
79 |
+
|
80 |
+
async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int):
|
81 |
+
global questions_dict, questions
|
82 |
+
questions = await get_questions(conn, finder_id_int, baseline_ranker_id)
|
83 |
+
if questions:
|
84 |
+
questions_dict = {q["text"]: q["id"] for q in questions}
|
85 |
+
question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions]
|
86 |
+
else:
|
87 |
+
question_options = []
|
88 |
+
return question_options
|
89 |
+
|
90 |
|
91 |
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
|
92 |
evt: gr.EventData = None):
|
|
|
103 |
|
104 |
# Main function to handle UI interactions
|
105 |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
|
106 |
+
global available_run_id_dict, previous_run_id, questions_dict
|
107 |
if not question_option:
|
108 |
+
return gr.skip(), gr.skip(), "No question selected", ""
|
109 |
+
if not source_finder_name or not run_id or not baseline_ranker_name:
|
110 |
+
return gr.skip(), gr.skip(), "Need to select source finder and baseline", ""
|
111 |
logger.info("processing update")
|
112 |
async with get_async_connection() as conn:
|
113 |
if type(baseline_ranker_name) == list:
|
|
|
123 |
|
124 |
if question_option == ALL_QUESTIONS_STR:
|
125 |
if finder_id_int:
|
|
|
|
|
|
|
|
|
126 |
run_id_int = available_run_id_dict.get(run_id)
|
127 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()),
|
128 |
+
run_id_int,
|
129 |
baseline_ranker_id_int)
|
|
|
130 |
else:
|
|
|
131 |
all_stats = None
|
132 |
+
return None, all_stats, "Select Run Id and source finder to see results.csv", ""
|
|
|
|
|
133 |
|
134 |
# Extract question ID from selection
|
135 |
question_id = questions_dict.get(question_option)
|
136 |
|
137 |
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
|
|
|
|
|
|
|
138 |
previous_run_id = run_id
|
139 |
run_id_int = available_run_id_dict.get(run_id)
|
140 |
|
|
|
147 |
df = pd.DataFrame(source_runs)
|
148 |
|
149 |
if not source_runs:
|
150 |
+
return None, None, "No results.csv found for the selected filters",
|
151 |
|
152 |
# Format table columns
|
153 |
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
|
|
|
159 |
# csv_data = df.to_csv(index=False)
|
160 |
metadata = await get_metadata(conn, question_id, run_id_int)
|
161 |
|
162 |
+
result_message = f"Found {len(source_runs)} results.csv"
|
163 |
+
return df_display, stats, result_message, metadata
|
164 |
|
165 |
|
166 |
# Add a new function to handle row selection
|
|
|
196 |
with gr.Column(scale=3):
|
197 |
with gr.Row():
|
198 |
with gr.Column(scale=1):
|
199 |
+
source_finder_dropdown = gr.Dropdown(
|
200 |
+
choices=finder_options,
|
|
|
|
|
201 |
value=None,
|
202 |
+
label="Source Finder",
|
203 |
interactive=True,
|
204 |
+
elem_id="source_finder_dropdown"
|
205 |
+
)
|
206 |
+
with gr.Column(scale=1):
|
207 |
+
run_id_dropdown = gr.Dropdown(
|
208 |
+
choices=run_id_options,
|
209 |
+
value=None,
|
210 |
+
allow_custom_value=True,
|
211 |
+
label="source finder Run ID",
|
212 |
+
interactive=True,
|
213 |
+
elem_id="run_id_dropdown"
|
214 |
)
|
215 |
with gr.Column(scale=1):
|
216 |
baseline_rankers_dropdown = gr.Dropdown(
|
217 |
choices=baseline_ranker_options,
|
218 |
+
value=None,
|
219 |
label="Select Baseline Ranker",
|
220 |
interactive=True,
|
221 |
elem_id="baseline_rankers_dropdown"
|
222 |
)
|
|
|
223 |
with gr.Row():
|
224 |
with gr.Column(scale=1):
|
225 |
+
# Main content area
|
226 |
+
question_dropdown = gr.Dropdown(
|
227 |
+
choices=[ALL_QUESTIONS_STR] + question_options,
|
228 |
+
label="Select Question (if list is empty this means there is no overlap between source run and baseline)",
|
229 |
+
value=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
interactive=True,
|
231 |
+
elem_id="question_dropdown"
|
232 |
)
|
233 |
+
|
234 |
with gr.Column(scale=1):
|
235 |
# Sidebar area
|
236 |
+
gr.Markdown("""To Get started select the following:
|
237 |
+
* Source Finder
|
238 |
+
* Source Finder Run ID (corresponds to a run of the source finder for a group of questions)
|
239 |
+
* Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions)
|
240 |
+
|
241 |
+
**Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.**
|
242 |
+
""")
|
243 |
|
244 |
with gr.Row():
|
245 |
result_text = gr.Markdown("Select a question to view source runs")
|
|
|
294 |
)
|
295 |
|
296 |
baseline_rankers_dropdown.change(
|
297 |
+
update_questions_list,
|
298 |
+
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
299 |
+
outputs=[question_dropdown, result_text, metadata_text]
|
300 |
+
|
301 |
+
)
|
302 |
|
303 |
+
run_id_dropdown.change(
|
304 |
+
update_questions_list,
|
305 |
+
inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
306 |
+
outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table]
|
307 |
)
|
308 |
|
309 |
question_dropdown.change(
|
310 |
update_sources_list,
|
311 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
312 |
+
outputs=[results_table, statistics_table, result_text, metadata_text]
|
313 |
)
|
314 |
|
315 |
source_finder_dropdown.change(
|
316 |
update_run_ids,
|
317 |
+
inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown],
|
318 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
319 |
+
outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
)
|
321 |
|
322 |
app.queue()
|
data_access.py
CHANGED
@@ -14,9 +14,17 @@ load_dotenv()
|
|
14 |
|
15 |
|
16 |
@asynccontextmanager
|
17 |
-
async def get_async_connection(schema="talmudexplore"):
|
18 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
conn = None
|
|
|
20 |
try:
|
21 |
# Create a single connection without relying on a shared pool
|
22 |
conn = await asyncpg.connect(
|
@@ -27,14 +35,27 @@ async def get_async_connection(schema="talmudexplore"):
|
|
27 |
port=os.getenv("pg_port")
|
28 |
)
|
29 |
await conn.execute(f'SET search_path TO {schema}')
|
|
|
|
|
|
|
|
|
|
|
30 |
yield conn
|
|
|
|
|
31 |
finally:
|
32 |
if conn:
|
33 |
await conn.close()
|
34 |
|
35 |
|
36 |
-
async def get_questions(conn: asyncpg.Connection):
|
37 |
-
questions = await conn.fetch("
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
39 |
|
40 |
@cached(cache=TTLCache(ttl=1800, maxsize=1024))
|
@@ -96,7 +117,7 @@ async def get_baseline_rankers(conn: asyncpg.Connection):
|
|
96 |
FROM source_run_results srr
|
97 |
WHERE srr.source_finder_run_id = sfr.id
|
98 |
)
|
99 |
-
ORDER BY sf.id
|
100 |
"""
|
101 |
|
102 |
rankers = await conn.fetch(query)
|
@@ -131,26 +152,24 @@ async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connecti
|
|
131 |
"high_ranked_overlap_count": len(high_ranked_overlap),
|
132 |
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
133 |
}
|
134 |
-
#convert results to dataframe
|
135 |
results_df = pd.DataFrame([results])
|
136 |
return results_df
|
137 |
|
138 |
|
139 |
-
async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int):
|
140 |
"""
|
141 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
142 |
|
143 |
Args:
|
|
|
|
|
144 |
source_finder_run_id (int): ID of the source finder and run as appears in source runs
|
145 |
ranker_id (int): ID of the baseline ranker
|
146 |
|
147 |
Returns:
|
148 |
pd.DataFrame: DataFrame containing aggregated statistics
|
149 |
"""
|
150 |
-
# Get all questions
|
151 |
-
query = "SELECT id FROM questions ORDER BY id"
|
152 |
-
questions = await conn.fetch(query)
|
153 |
-
question_ids = [q["id"] for q in questions]
|
154 |
|
155 |
# Initialize aggregates
|
156 |
total_baseline_sources = 0
|
@@ -190,7 +209,7 @@ async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connec
|
|
190 |
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
|
191 |
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
|
192 |
|
193 |
-
# Compile results
|
194 |
cumulative_stats = {
|
195 |
"total_questions_analyzed": valid_questions,
|
196 |
"total_baseline_sources": total_baseline_sources,
|
@@ -237,7 +256,7 @@ async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source
|
|
237 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
238 |
# Get all unique sugya_ids
|
239 |
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
240 |
-
# Build unified results
|
241 |
unified_results = []
|
242 |
for sugya_id in all_sugya_ids:
|
243 |
in_source_run = sugya_id in source_runs_dict
|
|
|
14 |
|
15 |
|
16 |
@asynccontextmanager
|
17 |
+
async def get_async_connection(schema="talmudexplore", auto_commit=True):
|
18 |
+
"""
|
19 |
+
Get a connection for the current request.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
schema: Database schema to use
|
23 |
+
auto_commit: If True (default), each statement auto-commits.
|
24 |
+
If False, requires explicit commit.
|
25 |
+
"""
|
26 |
conn = None
|
27 |
+
tx = None
|
28 |
try:
|
29 |
# Create a single connection without relying on a shared pool
|
30 |
conn = await asyncpg.connect(
|
|
|
35 |
port=os.getenv("pg_port")
|
36 |
)
|
37 |
await conn.execute(f'SET search_path TO {schema}')
|
38 |
+
|
39 |
+
if not auto_commit:
|
40 |
+
# Start a transaction that requires explicit commit
|
41 |
+
tx = conn.transaction()
|
42 |
+
await tx.start()
|
43 |
yield conn
|
44 |
+
if not auto_commit and tx:
|
45 |
+
await tx.commit()
|
46 |
finally:
|
47 |
if conn:
|
48 |
await conn.close()
|
49 |
|
50 |
|
51 |
+
async def get_questions(conn: asyncpg.Connection, source_finder_run_id: int, baseline_source_finder_run_id: int):
|
52 |
+
questions = await conn.fetch("""
|
53 |
+
select distinct q.id, question_text from talmudexplore.questions q
|
54 |
+
join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $1) sfrqm1
|
55 |
+
on sfrqm1.question_id = q.id
|
56 |
+
join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $2) sfrqm2
|
57 |
+
on sfrqm2.question_id = q.id;
|
58 |
+
""", source_finder_run_id, baseline_source_finder_run_id)
|
59 |
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
60 |
|
61 |
@cached(cache=TTLCache(ttl=1800, maxsize=1024))
|
|
|
117 |
FROM source_run_results srr
|
118 |
WHERE srr.source_finder_run_id = sfr.id
|
119 |
)
|
120 |
+
ORDER BY sf.id DESC
|
121 |
"""
|
122 |
|
123 |
rankers = await conn.fetch(query)
|
|
|
152 |
"high_ranked_overlap_count": len(high_ranked_overlap),
|
153 |
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
154 |
}
|
155 |
+
#convert results.csv to dataframe
|
156 |
results_df = pd.DataFrame([results])
|
157 |
return results_df
|
158 |
|
159 |
|
160 |
+
async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, question_ids, source_finder_run_id: int, ranker_id: int):
|
161 |
"""
|
162 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
163 |
|
164 |
Args:
|
165 |
+
conn (asyncpg.Connection): Database connection
|
166 |
+
question_ids (list): List of question IDs to analyze
|
167 |
source_finder_run_id (int): ID of the source finder and run as appears in source runs
|
168 |
ranker_id (int): ID of the baseline ranker
|
169 |
|
170 |
Returns:
|
171 |
pd.DataFrame: DataFrame containing aggregated statistics
|
172 |
"""
|
|
|
|
|
|
|
|
|
173 |
|
174 |
# Initialize aggregates
|
175 |
total_baseline_sources = 0
|
|
|
209 |
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
|
210 |
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
|
211 |
|
212 |
+
# Compile results.csv
|
213 |
cumulative_stats = {
|
214 |
"total_questions_analyzed": valid_questions,
|
215 |
"total_baseline_sources": total_baseline_sources,
|
|
|
256 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
257 |
# Get all unique sugya_ids
|
258 |
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
259 |
+
# Build unified results.csv
|
260 |
unified_results = []
|
261 |
for sugya_id in all_sugya_ids:
|
262 |
in_source_run = sugya_id in source_runs_dict
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
asyncpg
|
2 |
gradio
|
3 |
dotenv
|
4 |
-
psycopg2
|
5 |
cachetools
|
|
|
1 |
asyncpg
|
2 |
gradio
|
3 |
dotenv
|
4 |
+
psycopg2-binary
|
5 |
cachetools
|
tests/test_db_layer.py
CHANGED
@@ -2,9 +2,16 @@ import pandas as pd
|
|
2 |
import pytest
|
3 |
|
4 |
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
|
5 |
-
get_async_connection
|
6 |
from data_access import get_unified_sources
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
@pytest.mark.asyncio
|
10 |
async def test_get_unified_sources():
|
@@ -13,7 +20,7 @@ async def test_get_unified_sources():
|
|
13 |
assert results is not None
|
14 |
assert stats is not None
|
15 |
|
16 |
-
# Check number of rows in results list
|
17 |
assert len(results) > 4, "Results should contain at least one row"
|
18 |
|
19 |
# Check number of rows in stats DataFrame
|
@@ -30,9 +37,11 @@ async def test_calculate_cumulative_statistics_for_all_questions():
|
|
30 |
|
31 |
# Call the function to test
|
32 |
async with get_async_connection() as conn:
|
33 |
-
|
|
|
|
|
34 |
|
35 |
-
# Check basic structure of results
|
36 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
37 |
assert result.shape[0] == 1, "Result should have one row"
|
38 |
|
@@ -74,7 +83,7 @@ async def test_get_metadata_none_returned():
|
|
74 |
async with get_async_connection() as conn:
|
75 |
result = await get_metadata(conn, question_id, source_finder_run_id)
|
76 |
|
77 |
-
assert result ==
|
78 |
|
79 |
@pytest.mark.asyncio
|
80 |
async def test_get_metadata():
|
|
|
2 |
import pytest
|
3 |
|
4 |
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
|
5 |
+
get_async_connection, get_questions
|
6 |
from data_access import get_unified_sources
|
7 |
|
8 |
+
@pytest.mark.asyncio
|
9 |
+
async def test_get_questions():
|
10 |
+
source_run_id = 2
|
11 |
+
baseline_source_finder_run_id = 1
|
12 |
+
async with get_async_connection() as conn:
|
13 |
+
actual = await get_questions(conn, source_run_id, baseline_source_finder_run_id)
|
14 |
+
assert len(actual) == 10
|
15 |
|
16 |
@pytest.mark.asyncio
|
17 |
async def test_get_unified_sources():
|
|
|
20 |
assert results is not None
|
21 |
assert stats is not None
|
22 |
|
23 |
+
# Check number of rows in results.csv list
|
24 |
assert len(results) > 4, "Results should contain at least one row"
|
25 |
|
26 |
# Check number of rows in stats DataFrame
|
|
|
37 |
|
38 |
# Call the function to test
|
39 |
async with get_async_connection() as conn:
|
40 |
+
questions = await get_questions(conn, source_finder_run_id, ranker_id)
|
41 |
+
question_ids = [question['id'] for question in questions]
|
42 |
+
result = await calculate_cumulative_statistics_for_all_questions(conn, question_ids, source_finder_run_id, ranker_id)
|
43 |
|
44 |
+
# Check basic structure of results.csv
|
45 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
46 |
assert result.shape[0] == 1, "Result should have one row"
|
47 |
|
|
|
83 |
async with get_async_connection() as conn:
|
84 |
result = await get_metadata(conn, question_id, source_finder_run_id)
|
85 |
|
86 |
+
assert result == {}, "Should return empty string when no metadata is found"
|
87 |
|
88 |
@pytest.mark.asyncio
|
89 |
async def test_get_metadata():
|