davidr70 commited on
Commit
6e35819
·
1 Parent(s): 0d42969

improvements

Browse files
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import asyncio
2
- from typing import Optional
3
  import gradio as gr
4
  import pandas as pd
5
- from data_access import get_pool, get_async_connection, close_pool
 
6
 
7
  # Initialize data at the module level
8
  questions = []
@@ -10,61 +11,26 @@ source_finders = []
10
  questions_dict = {}
11
  source_finders_dict = {}
12
  question_options = []
 
 
 
13
  run_ids = []
14
  finder_options = []
15
  finder_labels = {"All": "All Source Finders"}
16
 
17
 
18
  # Get all questions
19
- async def get_questions():
20
- async with get_async_connection() as conn:
21
- questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
22
- return [{"id": q["id"], "text": q["question_text"]} for q in questions]
23
-
24
-
25
- # Get distinct source finders
26
- async def get_source_finders():
27
- async with get_async_connection() as conn:
28
- finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
29
- return [{"id": f["id"], "name": f["name"]} for f in finders]
30
-
31
-
32
- # Get distinct run IDs for a question
33
- async def get_run_ids(question_id: int):
34
- async with get_async_connection() as conn:
35
- query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc"
36
- params = [question_id]
37
- run_ids = await conn.fetch(query, *params)
38
- return [r["run_id"] for r in run_ids]
39
-
40
-
41
- # Get source runs for a specific question with filters
42
- async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None,
43
- run_id: Optional[int] = None):
44
- async with get_async_connection() as conn:
45
- # Build query with filters
46
- query = """
47
- SELECT sr.*, sf.source_finder_type as finder_name
48
- FROM source_runs sr
49
- JOIN source_finders sf ON sr.source_finder_id = sf.id
50
- WHERE sr.question_id = $1 and sr.run_id = $2
51
- AND sr.source_finder_id = $3
52
- """
53
- params = [question_id, run_id, source_finder_id]
54
-
55
- query += " ORDER BY sr.rank DESC"
56
-
57
- sources = await conn.fetch(query, *params)
58
- return [dict(s) for s in sources]
59
-
60
 
61
  # Initialize data in a single async function
62
  async def initialize_data():
63
- global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, finder_labels
64
 
65
  questions = await get_questions()
66
  source_finders = await get_source_finders()
67
 
 
 
 
68
  # Convert to dictionaries for easier lookup
69
  questions_dict = {q["id"]: q["text"] for q in questions}
70
  source_finders_dict = {f["id"]: f["name"] for f in source_finders}
@@ -73,10 +39,12 @@ async def initialize_data():
73
  question_options = [f"{q['id']}: {q['text']}" for q in questions]
74
  finder_options = [str(f["id"]) for f in source_finders]
75
  finder_labels = {str(f["id"]): f["name"] for f in source_finders}
 
 
76
 
77
 
78
  # Main function to handle UI interactions
79
- def update_source_runs(question_option, source_finder_id, run_id):
80
  if not question_option:
81
  return None, [], "No question selected", None
82
 
@@ -86,35 +54,36 @@ def update_source_runs(question_option, source_finder_id, run_id):
86
  # Get run_ids for filtering - use asyncio.run for each independent operation
87
  available_run_ids = asyncio.run(get_run_ids(question_id))
88
  run_id_options = [str(r_id) for r_id in available_run_ids]
 
 
89
 
90
- # If the selected run_id is not in available options, reset it
91
- # if run_id not in run_id_options:
92
- # run_id = None
93
- #
94
- # # Convert run_id to int if not "All"
95
- run_id_int = available_run_ids[0]
96
  finder_id_int = None if len(source_finder_id) == 0 else int(source_finder_id)
 
 
 
97
 
 
 
98
  # Get source runs data
99
- source_runs = asyncio.run(get_source_runs(question_id, finder_id_int, run_id_int))
 
 
 
100
 
101
  if not source_runs:
102
- return None, run_id_options, "No results found for the selected filters"
103
 
104
- # Create DataFrame for display
105
- df = pd.DataFrame(source_runs)
106
 
107
  # Format table columns
108
- columns_to_display = ['finder_name', 'run_id', 'sugya_id', 'tractate', 'folio', 'rank', 'reason']
109
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
110
 
111
  # CSV for download
112
  # csv_data = df.to_csv(index=False)
113
 
114
  result_message = f"Found {len(source_runs)} results"
115
-
116
- return df_display, run_id_options, result_message,
117
-
118
 
119
 
120
  # Create Gradio app
@@ -128,31 +97,52 @@ async def main():
128
 
129
  with gr.Row():
130
  with gr.Column(scale=3):
131
- # Main content area
132
- question_dropdown = gr.Dropdown(
133
- choices=question_options,
134
- label="Select Question",
135
- value=None,
136
- interactive=True
137
- )
138
-
139
- run_id_dropdown = gr.Dropdown(
140
- choices=run_ids,
141
- value="1",
142
- allow_custom_value=True,
143
- label="Run ids for Question",
144
- interactive=True
145
- )
146
 
147
  with gr.Row():
148
- source_finder_dropdown = gr.Dropdown(
149
- choices=finder_options,
150
- label="Source Finder",
151
- interactive=True
152
- )
 
 
 
 
 
 
 
 
 
153
 
154
- result_text = gr.Markdown("Select a question to view source runs")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  results_table = gr.DataFrame(
157
  headers=['Source Finder', 'Run ID', 'Sugya ID', 'Tractate', 'Folio', 'Rank', 'Reason'],
158
  interactive=False
@@ -177,37 +167,40 @@ async def main():
177
  gr.Markdown("### Source Finders")
178
  for f in source_finders:
179
  gr.Markdown(f"**{f['id']}**: {f['name']}")
 
 
 
180
 
181
  # Set up event handlers
182
  question_dropdown.change(
183
- update_source_runs,
184
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown],
185
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
186
- outputs=[results_table, run_id_dropdown, result_text]
187
  )
188
 
189
  source_finder_dropdown.change(
190
- update_source_runs,
191
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown],
192
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
193
- outputs=[results_table, run_id_dropdown, result_text]
194
  )
195
 
196
  run_id_dropdown.change(
197
- update_source_runs,
198
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown],
199
- outputs=[results_table, run_id_dropdown, result_text]
200
  )
201
 
202
  # Initial load of data when question is selected
203
  question_dropdown.change(
204
- update_source_runs,
205
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown],
206
- outputs=[results_table, run_id_dropdown, result_text]
207
  )
208
 
209
  app.queue()
210
  app.launch()
211
 
212
  if __name__ == "__main__":
213
- asyncio.run(main())
 
1
  import asyncio
2
+
3
  import gradio as gr
4
  import pandas as pd
5
+ from data_access import get_pool, get_async_connection, close_pool, get_questions, get_source_finders, get_run_ids, \
6
+ get_source_runs, get_baseline_rankers, calculate_baseline_vs_source_stats_for_question, get_unified_sources
7
 
8
  # Initialize data at the module level
9
  questions = []
 
11
  questions_dict = {}
12
  source_finders_dict = {}
13
  question_options = []
14
+ baseline_rankers = []
15
+ baseline_rankers_dict = {}
16
+ baseline_ranker_options = []
17
  run_ids = []
18
  finder_options = []
19
  finder_labels = {"All": "All Source Finders"}
20
 
21
 
22
  # Get all questions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Initialize data in a single async function
25
  async def initialize_data():
26
+ global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, finder_labels, baseline_rankers, baseline_ranker_options
27
 
28
  questions = await get_questions()
29
  source_finders = await get_source_finders()
30
 
31
+ baseline_rankers = await get_baseline_rankers()
32
+ baseline_rankers_dict = {f["id"]: f["name"] for f in baseline_rankers}
33
+
34
  # Convert to dictionaries for easier lookup
35
  questions_dict = {q["id"]: q["text"] for q in questions}
36
  source_finders_dict = {f["id"]: f["name"] for f in source_finders}
 
39
  question_options = [f"{q['id']}: {q['text']}" for q in questions]
40
  finder_options = [str(f["id"]) for f in source_finders]
41
  finder_labels = {str(f["id"]): f["name"] for f in source_finders}
42
+ baseline_ranker_options = [f["id"] for f in baseline_rankers]
43
+ baseline_ranker_labels = {str(f["id"]): f["name"] for f in source_finders}
44
 
45
 
46
  # Main function to handle UI interactions
47
+ def update_sources_list(question_option, source_finder_id, baseline_ranker_id: str, run_id:str):
48
  if not question_option:
49
  return None, [], "No question selected", None
50
 
 
54
  # Get run_ids for filtering - use asyncio.run for each independent operation
55
  available_run_ids = asyncio.run(get_run_ids(question_id))
56
  run_id_options = [str(r_id) for r_id in available_run_ids]
57
+ if run_id not in run_id_options:
58
+ run_id = run_id_options[0]
59
 
60
+ run_id_int = int(run_id)
 
 
 
 
 
61
  finder_id_int = None if len(source_finder_id) == 0 else int(source_finder_id)
62
+ if type(baseline_ranker_id) == list:
63
+ baseline_ranker_id = baseline_ranker_id[0]
64
+ baseline_ranker_id_int = 1 if len(baseline_ranker_id) == 0 else int(baseline_ranker_id)
65
 
66
+ source_runs = None
67
+ stats = None
68
  # Get source runs data
69
+ if finder_id_int:
70
+ source_runs, stats = asyncio.run(get_unified_sources(question_id, finder_id_int, run_id_int, baseline_ranker_id_int))
71
+ # Create DataFrame for display
72
+ df = pd.DataFrame(source_runs)
73
 
74
  if not source_runs:
75
+ return None, None, run_id_options, "No results found for the selected filters",
76
 
 
 
77
 
78
  # Format table columns
79
+ columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate', 'folio', 'reason']
80
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
81
 
82
  # CSV for download
83
  # csv_data = df.to_csv(index=False)
84
 
85
  result_message = f"Found {len(source_runs)} results"
86
+ return df_display, stats, run_id_options, result_message,
 
 
87
 
88
 
89
  # Create Gradio app
 
97
 
98
  with gr.Row():
99
  with gr.Column(scale=3):
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ # Main content area
103
+ question_dropdown = gr.Dropdown(
104
+ choices=question_options,
105
+ label="Select Question",
106
+ value=None,
107
+ interactive=True
108
+ )
109
+ with gr.Column(scale=1):
110
+ baseline_rankers_dropdown = gr.Dropdown(
111
+ choices=baseline_ranker_options,
112
+ label="Select Baseline Ranker",
113
+ interactive=True
114
+ )
115
 
116
  with gr.Row():
117
+ with gr.Column(scale=1):
118
+ source_finder_dropdown = gr.Dropdown(
119
+ choices=finder_options,
120
+ label="Source Finder",
121
+ interactive=True
122
+ )
123
+ with gr.Column(scale=1):
124
+ run_id_dropdown = gr.Dropdown(
125
+ choices=run_ids,
126
+ value="1",
127
+ allow_custom_value=True,
128
+ label="Run id for Question and source finder",
129
+ interactive=True
130
+ )
131
 
 
132
 
133
+ result_text = gr.Markdown("Select a question to view source runs")
134
+ gr.Markdown("# Source Run Statistics")
135
+ statistics_table = gr.DataFrame(
136
+ headers=["num_high_ranked_baseline_sources",
137
+ "num_high_ranked_found_sources",
138
+ "overlap_count",
139
+ "overlap_percentage",
140
+ "high_ranked_overlap_count",
141
+ "high_ranked_overlap_percentage"
142
+ ],
143
+ interactive=False,
144
+ )
145
+ gr.Markdown("# Sources Found")
146
  results_table = gr.DataFrame(
147
  headers=['Source Finder', 'Run ID', 'Sugya ID', 'Tractate', 'Folio', 'Rank', 'Reason'],
148
  interactive=False
 
167
  gr.Markdown("### Source Finders")
168
  for f in source_finders:
169
  gr.Markdown(f"**{f['id']}**: {f['name']}")
170
+ gr.Markdown("### Baseline Source Rankers")
171
+ for f in baseline_rankers:
172
+ gr.Markdown(f"**{f['id']}**: {f['name']}")
173
 
174
  # Set up event handlers
175
  question_dropdown.change(
176
+ update_sources_list,
177
+ inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
178
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
179
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text]
180
  )
181
 
182
  source_finder_dropdown.change(
183
+ update_sources_list,
184
+ inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
185
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
186
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text]
187
  )
188
 
189
  run_id_dropdown.change(
190
+ update_sources_list,
191
+ inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
192
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text]
193
  )
194
 
195
  # Initial load of data when question is selected
196
  question_dropdown.change(
197
+ update_sources_list,
198
+ inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
199
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text]
200
  )
201
 
202
  app.queue()
203
  app.launch()
204
 
205
  if __name__ == "__main__":
206
+ asyncio.run(main())
data_access.py CHANGED
@@ -1,9 +1,12 @@
1
  import asyncio
2
  import os
3
  from contextlib import asynccontextmanager
 
4
 
5
  import asyncpg
 
6
  from dotenv import load_dotenv
 
7
 
8
  # Global connection pool
9
  _pool = None
@@ -53,3 +56,148 @@ async def close_pool():
53
  await _pool.close()
54
  _pool = None
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import os
3
  from contextlib import asynccontextmanager
4
+ from typing import Optional
5
 
6
  import asyncpg
7
+ import psycopg2
8
  from dotenv import load_dotenv
9
+ import pandas as pd
10
 
11
  # Global connection pool
12
  _pool = None
 
56
  await _pool.close()
57
  _pool = None
58
 
59
+
60
+ async def get_questions():
61
+ async with get_async_connection() as conn:
62
+ questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
63
+ return [{"id": q["id"], "text": q["question_text"]} for q in questions]
64
+
65
+
66
+ # Get distinct source finders
67
+ async def get_source_finders():
68
+ async with get_async_connection() as conn:
69
+ finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
70
+ return [{"id": f["id"], "name": f["name"]} for f in finders]
71
+
72
+
73
+ # Get distinct run IDs for a question
74
+ async def get_run_ids(question_id: int):
75
+ async with get_async_connection() as conn:
76
+ query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc"
77
+ params = [question_id]
78
+ run_ids = await conn.fetch(query, *params)
79
+ return [r["run_id"] for r in run_ids]
80
+
81
+
82
+ # Get source runs for a specific question with filters
83
+ async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None,
84
+ run_id: Optional[int] = None):
85
+ async with get_async_connection() as conn:
86
+ # Build query with filters
87
+ query = """
88
+ SELECT sr.*, sf.source_finder_type as finder_name
89
+ FROM source_runs sr
90
+ JOIN source_finders sf ON sr.source_finder_id = sf.id
91
+ WHERE sr.question_id = $1 and sr.run_id = $2
92
+ AND sr.source_finder_id = $3
93
+ """
94
+ params = [question_id, run_id, source_finder_id]
95
+
96
+ query += " ORDER BY sr.rank DESC"
97
+
98
+ sources = await conn.fetch(query, *params)
99
+ return [dict(s) for s in sources]
100
+
101
+ async def get_baseline_rankers():
102
+ async with get_async_connection() as conn:
103
+ rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
104
+ return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
105
+
106
+ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources):
107
+ # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
108
+ # e.g. overlap, high ranked overlap, etc.
109
+ async with get_async_connection() as conn:
110
+ actual_sources_set = {s["sugya_id"] for s in source_runs_sources}
111
+ baseline_sources_set = {s["sugya_id"] for s in baseline_sources}
112
+
113
+ # Calculate overlap
114
+ overlap = actual_sources_set.intersection(baseline_sources_set)
115
+ # only_in_1 = actual_sources_set - baseline_sources_set
116
+ # only_in_2 = baseline_sources_set - actual_sources_set
117
+
118
+ # Calculate high-ranked overlap (rank >= 4)
119
+ actual_high_ranked = {s["sugya_id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
120
+ baseline_high_ranked = {s["sugya_id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
121
+
122
+ high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
123
+
124
+ results = {
125
+ "total_baseline_sources": len(baseline_sources),
126
+ "total_found_sources": len(source_runs_sources),
127
+ "overlap_count": len(overlap),
128
+ "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
129
+ 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
130
+ "num_high_ranked_baseline_sources": len(baseline_high_ranked),
131
+ "num_high_ranked_found_sources": len(actual_high_ranked),
132
+ "high_ranked_overlap_count": len(high_ranked_overlap),
133
+ "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
134
+ }
135
+ #convert results to dataframe
136
+ results_df = pd.DataFrame([results])
137
+ return results_df
138
+
139
+
140
+ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
141
+ """
142
+ Create unified view of sources from both baseline_sources and source_runs
143
+ with indicators of where each source appears and their respective ranks.
144
+ """
145
+ async with get_async_connection() as conn:
146
+ # Get sources from source_runs
147
+ query_runs = """
148
+ SELECT sr.sugya_id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason
149
+ FROM source_runs sr
150
+ WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
151
+ """
152
+ source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
153
+
154
+ # Get sources from baseline_sources
155
+ query_baseline = """
156
+ SELECT bs.sugya_id, bs.rank as baseline_rank, bs.tractate, bs.folio
157
+ FROM baseline_sources bs
158
+ WHERE bs.question_id = $1 AND bs.ranker_id = $2
159
+ """
160
+ baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
161
+
162
+ stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
163
+
164
+ # Convert to dictionaries for easier lookup
165
+ source_runs_dict = {s["sugya_id"]: dict(s) for s in source_runs}
166
+ baseline_dict = {s["sugya_id"]: dict(s) for s in baseline_sources}
167
+
168
+ # Get all unique sugya_ids
169
+ all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
170
+
171
+ # Build unified results
172
+ unified_results = []
173
+ for sugya_id in all_sugya_ids:
174
+ in_source_run = sugya_id in source_runs_dict
175
+ in_baseline = sugya_id in baseline_dict
176
+ if in_baseline:
177
+ info = baseline_dict[sugya_id]
178
+ else:
179
+ info = source_runs_dict[sugya_id]
180
+ result = {
181
+ "sugya_id": sugya_id,
182
+ "tractate": info.get("tractate", "N/A"),
183
+ "folio": info.get("folio", "N/A"),
184
+ "in_baseline": "Yes" if in_baseline else "No",
185
+ "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
186
+ "in_source_run": "Yes" if in_source_run else "No",
187
+ "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
188
+ "source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A")
189
+ }
190
+ unified_results.append(result)
191
+
192
+
193
+ return unified_results, stats_df
194
+
195
+
196
+ def get_pg_sync_connection(schema="talmudexplore"):
197
+ conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
198
+ user=os.getenv("pg_user"),
199
+ password=os.getenv("pg_password"),
200
+ host=os.getenv("pg_host"),
201
+ port=os.getenv("pg_port"),
202
+ options=f"-c search_path={schema}")
203
+ return conn
eval_tables.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_access import get_pg_sync_connection
2
+
3
+ conn = get_pg_sync_connection()
4
+
5
+
6
+
7
+ def create_eval_database():
8
+ """Create SQLite database with a proper relational structure."""
9
+ # Connect to the database (creates it if it doesn't exist)
10
+ cursor = conn.cursor()
11
+
12
+ # Create questions table
13
+ cursor.execute('''
14
+ CREATE TABLE IF NOT EXISTS questions (
15
+ id SERIAL PRIMARY KEY,
16
+ question_text TEXT NOT NULL,
17
+ CONSTRAINT unique_question_text UNIQUE (question_text)
18
+ );
19
+ ''')
20
+
21
+ cursor.execute('''
22
+ CREATE TABLE IF NOT EXISTS rankers (
23
+ id SERIAL PRIMARY KEY,
24
+ ranker TEXT NOT NULL
25
+ );
26
+ ''')
27
+
28
+ # Create table for unique sources
29
+ cursor.execute('''
30
+ CREATE TABLE IF NOT EXISTS baseline_sources (
31
+ id SERIAL PRIMARY KEY,
32
+ question_id INTEGER NOT NULL,
33
+ tractate TEXT NOT NULL,
34
+ folio TEXT NOT NULL,
35
+ sugya_id TEXT NOT NULL,
36
+ rank INTEGER NOT NULL,
37
+ reason TEXT,
38
+ ranker_id INTEGER NOT NULL,
39
+ FOREIGN KEY (question_id) REFERENCES questions(id),
40
+ FOREIGN KEY (ranker_id) REFERENCES rankers(id),
41
+ CONSTRAINT unique_source_per_question_ranker UNIQUE(question_id, sugya_id, ranker_id)
42
+ );
43
+ ''')
44
+
45
+ cursor.execute('''
46
+ CREATE TABLE IF NOT EXISTS source_finders (
47
+ id SERIAL PRIMARY KEY,
48
+ source_finder_type TEXT NOT NULL,
49
+ description TEXT,
50
+ source_finder_version TEXT NOT NULL
51
+ );
52
+ ''')
53
+
54
+ # Create table for logging all sources from each run
55
+ cursor.execute('''
56
+ CREATE TABLE IF NOT EXISTS source_runs (
57
+ id SERIAL PRIMARY KEY,
58
+ source_finder_id INTEGER NOT NULL,
59
+ run_id TEXT NOT NULL,
60
+ run_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
61
+ question_id INTEGER NOT NULL,
62
+ tractate TEXT NOT NULL,
63
+ folio TEXT NOT NULL,
64
+ sugya_id TEXT NOT NULL,
65
+ rank INTEGER NOT NULL,
66
+ reason TEXT,
67
+ FOREIGN KEY (source_finder_id) REFERENCES source_finders(id),
68
+ FOREIGN KEY (question_id) REFERENCES questions(id)
69
+ );
70
+ ''')
71
+
72
+ conn.commit()
73
+ conn.close()
74
+
75
+ def load_source_finders():
76
+ cursor = conn.cursor()
77
+ for item in ["claude_sources", "keywords", "lenses"]:
78
+ cursor.execute("INSERT INTO source_finders (source_finder_type, source_finder_version) VALUES (%s, 1)", (item,))
79
+ conn.commit()
80
+
81
+ def load_rankers():
82
+ cursor = conn.cursor()
83
+ for item in ["claude_sources"]:
84
+ cursor.execute("INSERT INTO rankers (ranker) VALUES (%s)", (item,))
85
+ conn.commit()
86
+
87
+ def load_baseline_sources():
88
+ # copy all claude values where run_id = 1 from source_runs to baseline_sources
89
+ cursor = conn.cursor()
90
+ cursor.execute('''
91
+ INSERT INTO baseline_sources (question_id, tractate, folio, sugya_id, rank, reason, ranker_id)
92
+ SELECT question_id, tractate, folio, sugya_id, rank, reason, 1
93
+ FROM source_runs
94
+ WHERE run_id = 1 and source_finder_id = 1
95
+ ''')
96
+ conn.commit()
97
+
98
+
99
+
100
+ if __name__ == '__main__':
101
+ # Create the database
102
+ # create_eval_database()
103
+ load_baseline_sources()
104
+
105
+
106
+
load_ground_truth.py ADDED
File without changes
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  asyncpg
2
  gradio
3
- dotenv
 
 
1
  asyncpg
2
  gradio
3
+ dotenv
4
+ psycopg2
tests/__init__.py ADDED
File without changes
tests/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pytest
2
+ pytest-asyncio
tests/test_db_layer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from data_access import get_unified_sources
4
+
5
+
6
+ @pytest.mark.asyncio
7
+ async def test_get_unified_sources():
8
+ results, stats = await get_unified_sources(2, 2, 1, 1)
9
+ assert results is not None
10
+ assert stats is not None
11
+
12
+ # Check number of rows in results list
13
+ assert len(results) > 4, "Results should contain at least one row"
14
+
15
+ # Check number of rows in stats DataFrame
16
+ assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
17
+
18
+ # You can also check specific stats columns
19
+ assert "overlap_count" in stats.columns, "Stats should contain overlap_count"