davidr70 commited on
Commit
5f4f31d
·
1 Parent(s): 09d4cda

add ability to see text

Browse files
Files changed (2) hide show
  1. app.py +38 -5
  2. data_access.py +26 -13
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
- get_unified_sources
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -105,6 +105,24 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
105
  return df_display, stats, run_id_options, result_message,
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Create Gradio app
109
 
110
  # Ensure we clean up when done
@@ -180,10 +198,19 @@ async def main():
180
  with gr.Row():
181
  gr.Markdown("# Sources Found")
182
  with gr.Row():
183
- results_table = gr.DataFrame(
184
- headers=['Source Finder', 'Run ID', 'Sugya ID', 'Tractate', 'Folio', 'Rank', 'Reason'],
185
- interactive=False
186
- )
 
 
 
 
 
 
 
 
 
187
 
188
  # download_button = gr.DownloadButton(
189
  # label="Download Results as CSV",
@@ -194,6 +221,12 @@ async def main():
194
 
195
 
196
  # Set up event handlers
 
 
 
 
 
 
197
  question_dropdown.change(
198
  update_sources_list,
199
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
 
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
+ get_unified_sources, get_source_text
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
105
  return df_display, stats, run_id_options, result_message,
106
 
107
 
108
+ # Add a new function to handle row selection
109
+ async def handle_row_selection_async(evt: gr.SelectData):
110
+ if evt is None or evt.value is None:
111
+ return "No source selected"
112
+
113
+ try:
114
+ # Get the ID from the selected row
115
+ tractate_chunk_id = evt.row_value[0]
116
+ # Get the source text
117
+ text = await get_source_text(tractate_chunk_id)
118
+ return text
119
+ except Exception as e:
120
+ return f"Error retrieving source text: {str(e)}"
121
+
122
+
123
+ def handle_row_selection(evt: gr.SelectData):
124
+ return asyncio.run(handle_row_selection_async(evt))
125
+
126
  # Create Gradio app
127
 
128
  # Ensure we clean up when done
 
198
  with gr.Row():
199
  gr.Markdown("# Sources Found")
200
  with gr.Row():
201
+ with gr.Column(scale=3):
202
+ results_table = gr.DataFrame(
203
+ headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'source_reason'],
204
+ interactive=False
205
+ )
206
+ with gr.Column(scale=1):
207
+ source_text = gr.TextArea(
208
+ value="Text of the source will appear here",
209
+ lines=15,
210
+ label="Source Text",
211
+ interactive=False,
212
+ elem_id="source_text"
213
+ )
214
 
215
  # download_button = gr.DownloadButton(
216
  # label="Download Results as CSV",
 
221
 
222
 
223
  # Set up event handlers
224
+ results_table.select(
225
+ handle_row_selection,
226
+ inputs=None,
227
+ outputs=source_text
228
+ )
229
+
230
  question_dropdown.change(
231
  update_sources_list,
232
  inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
data_access.py CHANGED
@@ -80,8 +80,8 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
80
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
81
  # e.g. overlap, high ranked overlap, etc.
82
  async with get_async_connection() as conn:
83
- actual_sources_set = {s["sugya_id"] for s in source_runs_sources}
84
- baseline_sources_set = {s["sugya_id"] for s in baseline_sources}
85
 
86
  # Calculate overlap
87
  overlap = actual_sources_set.intersection(baseline_sources_set)
@@ -89,8 +89,8 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
89
  # only_in_2 = baseline_sources_set - actual_sources_set
90
 
91
  # Calculate high-ranked overlap (rank >= 4)
92
- actual_high_ranked = {s["sugya_id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
93
- baseline_high_ranked = {s["sugya_id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
94
 
95
  high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
96
 
@@ -118,16 +118,16 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
118
  async with get_async_connection() as conn:
119
  # Get sources from source_runs
120
  query_runs = """
121
- SELECT sr.sugya_id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason
122
- FROM source_runs sr
123
  WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
124
  """
125
  source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
126
 
127
  # Get sources from baseline_sources
128
  query_baseline = """
129
- SELECT bs.sugya_id, bs.rank as baseline_rank, bs.tractate, bs.folio
130
- FROM baseline_sources bs
131
  WHERE bs.question_id = $1 AND bs.ranker_id = $2
132
  """
133
  baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
@@ -135,8 +135,8 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
135
  stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
136
 
137
  # Convert to dictionaries for easier lookup
138
- source_runs_dict = {s["sugya_id"]: dict(s) for s in source_runs}
139
- baseline_dict = {s["sugya_id"]: dict(s) for s in baseline_sources}
140
 
141
  # Get all unique sugya_ids
142
  all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
@@ -151,9 +151,9 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
151
  else:
152
  info = source_runs_dict[sugya_id]
153
  result = {
154
- "sugya_id": sugya_id,
155
- "tractate": info.get("tractate", "N/A"),
156
- "folio": info.get("folio", "N/A"),
157
  "in_baseline": "Yes" if in_baseline else "No",
158
  "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
159
  "in_source_run": "Yes" if in_source_run else "No",
@@ -166,6 +166,19 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
166
  return unified_results, stats_df
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def get_pg_sync_connection(schema="talmudexplore"):
170
  conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
171
  user=os.getenv("pg_user"),
 
80
  # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
81
  # e.g. overlap, high ranked overlap, etc.
82
  async with get_async_connection() as conn:
83
+ actual_sources_set = {s["id"] for s in source_runs_sources}
84
+ baseline_sources_set = {s["id"] for s in baseline_sources}
85
 
86
  # Calculate overlap
87
  overlap = actual_sources_set.intersection(baseline_sources_set)
 
89
  # only_in_2 = baseline_sources_set - actual_sources_set
90
 
91
  # Calculate high-ranked overlap (rank >= 4)
92
+ actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
93
+ baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
94
 
95
  high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
96
 
 
118
  async with get_async_connection() as conn:
119
  # Get sources from source_runs
120
  query_runs = """
121
+ SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason
122
+ FROM source_runs sr join talmud_bavli tb on sr.sugya_id = tb.xml_id
123
  WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
124
  """
125
  source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
126
 
127
  # Get sources from baseline_sources
128
  query_baseline = """
129
+ SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
130
+ FROM baseline_sources bs join talmud_bavli tb on bs.sugya_id = tb.xml_id
131
  WHERE bs.question_id = $1 AND bs.ranker_id = $2
132
  """
133
  baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
 
135
  stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
136
 
137
  # Convert to dictionaries for easier lookup
138
+ source_runs_dict = {s["id"]: dict(s) for s in source_runs}
139
+ baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
140
 
141
  # Get all unique sugya_ids
142
  all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
 
151
  else:
152
  info = source_runs_dict[sugya_id]
153
  result = {
154
+ "id": sugya_id,
155
+ "tractate": info.get("tractate"),
156
+ "folio": info.get("folio"),
157
  "in_baseline": "Yes" if in_baseline else "No",
158
  "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
159
  "in_source_run": "Yes" if in_source_run else "No",
 
166
  return unified_results, stats_df
167
 
168
 
169
+ async def get_source_text(tractate_chunk_id: int):
170
+ """
171
+ Retrieves the text content for a given tractate chunk ID.
172
+ """
173
+ async with get_async_connection() as conn:
174
+ query = """
175
+ SELECT tb.text_with_nikud as text
176
+ FROM talmud_bavli tb
177
+ WHERE tb.tractate_chunk_id = $1
178
+ """
179
+ result = await conn.fetchrow(query, tractate_chunk_id)
180
+ return result["text"] if result else "Source text not found"
181
+
182
  def get_pg_sync_connection(schema="talmudexplore"):
183
  conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
184
  user=os.getenv("pg_user"),