davidr70 commited on
Commit
312213e
·
1 Parent(s): 322ed33

changes to use new table and descriptive runs

Browse files
Files changed (5) hide show
  1. app.py +23 -14
  2. data_access.py +77 -74
  3. eval_tables.py +29 -6
  4. scripts/__init__.py +0 -0
  5. tests/test_db_layer.py +45 -4
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
- get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -18,8 +18,9 @@ question_options = []
18
  baseline_rankers_dict = {}
19
  baseline_ranker_options = []
20
  run_ids = []
 
21
  finder_options = []
22
- previous_run_id = 1
23
 
24
  run_id_dropdown = None
25
 
@@ -60,8 +61,9 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
60
 
61
  # Main function to handle UI interactions
62
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
 
63
  if not question_option:
64
- return gr.skip(), gr.skip(), gr.skip(), "No question selected"
65
  logger.info("processing update")
66
  if type(baseline_ranker_name) == list:
67
  baseline_ranker_name = baseline_ranker_name[0]
@@ -75,20 +77,21 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
75
 
76
  if question_option == "All questions":
77
  if finder_id_int and type(run_id) == str:
78
- all_stats = await calculate_cumulative_statistics_for_all_questions(finder_id_int, int(run_id), baseline_ranker_id_int)
 
79
  else:
80
  all_stats = None
81
- return None, all_stats, gr.skip(), "Select Run Id and source finder to see results"
82
 
83
  # Extract question ID from selection
84
  question_id = questions_dict.get(question_option)
85
 
86
- available_run_ids = await get_run_ids(question_id)
87
- run_id_options = [str(r_id) for r_id in available_run_ids]
88
  if run_id not in run_id_options:
89
  run_id = run_id_options[0]
90
 
91
- run_id_int = int(run_id)
92
 
93
 
94
 
@@ -96,7 +99,7 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
96
  stats = None
97
  # Get source runs data
98
  if finder_id_int:
99
- source_runs, stats = await get_unified_sources(question_id, finder_id_int, run_id_int, baseline_ranker_id_int)
100
  # Create DataFrame for display
101
  df = pd.DataFrame(source_runs)
102
 
@@ -110,9 +113,10 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
110
 
111
  # CSV for download
112
  # csv_data = df.to_csv(index=False)
 
113
 
114
  result_message = f"Found {len(source_runs)} results"
115
- return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message,
116
 
117
 
118
  # Add a new function to handle row selection
@@ -182,7 +186,6 @@ async def main():
182
  # Sidebar area
183
  gr.Markdown("### About")
184
  gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
185
- gr.Markdown("Start by selecting a question, then optionally filter by source finder and run ID.")
186
 
187
  gr.Markdown("### Statistics")
188
  gr.Markdown(f"Total Questions: {len(questions)}")
@@ -204,6 +207,12 @@ async def main():
204
  ],
205
  interactive=False,
206
  )
 
 
 
 
 
 
207
  with gr.Row():
208
  gr.Markdown("# Sources Found")
209
  with gr.Row():
@@ -240,20 +249,20 @@ async def main():
240
  update_sources_list,
241
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
242
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
243
- outputs=[results_table, statistics_table, run_id_dropdown, result_text]
244
  )
245
 
246
  source_finder_dropdown.change(
247
  update_sources_list,
248
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
249
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
250
- outputs=[results_table, statistics_table, run_id_dropdown, result_text]
251
  )
252
 
253
  run_id_dropdown.change(
254
  update_sources_list,
255
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
256
- outputs=[results_table, statistics_table, run_id_dropdown, result_text]
257
  )
258
 
259
 
 
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
+ get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
18
  baseline_rankers_dict = {}
19
  baseline_ranker_options = []
20
  run_ids = []
21
+ available_run_id_dict = {}
22
  finder_options = []
23
+ previous_run_id = None
24
 
25
  run_id_dropdown = None
26
 
 
61
 
62
  # Main function to handle UI interactions
63
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
64
+ global available_run_id_dict
65
  if not question_option:
66
+ return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
67
  logger.info("processing update")
68
  if type(baseline_ranker_name) == list:
69
  baseline_ranker_name = baseline_ranker_name[0]
 
77
 
78
  if question_option == "All questions":
79
  if finder_id_int and type(run_id) == str:
80
+ run_id_int = available_run_id_dict.get(run_id)
81
+ all_stats = await calculate_cumulative_statistics_for_all_questions(run_id_int, baseline_ranker_id_int)
82
  else:
83
  all_stats = None
84
+ return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
85
 
86
  # Extract question ID from selection
87
  question_id = questions_dict.get(question_option)
88
 
89
+ available_run_id_dict = await get_run_ids(question_id, finder_id_int)
90
+ run_id_options = list(available_run_id_dict.keys())
91
  if run_id not in run_id_options:
92
  run_id = run_id_options[0]
93
 
94
+ run_id_int = available_run_id_dict.get(run_id)
95
 
96
 
97
 
 
99
  stats = None
100
  # Get source runs data
101
  if finder_id_int:
102
+ source_runs, stats = await get_unified_sources(question_id, run_id_int, baseline_ranker_id_int)
103
  # Create DataFrame for display
104
  df = pd.DataFrame(source_runs)
105
 
 
113
 
114
  # CSV for download
115
  # csv_data = df.to_csv(index=False)
116
+ metadata = await get_metadata(question_id, run_id_int)
117
 
118
  result_message = f"Found {len(source_runs)} results"
119
+ return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
120
 
121
 
122
  # Add a new function to handle row selection
 
186
  # Sidebar area
187
  gr.Markdown("### About")
188
  gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
 
189
 
190
  gr.Markdown("### Statistics")
191
  gr.Markdown(f"Total Questions: {len(questions)}")
 
207
  ],
208
  interactive=False,
209
  )
210
+ with gr.Row():
211
+ metadata_text = gr.TextArea(
212
+ label="Metadata of Source Finder for Selected Question",
213
+ elem_id="metadata",
214
+ lines = 2
215
+ )
216
  with gr.Row():
217
  gr.Markdown("# Sources Found")
218
  with gr.Row():
 
249
  update_sources_list,
250
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
251
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
252
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
253
  )
254
 
255
  source_finder_dropdown.change(
256
  update_sources_list,
257
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
258
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
259
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
260
  )
261
 
262
  run_id_dropdown.change(
263
  update_sources_list,
264
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
265
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
266
  )
267
 
268
 
data_access.py CHANGED
@@ -35,6 +35,17 @@ async def get_questions():
35
  questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
36
  return [{"id": q["id"], "text": q["question_text"]} for q in questions]
37
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Get distinct source finders
40
  async def get_source_finders():
@@ -44,32 +55,19 @@ async def get_source_finders():
44
 
45
 
46
  # Get distinct run IDs for a question
47
- async def get_run_ids(question_id: int):
48
- async with get_async_connection() as conn:
49
- query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc"
50
- params = [question_id]
51
- run_ids = await conn.fetch(query, *params)
52
- return [r["run_id"] for r in run_ids]
53
-
54
-
55
- # Get source runs for a specific question with filters
56
- async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None,
57
- run_id: Optional[int] = None):
58
  async with get_async_connection() as conn:
59
- # Build query with filters
60
  query = """
61
- SELECT sr.*, sf.source_finder_type as finder_name
62
- FROM source_runs sr
63
- JOIN source_finders sf ON sr.source_finder_id = sf.id
64
- WHERE sr.question_id = $1 and sr.run_id = $2
65
- AND sr.source_finder_id = $3
 
66
  """
67
- params = [question_id, run_id, source_finder_id]
68
-
69
- query += " ORDER BY sr.rank DESC"
70
 
71
- sources = await conn.fetch(query, *params)
72
- return [dict(s) for s in sources]
73
 
74
  async def get_baseline_rankers():
75
  async with get_async_connection() as conn:
@@ -110,13 +108,12 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
110
  return results_df
111
 
112
 
113
- async def calculate_cumulative_statistics_for_all_questions(source_finder_id: int, run_id: int, ranker_id: int):
114
  """
115
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
116
 
117
  Args:
118
- source_finder_id (int): ID of the source finder
119
- run_id (int): Run ID to analyze
120
  ranker_id (int): ID of the baseline ranker
121
 
122
  Returns:
@@ -141,7 +138,7 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_id: in
141
  for question_id in question_ids:
142
  try:
143
  # Get unified sources for this question
144
- sources, stats = await get_unified_sources(question_id, source_finder_id, run_id, ranker_id)
145
 
146
  if sources and len(sources) > 0:
147
  valid_questions += 1
@@ -186,62 +183,68 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_id: in
186
  return pd.DataFrame([cumulative_stats])
187
 
188
 
189
- async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
190
  """
191
  Create unified view of sources from both baseline_sources and source_runs
192
  with indicators of where each source appears and their respective ranks.
193
  """
194
  async with get_async_connection() as conn:
195
- # Get sources from source_runs
196
- query_runs = """
197
- SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio,
198
- sr.reason as source_reason, sr.metadata
199
- FROM source_runs sr join talmud_bavli tb on sr.sugya_id = tb.xml_id
200
- WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
201
- """
202
- source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
203
 
204
- # Get sources from baseline_sources
205
- query_baseline = """
206
- SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
207
- FROM baseline_sources bs join talmud_bavli tb on bs.sugya_id = tb.xml_id
208
- WHERE bs.question_id = $1 AND bs.ranker_id = $2
209
- """
210
- baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
211
-
212
- stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
213
-
214
- # Convert to dictionaries for easier lookup
215
- source_runs_dict = {s["id"]: dict(s) for s in source_runs}
216
- baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
217
-
218
- # Get all unique sugya_ids
219
- all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
220
-
221
- # Build unified results
222
- unified_results = []
223
- for sugya_id in all_sugya_ids:
224
- in_source_run = sugya_id in source_runs_dict
225
- in_baseline = sugya_id in baseline_dict
226
- if in_baseline:
227
- info = baseline_dict[sugya_id]
228
- else:
229
- info = source_runs_dict[sugya_id]
230
- result = {
231
- "id": sugya_id,
232
- "tractate": info.get("tractate"),
233
- "folio": info.get("folio"),
234
- "in_baseline": "Yes" if in_baseline else "No",
235
- "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
236
- "in_source_run": "Yes" if in_source_run else "No",
237
- "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
238
- "source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A"),
239
- "metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
240
- }
241
- unified_results.append(result)
242
 
243
 
244
- return unified_results, stats_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
 
247
  async def get_source_text(tractate_chunk_id: int):
 
35
  questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
36
  return [{"id": q["id"], "text": q["question_text"]} for q in questions]
37
 
38
+ async def get_metadata(question_id: int, source_finder_id_run_id: int):
39
+ async with get_async_connection() as conn:
40
+ metadata = await conn.fetchrow('''
41
+ SELECT metadata
42
+ FROM source_finder_run_question_metadata sfrqm
43
+ WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
44
+ ''', question_id, source_finder_id_run_id)
45
+ if metadata is None:
46
+ return ""
47
+ return metadata.get('metadata')
48
+
49
 
50
  # Get distinct source finders
51
  async def get_source_finders():
 
55
 
56
 
57
  # Get distinct run IDs for a question
58
+ async def get_run_ids(question_id: int, source_finder_id: int):
 
 
 
 
 
 
 
 
 
 
59
  async with get_async_connection() as conn:
 
60
  query = """
61
+ select distinct sfr.description, srs.source_finder_run_id as run_id
62
+ from talmudexplore.source_run_results srs
63
+ join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
64
+ join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
65
+ where sfr.source_finder_id = $1
66
+ and srs.question_id = $2
67
  """
68
+ run_ids = await conn.fetch(query, source_finder_id, question_id)
69
+ return {r["description"]:r["run_id"] for r in run_ids}
 
70
 
 
 
71
 
72
  async def get_baseline_rankers():
73
  async with get_async_connection() as conn:
 
108
  return results_df
109
 
110
 
111
+ async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id: int, ranker_id: int):
112
  """
113
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
114
 
115
  Args:
116
+ source_finder_run_id (int): ID of the source finder and run as appears in source runs
 
117
  ranker_id (int): ID of the baseline ranker
118
 
119
  Returns:
 
138
  for question_id in question_ids:
139
  try:
140
  # Get unified sources for this question
141
+ stats, sources = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
142
 
143
  if sources and len(sources) > 0:
144
  valid_questions += 1
 
183
  return pd.DataFrame([cumulative_stats])
184
 
185
 
186
+ async def get_unified_sources(question_id: int, source_finder_run_id: int, ranker_id: int):
187
  """
188
  Create unified view of sources from both baseline_sources and source_runs
189
  with indicators of where each source appears and their respective ranks.
190
  """
191
  async with get_async_connection() as conn:
192
+ stats_df, unified_results = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
 
 
 
 
 
 
 
193
 
194
+ return unified_results, stats_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
+ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
198
+ # Get sources from source_runs
199
+ query_runs = """
200
+ SELECT tb.tractate_chunk_id as id,
201
+ sr.rank as source_rank,
202
+ sr.tractate,
203
+ sr.folio,
204
+ sr.reason as source_reason
205
+ FROM source_run_results sr
206
+ join talmud_bavli tb on sr.sugya_id = tb.xml_id
207
+ WHERE sr.question_id = $1
208
+ AND sr.source_finder_run_id = $2
209
+ """
210
+ source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
211
+ # Get sources from baseline_sources
212
+ query_baseline = """
213
+ SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
214
+ FROM baseline_sources bs
215
+ join talmud_bavli tb on bs.sugya_id = tb.xml_id
216
+ WHERE bs.question_id = $1
217
+ AND bs.ranker_id = $2
218
+ """
219
+ baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
220
+ stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
221
+ # Convert to dictionaries for easier lookup
222
+ source_runs_dict = {s["id"]: dict(s) for s in source_runs}
223
+ baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
224
+ # Get all unique sugya_ids
225
+ all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
226
+ # Build unified results
227
+ unified_results = []
228
+ for sugya_id in all_sugya_ids:
229
+ in_source_run = sugya_id in source_runs_dict
230
+ in_baseline = sugya_id in baseline_dict
231
+ if in_baseline:
232
+ info = baseline_dict[sugya_id]
233
+ else:
234
+ info = source_runs_dict[sugya_id]
235
+ result = {
236
+ "id": sugya_id,
237
+ "tractate": info.get("tractate"),
238
+ "folio": info.get("folio"),
239
+ "in_baseline": "Yes" if in_baseline else "No",
240
+ "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
241
+ "in_source_run": "Yes" if in_source_run else "No",
242
+ "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
243
+ "source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A"),
244
+ "metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
245
+ }
246
+ unified_results.append(result)
247
+ return stats_df, unified_results
248
 
249
 
250
  async def get_source_text(tractate_chunk_id: int):
eval_tables.py CHANGED
@@ -51,12 +51,35 @@ def create_eval_database():
51
  );
52
  ''')
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Create table for logging all sources from each run
55
  cursor.execute('''
56
- CREATE TABLE IF NOT EXISTS source_runs (
57
  id SERIAL PRIMARY KEY,
58
- source_finder_id INTEGER NOT NULL,
59
- run_id TEXT NOT NULL,
60
  run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
61
  question_id INTEGER NOT NULL,
62
  tractate TEXT NOT NULL,
@@ -64,7 +87,7 @@ def create_eval_database():
64
  sugya_id TEXT NOT NULL,
65
  rank INTEGER NOT NULL,
66
  reason TEXT,
67
- FOREIGN KEY (source_finder_id) REFERENCES source_finders(id),
68
  FOREIGN KEY (question_id) REFERENCES questions(id)
69
  );
70
  ''')
@@ -99,8 +122,8 @@ def load_baseline_sources():
99
 
100
  if __name__ == '__main__':
101
  # Create the database
102
- # create_eval_database()
103
- load_baseline_sources()
104
 
105
 
106
 
 
51
  );
52
  ''')
53
 
54
+ cursor.execute('''
55
+ CREATE TABLE IF NOT EXISTS source_finder_runs (
56
+ id SERIAL PRIMARY KEY,
57
+ run_id INTEGER NOT NULL,
58
+ source_finder_id INTEGER NOT NULL,
59
+ description TEXT,
60
+ FOREIGN KEY (source_finder_id) REFERENCES source_finders(id),
61
+ CONSTRAINT unique_source_per_run_id UNIQUE(run_id, source_finder_id)
62
+ );
63
+ ''')
64
+
65
+ cursor.execute('''
66
+ CREATE TABLE IF NOT EXISTS source_finder_run_question_metadata (
67
+ id SERIAL PRIMARY KEY,
68
+ question_id INTEGER NOT NULL,
69
+ source_finder_run_id INTEGER NOT NULL,
70
+ metadata JSON,
71
+ FOREIGN KEY (source_finder_run_id) REFERENCES source_finder_runs(id),
72
+ FOREIGN KEY (question_id) REFERENCES questions(id),
73
+ CONSTRAINT unique_question_per_run_id UNIQUE(question_id, source_finder_run_id)
74
+ );
75
+ ''')
76
+
77
+
78
  # Create table for logging all sources from each run
79
  cursor.execute('''
80
+ CREATE TABLE IF NOT EXISTS source_run_results (
81
  id SERIAL PRIMARY KEY,
82
+ source_finder_run_id INTEGER NOT NULL,
 
83
  run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
84
  question_id INTEGER NOT NULL,
85
  tractate TEXT NOT NULL,
 
87
  sugya_id TEXT NOT NULL,
88
  rank INTEGER NOT NULL,
89
  reason TEXT,
90
+ FOREIGN KEY (source_finder_run_id) REFERENCES source_finder_runs(id),
91
  FOREIGN KEY (question_id) REFERENCES questions(id)
92
  );
93
  ''')
 
122
 
123
  if __name__ == '__main__':
124
  # Create the database
125
+ create_eval_database()
126
+ # load_baseline_sources()
127
 
128
 
129
 
scripts/__init__.py ADDED
File without changes
tests/test_db_layer.py CHANGED
@@ -1,7 +1,7 @@
1
  import pandas as pd
2
  import pytest
3
 
4
- from data_access import calculate_cumulative_statistics_for_all_questions
5
  from data_access import get_unified_sources
6
 
7
 
@@ -20,9 +20,6 @@ async def test_get_unified_sources():
20
  # You can also check specific stats columns
21
  assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
22
 
23
-
24
-
25
-
26
  @pytest.mark.asyncio
27
  async def test_calculate_cumulative_statistics_for_all_questions():
28
  # Test with known source_finder_id, run_id, and ranker_id
@@ -65,3 +62,47 @@ async def test_calculate_cumulative_statistics_for_all_questions():
65
  assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
66
  0] <= 100, "High ranked overlap percentage should be between 0 and 100"
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import pytest
3
 
4
+ from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids
5
  from data_access import get_unified_sources
6
 
7
 
 
20
  # You can also check specific stats columns
21
  assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
22
 
 
 
 
23
  @pytest.mark.asyncio
24
  async def test_calculate_cumulative_statistics_for_all_questions():
25
  # Test with known source_finder_id, run_id, and ranker_id
 
62
  assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
63
  0] <= 100, "High ranked overlap percentage should be between 0 and 100"
64
 
65
+ @pytest.mark.asyncio
66
+ async def test_get_metadata_none_returned():
67
+ # Test with known source_finder_id, run_id, and ranker_id
68
+ source_finder_id = 1
69
+ run_id = 1
70
+ question_id = 1
71
+
72
+ # Call the function to test
73
+ result = await get_metadata(question_id, source_finder_id, run_id)
74
+
75
+ assert result == "", "Should return empty string when no metadata is found"
76
+
77
+ @pytest.mark.asyncio
78
+ async def test_get_metadata():
79
+ # Test with known source_finder_id, run_id, and ranker_id
80
+ source_finder_run_id = 4
81
+ question_id = 1
82
+
83
+ # Call the function to test
84
+ result = await get_metadata(question_id, source_finder_run_id)
85
+
86
+ assert result is not None, "Should return metadata when it exists"
87
+
88
+
89
+ @pytest.mark.asyncio
90
+ async def test_get_run_ids():
91
+ # Test with known question_id and source_finder_id
92
+ question_id = 2 # Using a question ID that exists in the test database
93
+ source_finder_id = 2 # Using a source finder ID that exists in the test database
94
+
95
+ # Call the function to test
96
+ result = await get_run_ids(question_id, source_finder_id)
97
+
98
+ # Verify the result is a dictionary
99
+ assert isinstance(result, dict), "Result should be a dictionary"
100
+
101
+ # Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
102
+ assert len(result) > 0, "Should return at least one run ID"
103
+
104
+ # Test with a non-existent question_id
105
+ non_existent_question_id = 9999
106
+ empty_result = await get_run_ids(non_existent_question_id, source_finder_id)
107
+ assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
108
+ assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"