davidr70 commited on
Commit
3a7a44c
·
1 Parent(s): 312213e

fix connection reuse

Browse files
Files changed (4) hide show
  1. app.py +49 -47
  2. data_access.py +141 -153
  3. eval_tables.py +5 -0
  4. tests/test_db_layer.py +21 -17
app.py CHANGED
@@ -5,7 +5,8 @@ import pandas as pd
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
- get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -20,7 +21,7 @@ baseline_ranker_options = []
20
  run_ids = []
21
  available_run_id_dict = {}
22
  finder_options = []
23
- previous_run_id = None
24
 
25
  run_id_dropdown = None
26
 
@@ -29,13 +30,13 @@ run_id_dropdown = None
29
  # Initialize data in a single async function
30
  async def initialize_data():
31
  global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
 
 
 
 
 
32
 
33
- questions = await get_questions()
34
- source_finders = await get_source_finders()
35
-
36
- baseline_rankers = await get_baseline_rankers()
37
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
38
-
39
  # Convert to dictionaries for easier lookup
40
  questions_dict = {q["text"]: q["id"] for q in questions}
41
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
@@ -52,7 +53,7 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
52
  if evt:
53
  logger.info(f"event: {evt.target.elem_id}")
54
  if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
55
- return gr.skip(), gr.skip(), gr.skip(), gr.skip()
56
 
57
  if type(run_id) == str:
58
  previous_run_id = run_id
@@ -65,55 +66,56 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
65
  if not question_option:
66
  return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
67
  logger.info("processing update")
68
- if type(baseline_ranker_name) == list:
69
- baseline_ranker_name = baseline_ranker_name[0]
70
-
71
- baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
72
 
73
- if len(source_finder_name):
74
- finder_id_int = source_finders_dict.get(source_finder_name)
75
- else:
76
- finder_id_int = None
77
 
78
- if question_option == "All questions":
79
- if finder_id_int and type(run_id) == str:
80
- run_id_int = available_run_id_dict.get(run_id)
81
- all_stats = await calculate_cumulative_statistics_for_all_questions(run_id_int, baseline_ranker_id_int)
82
  else:
83
- all_stats = None
84
- return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
85
 
86
- # Extract question ID from selection
87
- question_id = questions_dict.get(question_option)
 
 
 
 
 
88
 
89
- available_run_id_dict = await get_run_ids(question_id, finder_id_int)
90
- run_id_options = list(available_run_id_dict.keys())
91
- if run_id not in run_id_options:
92
- run_id = run_id_options[0]
93
 
94
- run_id_int = available_run_id_dict.get(run_id)
 
 
 
95
 
 
96
 
97
 
98
- source_runs = None
99
- stats = None
100
- # Get source runs data
101
- if finder_id_int:
102
- source_runs, stats = await get_unified_sources(question_id, run_id_int, baseline_ranker_id_int)
103
- # Create DataFrame for display
104
- df = pd.DataFrame(source_runs)
105
 
106
- if not source_runs:
107
- return None, None, run_id_options, "No results found for the selected filters",
 
 
 
 
 
108
 
109
- # Format table columns
110
- columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
111
- 'folio', 'reason']
112
- df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
113
 
114
- # CSV for download
115
- # csv_data = df.to_csv(index=False)
116
- metadata = await get_metadata(question_id, run_id_int)
 
 
 
 
 
117
 
118
  result_message = f"Found {len(source_runs)} results"
119
  return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
@@ -128,7 +130,8 @@ async def handle_row_selection_async(evt: gr.SelectData):
128
  # Get the ID from the selected row
129
  tractate_chunk_id = evt.row_value[0]
130
  # Get the source text
131
- text = await get_source_text(tractate_chunk_id)
 
132
  return text
133
  except Exception as e:
134
  return f"Error retrieving source text: {str(e)}"
@@ -248,7 +251,6 @@ async def main():
248
  question_dropdown.change(
249
  update_sources_list,
250
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
251
- # outputs=[run_id_dropdown, results_table, result_text, download_button]
252
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
253
  )
254
 
 
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
+ get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
9
+ get_async_connection
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
21
  run_ids = []
22
  available_run_id_dict = {}
23
  finder_options = []
24
+ previous_run_id = "initial_run"
25
 
26
  run_id_dropdown = None
27
 
 
30
  # Initialize data in a single async function
31
  async def initialize_data():
32
  global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
33
+ async with get_async_connection() as conn:
34
+ # Get questions and source finders
35
+ questions = await get_questions(conn)
36
+ source_finders = await get_source_finders(conn)
37
+ baseline_rankers = await get_baseline_rankers(conn)
38
 
 
 
 
 
39
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
 
40
  # Convert to dictionaries for easier lookup
41
  questions_dict = {q["text"]: q["id"] for q in questions}
42
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
 
53
  if evt:
54
  logger.info(f"event: {evt.target.elem_id}")
55
  if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
56
+ return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
57
 
58
  if type(run_id) == str:
59
  previous_run_id = run_id
 
66
  if not question_option:
67
  return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
68
  logger.info("processing update")
69
+ async with get_async_connection() as conn:
70
+ if type(baseline_ranker_name) == list:
71
+ baseline_ranker_name = baseline_ranker_name[0]
 
72
 
73
+ baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
 
 
 
74
 
75
+ if len(source_finder_name):
76
+ finder_id_int = source_finders_dict.get(source_finder_name)
 
 
77
  else:
78
+ finder_id_int = None
 
79
 
80
+ if question_option == "All questions":
81
+ if finder_id_int and type(run_id) == str:
82
+ run_id_int = available_run_id_dict.get(run_id)
83
+ all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
84
+ else:
85
+ all_stats = None
86
+ return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
87
 
88
+ # Extract question ID from selection
89
+ question_id = questions_dict.get(question_option)
 
 
90
 
91
+ available_run_id_dict = await get_run_ids(conn, question_id, finder_id_int)
92
+ run_id_options = list(available_run_id_dict.keys())
93
+ if run_id not in run_id_options:
94
+ run_id = run_id_options[0]
95
 
96
+ run_id_int = available_run_id_dict.get(run_id)
97
 
98
 
 
 
 
 
 
 
 
99
 
100
+ source_runs = None
101
+ stats = None
102
+ # Get source runs data
103
+ if finder_id_int:
104
+ source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int)
105
+ # Create DataFrame for display
106
+ df = pd.DataFrame(source_runs)
107
 
108
+ if not source_runs:
109
+ return None, None, run_id_options, "No results found for the selected filters",
 
 
110
 
111
+ # Format table columns
112
+ columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
113
+ 'folio', 'reason']
114
+ df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
115
+
116
+ # CSV for download
117
+ # csv_data = df.to_csv(index=False)
118
+ metadata = await get_metadata(conn, question_id, run_id_int)
119
 
120
  result_message = f"Found {len(source_runs)} results"
121
  return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
 
130
  # Get the ID from the selected row
131
  tractate_chunk_id = evt.row_value[0]
132
  # Get the source text
133
+ async with get_async_connection() as conn:
134
+ text = await get_source_text(conn, tractate_chunk_id)
135
  return text
136
  except Exception as e:
137
  return f"Error retrieving source text: {str(e)}"
 
251
  question_dropdown.change(
252
  update_sources_list,
253
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
 
254
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
255
  )
256
 
data_access.py CHANGED
@@ -30,85 +30,80 @@ async def get_async_connection(schema="talmudexplore"):
30
  await conn.close()
31
 
32
 
33
- async def get_questions():
34
- async with get_async_connection() as conn:
35
- questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
36
- return [{"id": q["id"], "text": q["question_text"]} for q in questions]
37
-
38
- async def get_metadata(question_id: int, source_finder_id_run_id: int):
39
- async with get_async_connection() as conn:
40
- metadata = await conn.fetchrow('''
41
- SELECT metadata
42
- FROM source_finder_run_question_metadata sfrqm
43
- WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
44
- ''', question_id, source_finder_id_run_id)
45
- if metadata is None:
46
- return ""
47
- return metadata.get('metadata')
48
 
49
 
50
  # Get distinct source finders
51
- async def get_source_finders():
52
- async with get_async_connection() as conn:
53
- finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
54
- return [{"id": f["id"], "name": f["name"]} for f in finders]
55
 
56
 
57
  # Get distinct run IDs for a question
58
- async def get_run_ids(question_id: int, source_finder_id: int):
59
- async with get_async_connection() as conn:
60
- query = """
61
- select distinct sfr.description, srs.source_finder_run_id as run_id
62
- from talmudexplore.source_run_results srs
63
- join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
64
- join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
65
- where sfr.source_finder_id = $1
66
- and srs.question_id = $2
67
- """
68
- run_ids = await conn.fetch(query, source_finder_id, question_id)
69
- return {r["description"]:r["run_id"] for r in run_ids}
70
-
71
-
72
- async def get_baseline_rankers():
73
- async with get_async_connection() as conn:
74
- rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
75
- return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
76
-
77
- async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources):
78
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
79
  # e.g. overlap, high ranked overlap, etc.
80
- async with get_async_connection() as conn:
81
- actual_sources_set = {s["id"] for s in source_runs_sources}
82
- baseline_sources_set = {s["id"] for s in baseline_sources}
83
-
84
- # Calculate overlap
85
- overlap = actual_sources_set.intersection(baseline_sources_set)
86
- # only_in_1 = actual_sources_set - baseline_sources_set
87
- # only_in_2 = baseline_sources_set - actual_sources_set
88
-
89
- # Calculate high-ranked overlap (rank >= 4)
90
- actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
91
- baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
92
-
93
- high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
94
-
95
- results = {
96
- "total_baseline_sources": len(baseline_sources),
97
- "total_found_sources": len(source_runs_sources),
98
- "overlap_count": len(overlap),
99
- "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
100
- 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
101
- "num_high_ranked_baseline_sources": len(baseline_high_ranked),
102
- "num_high_ranked_found_sources": len(actual_high_ranked),
103
- "high_ranked_overlap_count": len(high_ranked_overlap),
104
- "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
105
- }
106
- #convert results to dataframe
107
- results_df = pd.DataFrame([results])
108
- return results_df
109
 
110
-
111
- async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id: int, ranker_id: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  """
113
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
114
 
@@ -119,83 +114,75 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id
119
  Returns:
120
  pd.DataFrame: DataFrame containing aggregated statistics
121
  """
122
- async with get_async_connection() as conn:
123
- # Get all questions
124
- query = "SELECT id FROM questions ORDER BY id"
125
- questions = await conn.fetch(query)
126
- question_ids = [q["id"] for q in questions]
127
-
128
- # Initialize aggregates
129
- total_baseline_sources = 0
130
- total_found_sources = 0
131
- total_overlap = 0
132
- total_high_ranked_baseline = 0
133
- total_high_ranked_found = 0
134
- total_high_ranked_overlap = 0
135
-
136
- # Process each question
137
- valid_questions = 0
138
- for question_id in question_ids:
139
- try:
140
- # Get unified sources for this question
141
- stats, sources = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
142
-
143
- if sources and len(sources) > 0:
144
- valid_questions += 1
145
- stats_dict = stats.iloc[0].to_dict()
146
-
147
- # Add to running totals
148
- total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
149
- total_found_sources += stats_dict.get('total_found_sources', 0)
150
- total_overlap += stats_dict.get('overlap_count', 0)
151
- total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
152
- total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
153
- total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
154
- except Exception as e:
155
- # Skip questions with errors
156
- continue
157
-
158
- # Calculate overall percentages
159
- overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
160
- if max(total_baseline_sources, total_found_sources) > 0 else 0
161
-
162
- high_ranked_overlap_percentage = round(
163
- total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
164
- if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
165
-
166
- # Compile results
167
- cumulative_stats = {
168
- "total_questions_analyzed": valid_questions,
169
- "total_baseline_sources": total_baseline_sources,
170
- "total_found_sources": total_found_sources,
171
- "total_overlap_count": total_overlap,
172
- "overall_overlap_percentage": overlap_percentage,
173
- "total_high_ranked_baseline_sources": total_high_ranked_baseline,
174
- "total_high_ranked_found_sources": total_high_ranked_found,
175
- "total_high_ranked_overlap_count": total_high_ranked_overlap,
176
- "overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
177
- "avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
178
- 2) if valid_questions > 0 else 0,
179
- "avg_found_sources_per_question": round(total_found_sources / valid_questions,
180
- 2) if valid_questions > 0 else 0
181
- }
182
-
183
- return pd.DataFrame([cumulative_stats])
184
-
185
-
186
- async def get_unified_sources(question_id: int, source_finder_run_id: int, ranker_id: int):
187
  """
188
  Create unified view of sources from both baseline_sources and source_runs
189
  with indicators of where each source appears and their respective ranks.
190
  """
191
- async with get_async_connection() as conn:
192
- stats_df, unified_results = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
193
-
194
- return unified_results, stats_df
195
 
196
-
197
- async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
198
- # Get sources from source_runs
199
  query_runs = """
200
  SELECT tb.tractate_chunk_id as id,
201
  sr.rank as source_rank,
@@ -217,7 +204,7 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
217
  AND bs.ranker_id = $2
218
  """
219
  baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
220
- stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
221
  # Convert to dictionaries for easier lookup
222
  source_runs_dict = {s["id"]: dict(s) for s in source_runs}
223
  baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
@@ -244,21 +231,22 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
244
  "metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
245
  }
246
  unified_results.append(result)
247
- return stats_df, unified_results
 
248
 
249
 
250
- async def get_source_text(tractate_chunk_id: int):
251
  """
252
  Retrieves the text content for a given tractate chunk ID.
253
  """
254
- async with get_async_connection() as conn:
255
- query = """
256
- SELECT tb.text_with_nikud as text
257
- FROM talmud_bavli tb
258
- WHERE tb.tractate_chunk_id = $1
259
- """
260
- result = await conn.fetchrow(query, tractate_chunk_id)
261
- return result["text"] if result else "Source text not found"
262
 
263
  def get_pg_sync_connection(schema="talmudexplore"):
264
  conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
 
30
  await conn.close()
31
 
32
 
33
+ async def get_questions(conn: asyncpg.Connection):
34
+ questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
35
+ return [{"id": q["id"], "text": q["question_text"]} for q in questions]
36
+
37
+ async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int):
38
+ metadata = await conn.fetchrow('''
39
+ SELECT metadata
40
+ FROM source_finder_run_question_metadata sfrqm
41
+ WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
42
+ ''', question_id, source_finder_id_run_id)
43
+ if metadata is None:
44
+ return ""
45
+ return metadata.get('metadata')
 
 
46
 
47
 
48
  # Get distinct source finders
49
+ async def get_source_finders(conn: asyncpg.Connection):
50
+ finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
51
+ return [{"id": f["id"], "name": f["name"]} for f in finders]
 
52
 
53
 
54
  # Get distinct run IDs for a question
55
+ async def get_run_ids(conn: asyncpg.Connection, question_id: int, source_finder_id: int):
56
+ query = """
57
+ select distinct sfr.description, srs.source_finder_run_id as run_id
58
+ from talmudexplore.source_run_results srs
59
+ join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
60
+ join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
61
+ where sfr.source_finder_id = $1
62
+ and srs.question_id = $2
63
+ """
64
+ run_ids = await conn.fetch(query, source_finder_id, question_id)
65
+ return {r["description"]:r["run_id"] for r in run_ids}
66
+
67
+
68
+ async def get_baseline_rankers(conn: asyncpg.Connection):
69
+ rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
70
+ return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
71
+
72
+ async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
 
 
73
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
74
  # e.g. overlap, high ranked overlap, etc.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ actual_sources_set = {s["id"] for s in source_runs_sources}
77
+ baseline_sources_set = {s["id"] for s in baseline_sources}
78
+
79
+ # Calculate overlap
80
+ overlap = actual_sources_set.intersection(baseline_sources_set)
81
+ # only_in_1 = actual_sources_set - baseline_sources_set
82
+ # only_in_2 = baseline_sources_set - actual_sources_set
83
+
84
+ # Calculate high-ranked overlap (rank >= 4)
85
+ actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
86
+ baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
87
+
88
+ high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
89
+
90
+ results = {
91
+ "total_baseline_sources": len(baseline_sources),
92
+ "total_found_sources": len(source_runs_sources),
93
+ "overlap_count": len(overlap),
94
+ "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
95
+ 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
96
+ "num_high_ranked_baseline_sources": len(baseline_high_ranked),
97
+ "num_high_ranked_found_sources": len(actual_high_ranked),
98
+ "high_ranked_overlap_count": len(high_ranked_overlap),
99
+ "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
100
+ }
101
+ #convert results to dataframe
102
+ results_df = pd.DataFrame([results])
103
+ return results_df
104
+
105
+
106
+ async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int):
107
  """
108
  Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
109
 
 
114
  Returns:
115
  pd.DataFrame: DataFrame containing aggregated statistics
116
  """
117
+ # Get all questions
118
+ query = "SELECT id FROM questions ORDER BY id"
119
+ questions = await conn.fetch(query)
120
+ question_ids = [q["id"] for q in questions]
121
+
122
+ # Initialize aggregates
123
+ total_baseline_sources = 0
124
+ total_found_sources = 0
125
+ total_overlap = 0
126
+ total_high_ranked_baseline = 0
127
+ total_high_ranked_found = 0
128
+ total_high_ranked_overlap = 0
129
+
130
+ # Process each question
131
+ valid_questions = 0
132
+ for question_id in question_ids:
133
+ try:
134
+ # Get unified sources for this question
135
+ sources, stats = await get_unified_sources(conn, question_id, ranker_id, source_finder_run_id)
136
+
137
+ if sources and len(sources) > 0:
138
+ valid_questions += 1
139
+ stats_dict = stats.iloc[0].to_dict()
140
+
141
+ # Add to running totals
142
+ total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
143
+ total_found_sources += stats_dict.get('total_found_sources', 0)
144
+ total_overlap += stats_dict.get('overlap_count', 0)
145
+ total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
146
+ total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
147
+ total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
148
+ except Exception as e:
149
+ # Skip questions with errors
150
+ continue
151
+
152
+ # Calculate overall percentages
153
+ overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
154
+ if max(total_baseline_sources, total_found_sources) > 0 else 0
155
+
156
+ high_ranked_overlap_percentage = round(
157
+ total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
158
+ if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
159
+
160
+ # Compile results
161
+ cumulative_stats = {
162
+ "total_questions_analyzed": valid_questions,
163
+ "total_baseline_sources": total_baseline_sources,
164
+ "total_found_sources": total_found_sources,
165
+ "total_overlap_count": total_overlap,
166
+ "overall_overlap_percentage": overlap_percentage,
167
+ "total_high_ranked_baseline_sources": total_high_ranked_baseline,
168
+ "total_high_ranked_found_sources": total_high_ranked_found,
169
+ "total_high_ranked_overlap_count": total_high_ranked_overlap,
170
+ "overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
171
+ "avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
172
+ 2) if valid_questions > 0 else 0,
173
+ "avg_found_sources_per_question": round(total_found_sources / valid_questions,
174
+ 2) if valid_questions > 0 else 0
175
+ }
176
+
177
+ return pd.DataFrame([cumulative_stats])
178
+
179
+
180
+ async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int):
 
181
  """
182
  Create unified view of sources from both baseline_sources and source_runs
183
  with indicators of where each source appears and their respective ranks.
184
  """
 
 
 
 
185
 
 
 
 
186
  query_runs = """
187
  SELECT tb.tractate_chunk_id as id,
188
  sr.rank as source_rank,
 
204
  AND bs.ranker_id = $2
205
  """
206
  baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
207
+ stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
208
  # Convert to dictionaries for easier lookup
209
  source_runs_dict = {s["id"]: dict(s) for s in source_runs}
210
  baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
 
231
  "metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
232
  }
233
  unified_results.append(result)
234
+
235
+ return unified_results, stats_df
236
 
237
 
238
+ async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int):
239
  """
240
  Retrieves the text content for a given tractate chunk ID.
241
  """
242
+
243
+ query = """
244
+ SELECT tb.text_with_nikud as text
245
+ FROM talmud_bavli tb
246
+ WHERE tb.tractate_chunk_id = $1
247
+ """
248
+ result = await conn.fetchrow(query, tractate_chunk_id)
249
+ return result["text"] if result else "Source text not found"
250
 
251
  def get_pg_sync_connection(schema="talmudexplore"):
252
  conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
eval_tables.py CHANGED
@@ -92,6 +92,11 @@ def create_eval_database():
92
  );
93
  ''')
94
 
 
 
 
 
 
95
  conn.commit()
96
  conn.close()
97
 
 
92
  );
93
  ''')
94
 
95
+ cursor.execute('''alter table source_run_results
96
+ add constraint source_run_results_pk
97
+ unique (source_finder_run_id, question_id, sugya_id);
98
+ ''')
99
+
100
  conn.commit()
101
  conn.close()
102
 
tests/test_db_layer.py CHANGED
@@ -1,13 +1,15 @@
1
  import pandas as pd
2
  import pytest
3
 
4
- from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids
 
5
  from data_access import get_unified_sources
6
 
7
 
8
  @pytest.mark.asyncio
9
  async def test_get_unified_sources():
10
- results, stats = await get_unified_sources(2, 2, 1, 1)
 
11
  assert results is not None
12
  assert stats is not None
13
 
@@ -23,12 +25,12 @@ async def test_get_unified_sources():
23
  @pytest.mark.asyncio
24
  async def test_calculate_cumulative_statistics_for_all_questions():
25
  # Test with known source_finder_id, run_id, and ranker_id
26
- source_finder_id = 2
27
- run_id = 1
28
  ranker_id = 1
29
 
30
  # Call the function to test
31
- result = await calculate_cumulative_statistics_for_all_questions(source_finder_id, run_id, ranker_id)
 
32
 
33
  # Check basic structure of results
34
  assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
@@ -65,12 +67,12 @@ async def test_calculate_cumulative_statistics_for_all_questions():
65
  @pytest.mark.asyncio
66
  async def test_get_metadata_none_returned():
67
  # Test with known source_finder_id, run_id, and ranker_id
68
- source_finder_id = 1
69
- run_id = 1
70
  question_id = 1
71
 
72
  # Call the function to test
73
- result = await get_metadata(question_id, source_finder_id, run_id)
 
74
 
75
  assert result == "", "Should return empty string when no metadata is found"
76
 
@@ -81,7 +83,8 @@ async def test_get_metadata():
81
  question_id = 1
82
 
83
  # Call the function to test
84
- result = await get_metadata(question_id, source_finder_run_id)
 
85
 
86
  assert result is not None, "Should return metadata when it exists"
87
 
@@ -93,16 +96,17 @@ async def test_get_run_ids():
93
  source_finder_id = 2 # Using a source finder ID that exists in the test database
94
 
95
  # Call the function to test
96
- result = await get_run_ids(question_id, source_finder_id)
 
97
 
98
- # Verify the result is a dictionary
99
- assert isinstance(result, dict), "Result should be a dictionary"
100
 
101
- # Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
102
- assert len(result) > 0, "Should return at least one run ID"
103
 
104
- # Test with a non-existent question_id
105
- non_existent_question_id = 9999
106
- empty_result = await get_run_ids(non_existent_question_id, source_finder_id)
107
  assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
108
  assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
 
1
  import pandas as pd
2
  import pytest
3
 
4
+ from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
5
+ get_async_connection
6
  from data_access import get_unified_sources
7
 
8
 
9
  @pytest.mark.asyncio
10
  async def test_get_unified_sources():
11
+ async with get_async_connection() as conn:
12
+ results, stats = await get_unified_sources(conn,2, 2, 1)
13
  assert results is not None
14
  assert stats is not None
15
 
 
25
  @pytest.mark.asyncio
26
  async def test_calculate_cumulative_statistics_for_all_questions():
27
  # Test with known source_finder_id, run_id, and ranker_id
28
+ source_finder_run_id = 2
 
29
  ranker_id = 1
30
 
31
  # Call the function to test
32
+ async with get_async_connection() as conn:
33
+ result = await calculate_cumulative_statistics_for_all_questions(conn, source_finder_run_id, ranker_id)
34
 
35
  # Check basic structure of results
36
  assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
 
67
  @pytest.mark.asyncio
68
  async def test_get_metadata_none_returned():
69
  # Test with known source_finder_id, run_id, and ranker_id
70
+ source_finder_run_id = 1
 
71
  question_id = 1
72
 
73
  # Call the function to test
74
+ async with get_async_connection() as conn:
75
+ result = await get_metadata(conn, question_id, source_finder_run_id)
76
 
77
  assert result == "", "Should return empty string when no metadata is found"
78
 
 
83
  question_id = 1
84
 
85
  # Call the function to test
86
+ async with get_async_connection() as conn:
87
+ result = await get_metadata(conn, question_id, source_finder_run_id)
88
 
89
  assert result is not None, "Should return metadata when it exists"
90
 
 
96
  source_finder_id = 2 # Using a source finder ID that exists in the test database
97
 
98
  # Call the function to test
99
+ async with get_async_connection() as conn:
100
+ result = await get_run_ids(conn, question_id, source_finder_id)
101
 
102
+ # Verify the result is a dictionary
103
+ assert isinstance(result, dict), "Result should be a dictionary"
104
 
105
+ # Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
106
+ assert len(result) > 0, "Should return at least one run ID"
107
 
108
+ # Test with a non-existent question_id
109
+ non_existent_question_id = 9999
110
+ empty_result = await get_run_ids(conn, non_existent_question_id, source_finder_id)
111
  assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
112
  assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"