eval_results / app.py
davidr70's picture
fix init
714c819
raw
history blame
12.9 kB
import asyncio
import logging
import gradio as gr
import pandas as pd
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
get_async_connection
logger = logging.getLogger(__name__)
ALL_QUESTIONS_STR = "All questions"
# Initialize data at the module level
questions = []
source_finders = []
questions_dict = {}
source_finders_dict = {}
question_options = []
baseline_rankers_dict = {}
baseline_ranker_options = []
run_ids = []
available_run_id_dict = {}
finder_options = []
previous_run_id = "initial_run"
run_id_options = []
run_id_dropdown = None
# Get all questions
# Initialize data in a single async function
async def initialize_data():
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
async with get_async_connection() as conn:
# Get questions and source finders
questions = await get_questions(conn)
source_finders = await get_source_finders(conn)
baseline_rankers = await get_baseline_rankers(conn)
# Convert to dictionaries for easier lookup
questions_dict = {q["text"]: q["id"] for q in questions}
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
source_finders_dict = {f["name"]: f["id"] for f in source_finders}
# Create formatted options for dropdowns
question_options = [q['text'] for q in questions]
finder_options = [s["name"] for s in source_finders]
baseline_ranker_options = [b["name"] for b in baseline_rankers]
await update_run_ids_async(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0])
def update_run_ids(question_option, source_finder_name):
return asyncio.run(update_run_ids_async(question_option, source_finder_name))
async def update_run_ids_async(question_option, source_finder_name):
global previous_run_id, available_run_id_dict, run_id_options
async with get_async_connection() as conn:
finder_id_int = source_finders_dict.get(source_finder_name)
if question_option and question_option != ALL_QUESTIONS_STR:
question_id = questions_dict.get(question_option)
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
else:
available_run_id_dict = await get_run_ids(conn, finder_id_int)
run_id = list(available_run_id_dict.keys())[0]
previous_run_id = run_id
run_id_options = list(available_run_id_dict.keys())
return None, None, gr.Dropdown(choices=run_id_options,
value=run_id), "Select Question to see results", ""
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
evt: gr.EventData = None):
global previous_run_id
if evt:
logger.info(f"event: {evt.target.elem_id}")
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
if type(run_id) == str:
previous_run_id = run_id
return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id))
# Main function to handle UI interactions
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
global available_run_id_dict, previous_run_id
if not question_option:
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
logger.info("processing update")
async with get_async_connection() as conn:
if type(baseline_ranker_name) == list:
baseline_ranker_name = baseline_ranker_name[0]
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(
baseline_ranker_name)
if len(source_finder_name):
finder_id_int = source_finders_dict.get(source_finder_name)
else:
finder_id_int = None
if question_option == ALL_QUESTIONS_STR:
if finder_id_int:
if run_id is None:
available_run_id_dict = await get_run_ids(conn, finder_id_int)
run_id = list(available_run_id_dict.keys())[0]
previous_run_id = run_id
run_id_int = available_run_id_dict.get(run_id)
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int,
baseline_ranker_id_int)
else:
run_id_options = list(available_run_id_dict.keys())
all_stats = None
run_id_options = list(available_run_id_dict.keys())
return None, all_stats, gr.Dropdown(choices=run_id_options,
value=run_id), "Select Run Id and source finder to see results", ""
# Extract question ID from selection
question_id = questions_dict.get(question_option)
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
run_id_options = list(available_run_id_dict.keys())
if run_id not in run_id_options:
run_id = run_id_options[0]
previous_run_id = run_id
run_id_int = available_run_id_dict.get(run_id)
source_runs = None
stats = None
# Get source runs data
if finder_id_int:
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int)
# Create DataFrame for display
df = pd.DataFrame(source_runs)
if not source_runs:
return None, None, run_id_options, "No results found for the selected filters",
# Format table columns
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
'tractate',
'folio', 'reason']
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
# CSV for download
# csv_data = df.to_csv(index=False)
metadata = await get_metadata(conn, question_id, run_id_int)
result_message = f"Found {len(source_runs)} results"
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
# Add a new function to handle row selection
async def handle_row_selection_async(evt: gr.SelectData):
if evt is None or evt.value is None:
return "No source selected"
try:
# Get the ID from the selected row
tractate_chunk_id = evt.row_value[0]
# Get the source text
async with get_async_connection() as conn:
text = await get_source_text(conn, tractate_chunk_id)
return text
except Exception as e:
return f"Error retrieving source text: {str(e)}"
def handle_row_selection(evt: gr.SelectData):
return asyncio.run(handle_row_selection_async(evt))
# Create Gradio app
# Ensure we clean up when done
async def main():
global run_id_dropdown
await initialize_data()
with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app:
gr.Markdown("# Source Runs Explorer")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Column(scale=1):
# Main content area
question_dropdown = gr.Dropdown(
choices=[ALL_QUESTIONS_STR] + question_options,
label="Select Question",
value=None,
interactive=True,
elem_id="question_dropdown"
)
with gr.Column(scale=1):
baseline_rankers_dropdown = gr.Dropdown(
choices=baseline_ranker_options,
label="Select Baseline Ranker",
interactive=True,
elem_id="baseline_rankers_dropdown"
)
with gr.Row():
with gr.Column(scale=1):
source_finder_dropdown = gr.Dropdown(
choices=finder_options,
label="Source Finder",
interactive=True,
elem_id="source_finder_dropdown"
)
with gr.Column(scale=1):
run_id_dropdown = gr.Dropdown(
choices=run_id_options,
allow_custom_value=True,
label="Run id for Question and source finder",
interactive=True,
elem_id="run_id_dropdown"
)
with gr.Column(scale=1):
# Sidebar area
gr.Markdown("### About")
gr.Markdown("This tool allows you to explore source runs for Talmudic questions.")
gr.Markdown("### Statistics")
gr.Markdown(f"Total Questions: {len(questions)}")
gr.Markdown(f"Source Finders: {len(source_finders)}")
with gr.Row():
result_text = gr.Markdown("Select a question to view source runs")
with gr.Row():
gr.Markdown("# Source Run Statistics")
with gr.Row():
statistics_table = gr.DataFrame(
headers=["num_high_ranked_baseline_sources",
"num_high_ranked_found_sources",
"overlap_count",
"overlap_percentage",
"high_ranked_overlap_count",
"high_ranked_overlap_percentage"
],
interactive=False,
)
with gr.Row():
metadata_text = gr.TextArea(
label="Metadata of Source Finder for Selected Question",
elem_id="metadata",
lines=2
)
with gr.Row():
gr.Markdown("# Sources Found")
with gr.Row():
with gr.Column(scale=3):
results_table = gr.DataFrame(
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run',
'source_run_rank', 'source_reason', 'metadata'],
interactive=False
)
with gr.Column(scale=1):
source_text = gr.TextArea(
value="Text of the source will appear here",
lines=15,
label="Source Text",
interactive=False,
elem_id="source_text"
)
# download_button = gr.DownloadButton(
# label="Download Results as CSV",
# interactive=True,
# visible=True
# )
# Set up event handlers
results_table.select(
handle_row_selection,
inputs=None,
outputs=source_text
)
baseline_rankers_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
)
question_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
)
source_finder_dropdown.change(
update_run_ids,
inputs=[question_dropdown, source_finder_dropdown],
# outputs=[run_id_dropdown, results_table, result_text, download_button]
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
)
run_id_dropdown.change(
update_sources_list,
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
)
app.queue()
app.launch()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())