davidr70 commited on
Commit
0a408c8
·
1 Parent(s): 5cca310

fix to have baseline run from the runs table

Browse files
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import asyncio
 
2
 
3
  import gradio as gr
4
  import pandas as pd
5
- import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
  get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
@@ -10,6 +10,8 @@ from data_access import get_questions, get_source_finders, get_run_ids, get_base
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
  # Initialize data at the module level
14
  questions = []
15
  source_finders = []
@@ -22,9 +24,11 @@ run_ids = []
22
  available_run_id_dict = {}
23
  finder_options = []
24
  previous_run_id = "initial_run"
 
25
 
26
  run_id_dropdown = None
27
 
 
28
  # Get all questions
29
 
30
  # Initialize data in a single async function
@@ -36,7 +40,6 @@ async def initialize_data():
36
  source_finders = await get_source_finders(conn)
37
  baseline_rankers = await get_baseline_rankers(conn)
38
 
39
- baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
40
  # Convert to dictionaries for easier lookup
41
  questions_dict = {q["text"]: q["id"] for q in questions}
42
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
@@ -46,9 +49,32 @@ async def initialize_data():
46
  question_options = [q['text'] for q in questions]
47
  finder_options = [s["name"] for s in source_finders]
48
  baseline_ranker_options = [b["name"] for b in baseline_rankers]
 
 
49
 
 
 
50
 
51
- def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, evt: gr.EventData = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  global previous_run_id
53
  if evt:
54
  logger.info(f"event: {evt.target.elem_id}")
@@ -70,27 +96,30 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
70
  if type(baseline_ranker_name) == list:
71
  baseline_ranker_name = baseline_ranker_name[0]
72
 
73
- baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
 
74
 
75
  if len(source_finder_name):
76
  finder_id_int = source_finders_dict.get(source_finder_name)
77
  else:
78
  finder_id_int = None
79
 
80
- if question_option == "All questions":
81
  if finder_id_int:
82
  if run_id is None:
83
  available_run_id_dict = await get_run_ids(conn, finder_id_int)
84
  run_id = list(available_run_id_dict.keys())[0]
85
  previous_run_id = run_id
86
  run_id_int = available_run_id_dict.get(run_id)
87
- all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
 
88
 
89
  else:
90
  run_id_options = list(available_run_id_dict.keys())
91
  all_stats = None
92
  run_id_options = list(available_run_id_dict.keys())
93
- return None, all_stats, gr.Dropdown(choices=run_id_options, value=run_id), "Select Run Id and source finder to see results", ""
 
94
 
95
  # Extract question ID from selection
96
  question_id = questions_dict.get(question_option)
@@ -102,8 +131,6 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
102
  previous_run_id = run_id
103
  run_id_int = available_run_id_dict.get(run_id)
104
 
105
-
106
-
107
  source_runs = None
108
  stats = None
109
  # Get source runs data
@@ -116,7 +143,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
116
  return None, None, run_id_options, "No results found for the selected filters",
117
 
118
  # Format table columns
119
- columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
 
120
  'folio', 'reason']
121
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
122
 
@@ -147,6 +175,7 @@ async def handle_row_selection_async(evt: gr.SelectData):
147
  def handle_row_selection(evt: gr.SelectData):
148
  return asyncio.run(handle_row_selection_async(evt))
149
 
 
150
  # Create Gradio app
151
 
152
  # Ensure we clean up when done
@@ -162,7 +191,7 @@ async def main():
162
  with gr.Column(scale=1):
163
  # Main content area
164
  question_dropdown = gr.Dropdown(
165
- choices=["All questions"] + question_options,
166
  label="Select Question",
167
  value=None,
168
  interactive=True,
@@ -186,7 +215,7 @@ async def main():
186
  )
187
  with gr.Column(scale=1):
188
  run_id_dropdown = gr.Dropdown(
189
- choices=[],
190
  allow_custom_value=True,
191
  label="Run id for Question and source finder",
192
  interactive=True,
@@ -201,7 +230,6 @@ async def main():
201
  gr.Markdown(f"Total Questions: {len(questions)}")
202
  gr.Markdown(f"Source Finders: {len(source_finders)}")
203
 
204
-
205
  with gr.Row():
206
  result_text = gr.Markdown("Select a question to view source runs")
207
  with gr.Row():
@@ -221,14 +249,15 @@ async def main():
221
  metadata_text = gr.TextArea(
222
  label="Metadata of Source Finder for Selected Question",
223
  elem_id="metadata",
224
- lines = 2
225
  )
226
  with gr.Row():
227
  gr.Markdown("# Sources Found")
228
  with gr.Row():
229
  with gr.Column(scale=3):
230
  results_table = gr.DataFrame(
231
- headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'source_reason', 'metadata'],
 
232
  interactive=False
233
  )
234
  with gr.Column(scale=1):
@@ -246,8 +275,6 @@ async def main():
246
  # visible=True
247
  # )
248
 
249
-
250
-
251
  # Set up event handlers
252
  results_table.select(
253
  handle_row_selection,
@@ -255,15 +282,22 @@ async def main():
255
  outputs=source_text
256
  )
257
 
258
- question_dropdown.change(
259
  update_sources_list,
260
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
261
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
 
262
  )
263
 
264
- source_finder_dropdown.change(
265
  update_sources_list,
266
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
 
 
 
 
 
 
267
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
268
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
269
  )
@@ -274,7 +308,6 @@ async def main():
274
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
275
  )
276
 
277
-
278
  app.queue()
279
  app.launch()
280
 
 
1
  import asyncio
2
+ import logging
3
 
4
  import gradio as gr
5
  import pandas as pd
 
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
  get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ ALL_QUESTIONS_STR = "All questions"
14
+
15
  # Initialize data at the module level
16
  questions = []
17
  source_finders = []
 
24
  available_run_id_dict = {}
25
  finder_options = []
26
  previous_run_id = "initial_run"
27
+ run_id_options = []
28
 
29
  run_id_dropdown = None
30
 
31
+
32
  # Get all questions
33
 
34
  # Initialize data in a single async function
 
40
  source_finders = await get_source_finders(conn)
41
  baseline_rankers = await get_baseline_rankers(conn)
42
 
 
43
  # Convert to dictionaries for easier lookup
44
  questions_dict = {q["text"]: q["id"] for q in questions}
45
  baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
 
49
  question_options = [q['text'] for q in questions]
50
  finder_options = [s["name"] for s in source_finders]
51
  baseline_ranker_options = [b["name"] for b in baseline_rankers]
52
+ update_run_ids(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0])
53
+
54
 
55
+ def update_run_ids(question_option, source_finder_name):
56
+ return asyncio.run(update_run_ids_async(question_option, source_finder_name))
57
 
58
+
59
+ async def update_run_ids_async(question_option, source_finder_name):
60
+ global previous_run_id, available_run_id_dict, run_id_options
61
+ async with get_async_connection() as conn:
62
+ finder_id_int = source_finders_dict.get(source_finder_name)
63
+ if question_option and question_option != ALL_QUESTIONS_STR:
64
+ question_id = questions_dict.get(question_option)
65
+ available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
66
+ else:
67
+ available_run_id_dict = await get_run_ids(conn, finder_id_int)
68
+
69
+
70
+ run_id = list(available_run_id_dict.keys())[0]
71
+ previous_run_id = run_id
72
+ run_id_options = list(available_run_id_dict.keys())
73
+ return None, None, gr.Dropdown(choices=run_id_options,
74
+ value=run_id), "Select Question to see results", ""
75
+
76
+ def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
77
+ evt: gr.EventData = None):
78
  global previous_run_id
79
  if evt:
80
  logger.info(f"event: {evt.target.elem_id}")
 
96
  if type(baseline_ranker_name) == list:
97
  baseline_ranker_name = baseline_ranker_name[0]
98
 
99
+ baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(
100
+ baseline_ranker_name)
101
 
102
  if len(source_finder_name):
103
  finder_id_int = source_finders_dict.get(source_finder_name)
104
  else:
105
  finder_id_int = None
106
 
107
+ if question_option == ALL_QUESTIONS_STR:
108
  if finder_id_int:
109
  if run_id is None:
110
  available_run_id_dict = await get_run_ids(conn, finder_id_int)
111
  run_id = list(available_run_id_dict.keys())[0]
112
  previous_run_id = run_id
113
  run_id_int = available_run_id_dict.get(run_id)
114
+ all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int,
115
+ baseline_ranker_id_int)
116
 
117
  else:
118
  run_id_options = list(available_run_id_dict.keys())
119
  all_stats = None
120
  run_id_options = list(available_run_id_dict.keys())
121
+ return None, all_stats, gr.Dropdown(choices=run_id_options,
122
+ value=run_id), "Select Run Id and source finder to see results", ""
123
 
124
  # Extract question ID from selection
125
  question_id = questions_dict.get(question_option)
 
131
  previous_run_id = run_id
132
  run_id_int = available_run_id_dict.get(run_id)
133
 
 
 
134
  source_runs = None
135
  stats = None
136
  # Get source runs data
 
143
  return None, None, run_id_options, "No results found for the selected filters",
144
 
145
  # Format table columns
146
+ columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
147
+ 'tractate',
148
  'folio', 'reason']
149
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
150
 
 
175
  def handle_row_selection(evt: gr.SelectData):
176
  return asyncio.run(handle_row_selection_async(evt))
177
 
178
+
179
  # Create Gradio app
180
 
181
  # Ensure we clean up when done
 
191
  with gr.Column(scale=1):
192
  # Main content area
193
  question_dropdown = gr.Dropdown(
194
+ choices=[ALL_QUESTIONS_STR] + question_options,
195
  label="Select Question",
196
  value=None,
197
  interactive=True,
 
215
  )
216
  with gr.Column(scale=1):
217
  run_id_dropdown = gr.Dropdown(
218
+ choices=run_id_options,
219
  allow_custom_value=True,
220
  label="Run id for Question and source finder",
221
  interactive=True,
 
230
  gr.Markdown(f"Total Questions: {len(questions)}")
231
  gr.Markdown(f"Source Finders: {len(source_finders)}")
232
 
 
233
  with gr.Row():
234
  result_text = gr.Markdown("Select a question to view source runs")
235
  with gr.Row():
 
249
  metadata_text = gr.TextArea(
250
  label="Metadata of Source Finder for Selected Question",
251
  elem_id="metadata",
252
+ lines=2
253
  )
254
  with gr.Row():
255
  gr.Markdown("# Sources Found")
256
  with gr.Row():
257
  with gr.Column(scale=3):
258
  results_table = gr.DataFrame(
259
+ headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run',
260
+ 'source_run_rank', 'source_reason', 'metadata'],
261
  interactive=False
262
  )
263
  with gr.Column(scale=1):
 
275
  # visible=True
276
  # )
277
 
 
 
278
  # Set up event handlers
279
  results_table.select(
280
  handle_row_selection,
 
282
  outputs=source_text
283
  )
284
 
285
+ baseline_rankers_dropdown.change(
286
  update_sources_list,
287
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
288
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
289
+
290
  )
291
 
292
+ question_dropdown.change(
293
  update_sources_list,
294
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
295
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
296
+ )
297
+
298
+ source_finder_dropdown.change(
299
+ update_run_ids,
300
+ inputs=[question_dropdown, source_finder_dropdown],
301
  # outputs=[run_id_dropdown, results_table, result_text, download_button]
302
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
303
  )
 
308
  outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
309
  )
310
 
 
311
  app.queue()
312
  app.launch()
313
 
data_access.py CHANGED
@@ -15,6 +15,7 @@ load_dotenv()
15
  @asynccontextmanager
16
  async def get_async_connection(schema="talmudexplore"):
17
  """Get a connection for the current request."""
 
18
  try:
19
  # Create a single connection without relying on a shared pool
20
  conn = await asyncpg.connect(
@@ -27,7 +28,8 @@ async def get_async_connection(schema="talmudexplore"):
27
  await conn.execute(f'SET search_path TO {schema}')
28
  yield conn
29
  finally:
30
- await conn.close()
 
31
 
32
 
33
  async def get_questions(conn: asyncpg.Connection):
@@ -73,8 +75,13 @@ async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_
73
 
74
 
75
  async def get_baseline_rankers(conn: asyncpg.Connection):
76
- rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
77
- return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
 
 
 
 
 
78
 
79
  async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
80
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
@@ -203,14 +210,8 @@ async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source
203
  """
204
  source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
205
  # Get sources from baseline_sources
206
- query_baseline = """
207
- SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
208
- FROM baseline_sources bs
209
- join talmud_bavli tb on bs.sugya_id = tb.xml_id
210
- WHERE bs.question_id = $1
211
- AND bs.ranker_id = $2
212
- """
213
- baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
214
  stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
215
  # Convert to dictionaries for easier lookup
216
  source_runs_dict = {s["id"]: dict(s) for s in source_runs}
 
15
  @asynccontextmanager
16
  async def get_async_connection(schema="talmudexplore"):
17
  """Get a connection for the current request."""
18
+ conn = None
19
  try:
20
  # Create a single connection without relying on a shared pool
21
  conn = await asyncpg.connect(
 
28
  await conn.execute(f'SET search_path TO {schema}')
29
  yield conn
30
  finally:
31
+ if conn:
32
+ await conn.close()
33
 
34
 
35
  async def get_questions(conn: asyncpg.Connection):
 
75
 
76
 
77
  async def get_baseline_rankers(conn: asyncpg.Connection):
78
+ query = """
79
+ select sfr.id, sf.source_finder_type, sfr.description from talmudexplore.source_finder_runs sfr
80
+ join source_finders sf on sf.id = sfr.source_finder_id
81
+ order by sf.id
82
+ """
83
+ rankers = await conn.fetch(query)
84
+ return [{"id": r["id"], "name": f"{r['source_finder_type']} : {r['description']}"} for r in rankers]
85
 
86
  async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
87
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
 
210
  """
211
  source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
212
  # Get sources from baseline_sources
213
+ baseline_query = query_runs.replace("source_rank", "baseline_rank")
214
+ baseline_sources = await conn.fetch(baseline_query, question_id, ranker_id)
 
 
 
 
 
 
215
  stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
216
  # Convert to dictionaries for easier lookup
217
  source_runs_dict = {s["id"]: dict(s) for s in source_runs}
load_ground_truth.py DELETED
File without changes
eval_tables.py → scripts/eval_tables.py RENAMED
File without changes