eval_results / app.py
davidr70's picture
add ability to see text
5f4f31d
raw
history blame
9.72 kB
import asyncio
import gradio as gr
import pandas as pd
import logging
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
get_unified_sources, get_source_text
logger = logging.getLogger(__name__)
# Initialize data at the module level
questions = []
source_finders = []
questions_dict = {}
source_finders_dict = {}
question_options = []
baseline_rankers_dict = {}
baseline_ranker_options = []
run_ids = []
finder_options = []
previous_run_id = 1
run_id_dropdown = None
# Get all questions
# Initialize data in a single async function
async def initialize_data():
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
questions = await get_questions()
source_finders = await get_source_finders()
baseline_rankers = await get_baseline_rankers()
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
# Convert to dictionaries for easier lookup
questions_dict = {q["text"]: q["id"] for q in questions}
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
source_finders_dict = {f["name"]: f["id"] for f in source_finders}
# Create formatted options for dropdowns
question_options = [q['text'] for q in questions]
finder_options = [s["name"] for s in source_finders]
baseline_ranker_options = [b["name"] for b in baseline_rankers]
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, evt: gr.EventData = None):
global previous_run_id
if evt:
logger.info(f"event: {evt.target.elem_id}")
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
return gr.skip(), gr.skip(), gr.skip(), gr.skip()
if type(run_id) == str:
previous_run_id = run_id
return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id))
# Main function to handle UI interactions
async def update_sources_list_async(question_option, source_finder_name, run_id:str, baseline_ranker_name: str):
if not question_option:
return gr.skip(), gr.skip(), gr.skip(), "No question selected"
logger.info("processing update")
# Extract question ID from selection
question_id = questions_dict.get(question_option)
available_run_ids = await get_run_ids(question_id)
run_id_options = [str(r_id) for r_id in available_run_ids]
if run_id not in run_id_options:
run_id = run_id_options[0]
run_id_int = int(run_id)
if len(source_finder_name):
finder_id_int = source_finders_dict.get(source_finder_name)
else:
finder_id_int = None
if type(baseline_ranker_name) == list:
baseline_ranker_name = baseline_ranker_name[0]
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
source_runs = None
stats = None
# Get source runs data
if finder_id_int:
source_runs, stats = await get_unified_sources(question_id, finder_id_int, run_id_int, baseline_ranker_id_int)
# Create DataFrame for display
df = pd.DataFrame(source_runs)
if not source_runs:
return None, None, run_id_options, "No results found for the selected filters",
# Format table columns
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
'folio', 'reason']
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
# CSV for download
# csv_data = df.to_csv(index=False)
result_message = f"Found {len(source_runs)} results"
return df_display, stats, run_id_options, result_message,
# Add a new function to handle row selection
async def handle_row_selection_async(evt: gr.SelectData):
if evt is None or evt.value is None:
return "No source selected"
try:
# Get the ID from the selected row
tractate_chunk_id = evt.row_value[0]
# Get the source text
text = await get_source_text(tractate_chunk_id)
return text
except Exception as e:
return f"Error retrieving source text: {str(e)}"
def handle_row_selection(evt: gr.SelectData):
return asyncio.run(handle_row_selection_async(evt))
# Create Gradio app
# Ensure we clean up when done
async def main():
global run_id_dropdown
await initialize_data()
with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app:
gr.Markdown("# Source Runs Explorer")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Column(scale=1):
# Main content area
question_dropdown = gr.Dropdown(
choices=question_options,
label="Select Question",
value=None,
interactive=True,
elem_id="question_dropdown"
)
with gr.Column(scale=1):
baseline_rankers_dropdown = gr.Dropdown(
choices=baseline_ranker_options,
label="Select Baseline Ranker",
interactive=True,
elem_id="baseline_rankers_dropdown"
)
with gr.Row():
with gr.Column(scale=1):
source_finder_dropdown = gr.Dropdown(
choices=finder_options,
label="Source Finder",
interactive=True,
elem_id="source_finder_dropdown"
)
with gr.Column(scale=1):
run_id_dropdown = gr.Dropdown(
choices=run_ids,
value="1",
allow_custom_value=True,
label="Run id for Question and source finder",
interactive=True,
elem_id="run_id_dropdown"
)
with gr.Column(scale=1):
# Sidebar area
gr.Markdown("### About")
gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
gr.Markdown("Start by selecting a question, then optionally filter by source finder and run ID.")
gr.Markdown("### Statistics")
gr.Markdown(f"Total Questions: {len(questions)}")
gr.Markdown(f"Source Finders: {len(source_finders)}")
with gr.Row():
result_text = gr.Markdown("Select a question to view source runs")
with gr.Row():
gr.Markdown("# Source Run Statistics")
with gr.Row():
statistics_table = gr.DataFrame(
headers=["num_high_ranked_baseline_sources",
"num_high_ranked_found_sources",
"overlap_count",
"overlap_percentage",
"high_ranked_overlap_count",
"high_ranked_overlap_percentage"
],
interactive=False,
)
with gr.Row():
gr.Markdown("# Sources Found")
with gr.Row():
with gr.Column(scale=3):
results_table = gr.DataFrame(
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'source_reason'],
interactive=False
)
with gr.Column(scale=1):
source_text = gr.TextArea(
value="Text of the source will appear here",
lines=15,
label="Source Text",
interactive=False,
elem_id="source_text"
)
# download_button = gr.DownloadButton(
# label="Download Results as CSV",
# interactive=True,
# visible=True
# )
# Set up event handlers
results_table.select(
handle_row_selection,
inputs=None,
outputs=source_text
)
question_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
# outputs=[run_id_dropdown, results_table, result_text, download_button]
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
)
source_finder_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
# outputs=[run_id_dropdown, results_table, result_text, download_button]
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
)
run_id_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[results_table, statistics_table, run_id_dropdown, result_text]
)
app.queue()
app.launch()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())