Spaces:
Running
Running
add ability to see text
Browse files- app.py +38 -5
- data_access.py +26 -13
app.py
CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
-
get_unified_sources
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
@@ -105,6 +105,24 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
|
|
105 |
return df_display, stats, run_id_options, result_message,
|
106 |
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# Create Gradio app
|
109 |
|
110 |
# Ensure we clean up when done
|
@@ -180,10 +198,19 @@ async def main():
|
|
180 |
with gr.Row():
|
181 |
gr.Markdown("# Sources Found")
|
182 |
with gr.Row():
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# download_button = gr.DownloadButton(
|
189 |
# label="Download Results as CSV",
|
@@ -194,6 +221,12 @@ async def main():
|
|
194 |
|
195 |
|
196 |
# Set up event handlers
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
question_dropdown.change(
|
198 |
update_sources_list,
|
199 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
+
get_unified_sources, get_source_text
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
|
|
105 |
return df_display, stats, run_id_options, result_message,
|
106 |
|
107 |
|
108 |
+
# Add a new function to handle row selection
|
109 |
+
async def handle_row_selection_async(evt: gr.SelectData):
|
110 |
+
if evt is None or evt.value is None:
|
111 |
+
return "No source selected"
|
112 |
+
|
113 |
+
try:
|
114 |
+
# Get the ID from the selected row
|
115 |
+
tractate_chunk_id = evt.row_value[0]
|
116 |
+
# Get the source text
|
117 |
+
text = await get_source_text(tractate_chunk_id)
|
118 |
+
return text
|
119 |
+
except Exception as e:
|
120 |
+
return f"Error retrieving source text: {str(e)}"
|
121 |
+
|
122 |
+
|
123 |
+
def handle_row_selection(evt: gr.SelectData):
|
124 |
+
return asyncio.run(handle_row_selection_async(evt))
|
125 |
+
|
126 |
# Create Gradio app
|
127 |
|
128 |
# Ensure we clean up when done
|
|
|
198 |
with gr.Row():
|
199 |
gr.Markdown("# Sources Found")
|
200 |
with gr.Row():
|
201 |
+
with gr.Column(scale=3):
|
202 |
+
results_table = gr.DataFrame(
|
203 |
+
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'source_reason'],
|
204 |
+
interactive=False
|
205 |
+
)
|
206 |
+
with gr.Column(scale=1):
|
207 |
+
source_text = gr.TextArea(
|
208 |
+
value="Text of the source will appear here",
|
209 |
+
lines=15,
|
210 |
+
label="Source Text",
|
211 |
+
interactive=False,
|
212 |
+
elem_id="source_text"
|
213 |
+
)
|
214 |
|
215 |
# download_button = gr.DownloadButton(
|
216 |
# label="Download Results as CSV",
|
|
|
221 |
|
222 |
|
223 |
# Set up event handlers
|
224 |
+
results_table.select(
|
225 |
+
handle_row_selection,
|
226 |
+
inputs=None,
|
227 |
+
outputs=source_text
|
228 |
+
)
|
229 |
+
|
230 |
question_dropdown.change(
|
231 |
update_sources_list,
|
232 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
data_access.py
CHANGED
@@ -80,8 +80,8 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
|
|
80 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
81 |
# e.g. overlap, high ranked overlap, etc.
|
82 |
async with get_async_connection() as conn:
|
83 |
-
actual_sources_set = {s["
|
84 |
-
baseline_sources_set = {s["
|
85 |
|
86 |
# Calculate overlap
|
87 |
overlap = actual_sources_set.intersection(baseline_sources_set)
|
@@ -89,8 +89,8 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
|
|
89 |
# only_in_2 = baseline_sources_set - actual_sources_set
|
90 |
|
91 |
# Calculate high-ranked overlap (rank >= 4)
|
92 |
-
actual_high_ranked = {s["
|
93 |
-
baseline_high_ranked = {s["
|
94 |
|
95 |
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
96 |
|
@@ -118,16 +118,16 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
|
|
118 |
async with get_async_connection() as conn:
|
119 |
# Get sources from source_runs
|
120 |
query_runs = """
|
121 |
-
SELECT
|
122 |
-
FROM source_runs sr
|
123 |
WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
|
124 |
"""
|
125 |
source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
|
126 |
|
127 |
# Get sources from baseline_sources
|
128 |
query_baseline = """
|
129 |
-
SELECT
|
130 |
-
FROM baseline_sources bs
|
131 |
WHERE bs.question_id = $1 AND bs.ranker_id = $2
|
132 |
"""
|
133 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
@@ -135,8 +135,8 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
|
|
135 |
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
136 |
|
137 |
# Convert to dictionaries for easier lookup
|
138 |
-
source_runs_dict = {s["
|
139 |
-
baseline_dict = {s["
|
140 |
|
141 |
# Get all unique sugya_ids
|
142 |
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
@@ -151,9 +151,9 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
|
|
151 |
else:
|
152 |
info = source_runs_dict[sugya_id]
|
153 |
result = {
|
154 |
-
"
|
155 |
-
"tractate": info.get("tractate"
|
156 |
-
"folio": info.get("folio"
|
157 |
"in_baseline": "Yes" if in_baseline else "No",
|
158 |
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
|
159 |
"in_source_run": "Yes" if in_source_run else "No",
|
@@ -166,6 +166,19 @@ async def get_unified_sources(question_id: int, source_finder_id: int, run_id: i
|
|
166 |
return unified_results, stats_df
|
167 |
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def get_pg_sync_connection(schema="talmudexplore"):
|
170 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
171 |
user=os.getenv("pg_user"),
|
|
|
80 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
81 |
# e.g. overlap, high ranked overlap, etc.
|
82 |
async with get_async_connection() as conn:
|
83 |
+
actual_sources_set = {s["id"] for s in source_runs_sources}
|
84 |
+
baseline_sources_set = {s["id"] for s in baseline_sources}
|
85 |
|
86 |
# Calculate overlap
|
87 |
overlap = actual_sources_set.intersection(baseline_sources_set)
|
|
|
89 |
# only_in_2 = baseline_sources_set - actual_sources_set
|
90 |
|
91 |
# Calculate high-ranked overlap (rank >= 4)
|
92 |
+
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
|
93 |
+
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
|
94 |
|
95 |
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
96 |
|
|
|
118 |
async with get_async_connection() as conn:
|
119 |
# Get sources from source_runs
|
120 |
query_runs = """
|
121 |
+
SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason
|
122 |
+
FROM source_runs sr join talmud_bavli tb on sr.sugya_id = tb.xml_id
|
123 |
WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
|
124 |
"""
|
125 |
source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
|
126 |
|
127 |
# Get sources from baseline_sources
|
128 |
query_baseline = """
|
129 |
+
SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio
|
130 |
+
FROM baseline_sources bs join talmud_bavli tb on bs.sugya_id = tb.xml_id
|
131 |
WHERE bs.question_id = $1 AND bs.ranker_id = $2
|
132 |
"""
|
133 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
|
|
135 |
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
136 |
|
137 |
# Convert to dictionaries for easier lookup
|
138 |
+
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
139 |
+
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
140 |
|
141 |
# Get all unique sugya_ids
|
142 |
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
|
|
|
151 |
else:
|
152 |
info = source_runs_dict[sugya_id]
|
153 |
result = {
|
154 |
+
"id": sugya_id,
|
155 |
+
"tractate": info.get("tractate"),
|
156 |
+
"folio": info.get("folio"),
|
157 |
"in_baseline": "Yes" if in_baseline else "No",
|
158 |
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
|
159 |
"in_source_run": "Yes" if in_source_run else "No",
|
|
|
166 |
return unified_results, stats_df
|
167 |
|
168 |
|
169 |
+
async def get_source_text(tractate_chunk_id: int):
|
170 |
+
"""
|
171 |
+
Retrieves the text content for a given tractate chunk ID.
|
172 |
+
"""
|
173 |
+
async with get_async_connection() as conn:
|
174 |
+
query = """
|
175 |
+
SELECT tb.text_with_nikud as text
|
176 |
+
FROM talmud_bavli tb
|
177 |
+
WHERE tb.tractate_chunk_id = $1
|
178 |
+
"""
|
179 |
+
result = await conn.fetchrow(query, tractate_chunk_id)
|
180 |
+
return result["text"] if result else "Source text not found"
|
181 |
+
|
182 |
def get_pg_sync_connection(schema="talmudexplore"):
|
183 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
184 |
user=os.getenv("pg_user"),
|