davidr70 commited on
Commit
ea4284c
·
1 Parent(s): 92c49ff

changes for new version

Browse files
Files changed (4) hide show
  1. app.py +85 -74
  2. data_access.py +32 -13
  3. requirements.txt +1 -1
  4. tests/test_db_layer.py +14 -5
app.py CHANGED
@@ -33,45 +33,60 @@ run_id_dropdown = None
33
 
34
  # Initialize data in a single async function
35
  async def initialize_data():
36
- global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
37
  async with get_async_connection() as conn:
38
- # Get questions and source finders
39
- questions = await get_questions(conn)
40
  source_finders = await get_source_finders(conn)
41
  baseline_rankers = await get_baseline_rankers(conn)
42
 
43
  # Convert to dictionaries for easier lookup
44
- questions_dict = {q["text"]: q["id"] for q in questions}
45
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
46
  source_finders_dict = {f["name"]: f["id"] for f in source_finders}
47
 
48
  # Create formatted options for dropdowns
49
- question_options = [q['text'] for q in questions]
50
  finder_options = [s["name"] for s in source_finders]
51
  baseline_ranker_options = [b["name"] for b in baseline_rankers]
52
- await update_run_ids_async(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0])
53
 
54
 
55
- def update_run_ids(question_option, source_finder_name):
56
- return asyncio.run(update_run_ids_async(question_option, source_finder_name))
57
 
58
 
59
- async def update_run_ids_async(question_option, source_finder_name):
60
- global previous_run_id, available_run_id_dict, run_id_options
61
  async with get_async_connection() as conn:
62
  finder_id_int = source_finders_dict.get(source_finder_name)
63
- if question_option and question_option != ALL_QUESTIONS_STR:
64
- question_id = questions_dict.get(question_option)
65
- available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
66
- else:
67
- available_run_id_dict = await get_run_ids(conn, finder_id_int)
68
 
69
 
70
- run_id = list(available_run_id_dict.keys())[0]
71
- previous_run_id = run_id
72
- run_id_options = list(available_run_id_dict.keys())
73
- return None, None, gr.Dropdown(choices=run_id_options,
74
- value=run_id), "Select Question to see results", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
77
  evt: gr.EventData = None):
@@ -88,9 +103,11 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
88
 
89
  # Main function to handle UI interactions
90
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
91
- global available_run_id_dict, previous_run_id
92
  if not question_option:
93
- return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
 
 
94
  logger.info("processing update")
95
  async with get_async_connection() as conn:
96
  if type(baseline_ranker_name) == list:
@@ -106,28 +123,18 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
106
 
107
  if question_option == ALL_QUESTIONS_STR:
108
  if finder_id_int:
109
- if run_id is None:
110
- available_run_id_dict = await get_run_ids(conn, finder_id_int)
111
- run_id = list(available_run_id_dict.keys())[0]
112
- previous_run_id = run_id
113
  run_id_int = available_run_id_dict.get(run_id)
114
- all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int,
 
115
  baseline_ranker_id_int)
116
-
117
  else:
118
- run_id_options = list(available_run_id_dict.keys())
119
  all_stats = None
120
- run_id_options = list(available_run_id_dict.keys())
121
- return None, all_stats, gr.Dropdown(choices=run_id_options,
122
- value=run_id), "Select Run Id and source finder to see results", ""
123
 
124
  # Extract question ID from selection
125
  question_id = questions_dict.get(question_option)
126
 
127
  available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
128
- run_id_options = list(available_run_id_dict.keys())
129
- if run_id not in run_id_options:
130
- run_id = run_id_options[0]
131
  previous_run_id = run_id
132
  run_id_int = available_run_id_dict.get(run_id)
133
 
@@ -140,7 +147,7 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
140
  df = pd.DataFrame(source_runs)
141
 
142
  if not source_runs:
143
- return None, None, run_id_options, "No results found for the selected filters",
144
 
145
  # Format table columns
146
  columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
@@ -152,8 +159,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
152
  # csv_data = df.to_csv(index=False)
153
  metadata = await get_metadata(conn, question_id, run_id_int)
154
 
155
- result_message = f"Found {len(source_runs)} results"
156
- return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
157
 
158
 
159
  # Add a new function to handle row selection
@@ -189,46 +196,50 @@ async def main():
189
  with gr.Column(scale=3):
190
  with gr.Row():
191
  with gr.Column(scale=1):
192
- # Main content area
193
- question_dropdown = gr.Dropdown(
194
- choices=[ALL_QUESTIONS_STR] + question_options,
195
- label="Select Question",
196
  value=None,
 
197
  interactive=True,
198
- elem_id="question_dropdown"
 
 
 
 
 
 
 
 
 
199
  )
200
  with gr.Column(scale=1):
201
  baseline_rankers_dropdown = gr.Dropdown(
202
  choices=baseline_ranker_options,
 
203
  label="Select Baseline Ranker",
204
  interactive=True,
205
  elem_id="baseline_rankers_dropdown"
206
  )
207
-
208
  with gr.Row():
209
  with gr.Column(scale=1):
210
- source_finder_dropdown = gr.Dropdown(
211
- choices=finder_options,
212
- label="Source Finder",
213
- interactive=True,
214
- elem_id="source_finder_dropdown"
215
- )
216
- with gr.Column(scale=1):
217
- run_id_dropdown = gr.Dropdown(
218
- choices=run_id_options,
219
- allow_custom_value=True,
220
- label="Run id for Question and source finder",
221
  interactive=True,
222
- elem_id="run_id_dropdown"
223
  )
 
224
  with gr.Column(scale=1):
225
  # Sidebar area
226
- gr.Markdown("### About")
227
- gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
228
-
229
- gr.Markdown("### Statistics")
230
- gr.Markdown(f"Total Questions: {len(questions)}")
231
- gr.Markdown(f"Source Finders: {len(source_finders)}")
 
232
 
233
  with gr.Row():
234
  result_text = gr.Markdown("Select a question to view source runs")
@@ -283,29 +294,29 @@ async def main():
283
  )
284
 
285
  baseline_rankers_dropdown.change(
286
- update_sources_list,
287
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
288
- outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
 
 
289
 
 
 
 
 
290
  )
291
 
292
  question_dropdown.change(
293
  update_sources_list,
294
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
295
- outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
296
  )
297
 
298
  source_finder_dropdown.change(
299
  update_run_ids,
300
- inputs=[question_dropdown, source_finder_dropdown],
301
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
302
- outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
303
- )
304
-
305
- run_id_dropdown.change(
306
- update_sources_list,
307
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
308
- outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
309
  )
310
 
311
  app.queue()
 
33
 
34
  # Initialize data in a single async function
35
  async def initialize_data():
36
+ global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
37
  async with get_async_connection() as conn:
 
 
38
  source_finders = await get_source_finders(conn)
39
  baseline_rankers = await get_baseline_rankers(conn)
40
 
41
  # Convert to dictionaries for easier lookup
 
42
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
43
  source_finders_dict = {f["name"]: f["id"] for f in source_finders}
44
 
45
  # Create formatted options for dropdowns
 
46
  finder_options = [s["name"] for s in source_finders]
47
  baseline_ranker_options = [b["name"] for b in baseline_rankers]
 
48
 
49
 
50
+ def update_run_ids(question_option, source_finder_name, baseline_ranker_name):
51
+ return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name))
52
 
53
 
54
+ async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name):
55
+ global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options
56
  async with get_async_connection() as conn:
57
  finder_id_int = source_finders_dict.get(source_finder_name)
58
+ available_run_id_dict = await get_run_ids(conn, finder_id_int)
59
+ run_id_options = list(available_run_id_dict.keys())
60
+ return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options,
61
+ value=None), "Select Question to see results.csv", ""
 
62
 
63
 
64
+ def update_questions_list(source_finder_name, run_id, baseline_ranker_name):
65
+ return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name))
66
+
67
+
68
+ async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name):
69
+ global available_run_id_dict
70
+ if source_finder_name and run_id and baseline_ranker_name:
71
+ async with get_async_connection() as conn:
72
+ run_id_int = available_run_id_dict.get(run_id)
73
+ baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name)
74
+ questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int)
75
+ return gr.Dropdown(choices=questions, value=None), None, None, None, None
76
+ else:
77
+ return None, None, None, None, None
78
+
79
+
80
+ async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int):
81
+ global questions_dict, questions
82
+ questions = await get_questions(conn, finder_id_int, baseline_ranker_id)
83
+ if questions:
84
+ questions_dict = {q["text"]: q["id"] for q in questions}
85
+ question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions]
86
+ else:
87
+ question_options = []
88
+ return question_options
89
+
90
 
91
  def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
92
  evt: gr.EventData = None):
 
103
 
104
  # Main function to handle UI interactions
105
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
106
+ global available_run_id_dict, previous_run_id, questions_dict
107
  if not question_option:
108
+ return gr.skip(), gr.skip(), "No question selected", ""
109
+ if not source_finder_name or not run_id or not baseline_ranker_name:
110
+ return gr.skip(), gr.skip(), "Need to select source finder and baseline", ""
111
  logger.info("processing update")
112
  async with get_async_connection() as conn:
113
  if type(baseline_ranker_name) == list:
 
123
 
124
  if question_option == ALL_QUESTIONS_STR:
125
  if finder_id_int:
 
 
 
 
126
  run_id_int = available_run_id_dict.get(run_id)
127
+ all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()),
128
+ run_id_int,
129
  baseline_ranker_id_int)
 
130
  else:
 
131
  all_stats = None
132
+ return None, all_stats, "Select Run Id and source finder to see results.csv", ""
 
 
133
 
134
  # Extract question ID from selection
135
  question_id = questions_dict.get(question_option)
136
 
137
  available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
 
 
 
138
  previous_run_id = run_id
139
  run_id_int = available_run_id_dict.get(run_id)
140
 
 
147
  df = pd.DataFrame(source_runs)
148
 
149
  if not source_runs:
150
+ return None, None, "No results.csv found for the selected filters",
151
 
152
  # Format table columns
153
  columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
 
159
  # csv_data = df.to_csv(index=False)
160
  metadata = await get_metadata(conn, question_id, run_id_int)
161
 
162
+ result_message = f"Found {len(source_runs)} results.csv"
163
+ return df_display, stats, result_message, metadata
164
 
165
 
166
  # Add a new function to handle row selection
 
196
  with gr.Column(scale=3):
197
  with gr.Row():
198
  with gr.Column(scale=1):
199
+ source_finder_dropdown = gr.Dropdown(
200
+ choices=finder_options,
 
 
201
  value=None,
202
+ label="Source Finder",
203
  interactive=True,
204
+ elem_id="source_finder_dropdown"
205
+ )
206
+ with gr.Column(scale=1):
207
+ run_id_dropdown = gr.Dropdown(
208
+ choices=run_id_options,
209
+ value=None,
210
+ allow_custom_value=True,
211
+ label="source finder Run ID",
212
+ interactive=True,
213
+ elem_id="run_id_dropdown"
214
  )
215
  with gr.Column(scale=1):
216
  baseline_rankers_dropdown = gr.Dropdown(
217
  choices=baseline_ranker_options,
218
+ value=None,
219
  label="Select Baseline Ranker",
220
  interactive=True,
221
  elem_id="baseline_rankers_dropdown"
222
  )
 
223
  with gr.Row():
224
  with gr.Column(scale=1):
225
+ # Main content area
226
+ question_dropdown = gr.Dropdown(
227
+ choices=[ALL_QUESTIONS_STR] + question_options,
228
+ label="Select Question (if list is empty this means there is no overlap between source run and baseline)",
229
+ value=None,
 
 
 
 
 
 
230
  interactive=True,
231
+ elem_id="question_dropdown"
232
  )
233
+
234
  with gr.Column(scale=1):
235
  # Sidebar area
236
+ gr.Markdown("""To Get started select the following:
237
+ * Source Finder
238
+ * Source Finder Run ID (corresponds to a run of the source finder for a group of questions)
239
+ * Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions)
240
+
241
+ **Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.**
242
+ """)
243
 
244
  with gr.Row():
245
  result_text = gr.Markdown("Select a question to view source runs")
 
294
  )
295
 
296
  baseline_rankers_dropdown.change(
297
+ update_questions_list,
298
+ inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
299
+ outputs=[question_dropdown, result_text, metadata_text]
300
+
301
+ )
302
 
303
+ run_id_dropdown.change(
304
+ update_questions_list,
305
+ inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
306
+ outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table]
307
  )
308
 
309
  question_dropdown.change(
310
  update_sources_list,
311
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
312
+ outputs=[results_table, statistics_table, result_text, metadata_text]
313
  )
314
 
315
  source_finder_dropdown.change(
316
  update_run_ids,
317
+ inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown],
318
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
319
+ outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
 
 
 
 
 
 
320
  )
321
 
322
  app.queue()
data_access.py CHANGED
@@ -14,9 +14,17 @@ load_dotenv()
14
 
15
 
16
  @asynccontextmanager
17
- async def get_async_connection(schema="talmudexplore"):
18
- """Get a connection for the current request."""
 
 
 
 
 
 
 
19
  conn = None
 
20
  try:
21
  # Create a single connection without relying on a shared pool
22
  conn = await asyncpg.connect(
@@ -27,14 +35,27 @@ async def get_async_connection(schema="talmudexplore"):
27
  port=os.getenv("pg_port")
28
  )
29
  await conn.execute(f'SET search_path TO {schema}')
 
 
 
 
 
30
  yield conn
 
 
31
  finally:
32
  if conn:
33
  await conn.close()
34
 
35
 
36
- async def get_questions(conn: asyncpg.Connection):
37
- questions = await conn.fetch("SELECT id, question_text FROM questions where question_group_id = 1 ORDER BY id")
 
 
 
 
 
 
38
  return [{"id": q["id"], "text": q["question_text"]} for q in questions]
39
 
40
  @cached(cache=TTLCache(ttl=1800, maxsize=1024))
@@ -96,7 +117,7 @@ async def get_baseline_rankers(conn: asyncpg.Connection):
96
  FROM source_run_results srr
97
  WHERE srr.source_finder_run_id = sfr.id
98
  )
99
- ORDER BY sf.id
100
  """
101
 
102
  rankers = await conn.fetch(query)
@@ -131,26 +152,24 @@ async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connecti
131
  "high_ranked_overlap_count": len(high_ranked_overlap),
132
  "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
133
  }
134
- #convert results to dataframe
135
  results_df = pd.DataFrame([results])
136
  return results_df
137
 
138
 
139
- async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int):
140
  """
141
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
142
 
143
  Args:
 
 
144
  source_finder_run_id (int): ID of the source finder and run as appears in source runs
145
  ranker_id (int): ID of the baseline ranker
146
 
147
  Returns:
148
  pd.DataFrame: DataFrame containing aggregated statistics
149
  """
150
- # Get all questions
151
- query = "SELECT id FROM questions ORDER BY id"
152
- questions = await conn.fetch(query)
153
- question_ids = [q["id"] for q in questions]
154
 
155
  # Initialize aggregates
156
  total_baseline_sources = 0
@@ -190,7 +209,7 @@ async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connec
190
  total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
191
  if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
192
 
193
- # Compile results
194
  cumulative_stats = {
195
  "total_questions_analyzed": valid_questions,
196
  "total_baseline_sources": total_baseline_sources,
@@ -237,7 +256,7 @@ async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source
237
  baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
238
  # Get all unique sugya_ids
239
  all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
240
- # Build unified results
241
  unified_results = []
242
  for sugya_id in all_sugya_ids:
243
  in_source_run = sugya_id in source_runs_dict
 
14
 
15
 
16
  @asynccontextmanager
17
+ async def get_async_connection(schema="talmudexplore", auto_commit=True):
18
+ """
19
+ Get a connection for the current request.
20
+
21
+ Args:
22
+ schema: Database schema to use
23
+ auto_commit: If True (default), each statement auto-commits.
24
+ If False, requires explicit commit.
25
+ """
26
  conn = None
27
+ tx = None
28
  try:
29
  # Create a single connection without relying on a shared pool
30
  conn = await asyncpg.connect(
 
35
  port=os.getenv("pg_port")
36
  )
37
  await conn.execute(f'SET search_path TO {schema}')
38
+
39
+ if not auto_commit:
40
+ # Start a transaction that requires explicit commit
41
+ tx = conn.transaction()
42
+ await tx.start()
43
  yield conn
44
+ if not auto_commit and tx:
45
+ await tx.commit()
46
  finally:
47
  if conn:
48
  await conn.close()
49
 
50
 
51
+ async def get_questions(conn: asyncpg.Connection, source_finder_run_id: int, baseline_source_finder_run_id: int):
52
+ questions = await conn.fetch("""
53
+ select distinct q.id, question_text from talmudexplore.questions q
54
+ join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $1) sfrqm1
55
+ on sfrqm1.question_id = q.id
56
+ join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $2) sfrqm2
57
+ on sfrqm2.question_id = q.id;
58
+ """, source_finder_run_id, baseline_source_finder_run_id)
59
  return [{"id": q["id"], "text": q["question_text"]} for q in questions]
60
 
61
  @cached(cache=TTLCache(ttl=1800, maxsize=1024))
 
117
  FROM source_run_results srr
118
  WHERE srr.source_finder_run_id = sfr.id
119
  )
120
+ ORDER BY sf.id DESC
121
  """
122
 
123
  rankers = await conn.fetch(query)
 
152
  "high_ranked_overlap_count": len(high_ranked_overlap),
153
  "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
154
  }
155
+ #convert results.csv to dataframe
156
  results_df = pd.DataFrame([results])
157
  return results_df
158
 
159
 
160
+ async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, question_ids, source_finder_run_id: int, ranker_id: int):
161
  """
162
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
163
 
164
  Args:
165
+ conn (asyncpg.Connection): Database connection
166
+ question_ids (list): List of question IDs to analyze
167
  source_finder_run_id (int): ID of the source finder and run as appears in source runs
168
  ranker_id (int): ID of the baseline ranker
169
 
170
  Returns:
171
  pd.DataFrame: DataFrame containing aggregated statistics
172
  """
 
 
 
 
173
 
174
  # Initialize aggregates
175
  total_baseline_sources = 0
 
209
  total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
210
  if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
211
 
212
+ # Compile results.csv
213
  cumulative_stats = {
214
  "total_questions_analyzed": valid_questions,
215
  "total_baseline_sources": total_baseline_sources,
 
256
  baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
257
  # Get all unique sugya_ids
258
  all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
259
+ # Build unified results.csv
260
  unified_results = []
261
  for sugya_id in all_sugya_ids:
262
  in_source_run = sugya_id in source_runs_dict
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  asyncpg
2
  gradio
3
  dotenv
4
- psycopg2
5
  cachetools
 
1
  asyncpg
2
  gradio
3
  dotenv
4
+ psycopg2-binary
5
  cachetools
tests/test_db_layer.py CHANGED
@@ -2,9 +2,16 @@ import pandas as pd
2
  import pytest
3
 
4
  from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
5
- get_async_connection
6
  from data_access import get_unified_sources
7
 
 
 
 
 
 
 
 
8
 
9
  @pytest.mark.asyncio
10
  async def test_get_unified_sources():
@@ -13,7 +20,7 @@ async def test_get_unified_sources():
13
  assert results is not None
14
  assert stats is not None
15
 
16
- # Check number of rows in results list
17
  assert len(results) > 4, "Results should contain at least one row"
18
 
19
  # Check number of rows in stats DataFrame
@@ -30,9 +37,11 @@ async def test_calculate_cumulative_statistics_for_all_questions():
30
 
31
  # Call the function to test
32
  async with get_async_connection() as conn:
33
- result = await calculate_cumulative_statistics_for_all_questions(conn, source_finder_run_id, ranker_id)
 
 
34
 
35
- # Check basic structure of results
36
  assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
37
  assert result.shape[0] == 1, "Result should have one row"
38
 
@@ -74,7 +83,7 @@ async def test_get_metadata_none_returned():
74
  async with get_async_connection() as conn:
75
  result = await get_metadata(conn, question_id, source_finder_run_id)
76
 
77
- assert result == "", "Should return empty string when no metadata is found"
78
 
79
  @pytest.mark.asyncio
80
  async def test_get_metadata():
 
2
  import pytest
3
 
4
  from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
5
+ get_async_connection, get_questions
6
  from data_access import get_unified_sources
7
 
8
+ @pytest.mark.asyncio
9
+ async def test_get_questions():
10
+ source_run_id = 2
11
+ baseline_source_finder_run_id = 1
12
+ async with get_async_connection() as conn:
13
+ actual = await get_questions(conn, source_run_id, baseline_source_finder_run_id)
14
+ assert len(actual) == 10
15
 
16
  @pytest.mark.asyncio
17
  async def test_get_unified_sources():
 
20
  assert results is not None
21
  assert stats is not None
22
 
23
+ # Check number of rows in results.csv list
24
  assert len(results) > 4, "Results should contain at least one row"
25
 
26
  # Check number of rows in stats DataFrame
 
37
 
38
  # Call the function to test
39
  async with get_async_connection() as conn:
40
+ questions = await get_questions(conn, source_finder_run_id, ranker_id)
41
+ question_ids = [question['id'] for question in questions]
42
+ result = await calculate_cumulative_statistics_for_all_questions(conn, question_ids, source_finder_run_id, ranker_id)
43
 
44
+ # Check basic structure of results.csv
45
  assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
46
  assert result.shape[0] == 1, "Result should have one row"
47
 
 
83
  async with get_async_connection() as conn:
84
  result = await get_metadata(conn, question_id, source_finder_run_id)
85
 
86
+ assert result == {}, "Should return empty string when no metadata is found"
87
 
88
  @pytest.mark.asyncio
89
  async def test_get_metadata():