import asyncio import logging import gradio as gr import pandas as pd from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \ get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \ get_async_connection logger = logging.getLogger(__name__) ALL_QUESTIONS_STR = "All questions" # Initialize data at the module level questions = [] source_finders = [] questions_dict = {} source_finders_dict = {} question_options = [] baseline_rankers_dict = {} baseline_ranker_options = [] run_ids = [] available_run_id_dict = {} finder_options = [] previous_run_id = "initial_run" run_id_options = [] run_id_dropdown = None # Get all questions # Initialize data in a single async function async def initialize_data(): global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options async with get_async_connection() as conn: source_finders = await get_source_finders(conn) baseline_rankers = await get_baseline_rankers(conn) # Convert to dictionaries for easier lookup baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers} source_finders_dict = {f["name"]: f["id"] for f in source_finders} # Create formatted options for dropdowns finder_options = [s["name"] for s in source_finders] baseline_ranker_options = [b["name"] for b in baseline_rankers] def update_run_ids(question_option, source_finder_name, baseline_ranker_name): return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name)) async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name): global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options async with get_async_connection() as conn: finder_id_int = source_finders_dict.get(source_finder_name) available_run_id_dict = await get_run_ids(conn, finder_id_int) run_id_options = list(available_run_id_dict.keys()) return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options, value=None), "Select Question to see results.csv", "" def update_questions_list(source_finder_name, run_id, baseline_ranker_name): return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name)) async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name): global available_run_id_dict if source_finder_name and run_id and baseline_ranker_name: async with get_async_connection() as conn: run_id_int = available_run_id_dict.get(run_id) baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name) questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int) return gr.Dropdown(choices=questions, value=None), None, None, None, None else: return None, None, None, None, None async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int): global questions_dict, questions questions = await get_questions(conn, finder_id_int, baseline_ranker_id) if questions: questions_dict = {q["text"]: q["id"] for q in questions} question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions] else: question_options = [] return question_options def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, evt: gr.EventData = None): global previous_run_id if evt: logger.info(f"event: {evt.target.elem_id}") if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id): return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() if type(run_id) == str: previous_run_id = run_id return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id)) # Main function to handle UI interactions async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str): global available_run_id_dict, previous_run_id, questions_dict if not question_option: return gr.skip(), gr.skip(), "No question selected", "" if not source_finder_name or not run_id or not baseline_ranker_name: return gr.skip(), gr.skip(), "Need to select source finder and baseline", "" logger.info("processing update") async with get_async_connection() as conn: if type(baseline_ranker_name) == list: baseline_ranker_name = baseline_ranker_name[0] baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get( baseline_ranker_name) if len(source_finder_name): finder_id_int = source_finders_dict.get(source_finder_name) else: finder_id_int = None if question_option == ALL_QUESTIONS_STR: if finder_id_int: run_id_int = available_run_id_dict.get(run_id) all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()), run_id_int, baseline_ranker_id_int) else: all_stats = None return None, all_stats, "Select Run Id and source finder to see results.csv", "" # Extract question ID from selection question_id = questions_dict.get(question_option) available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id) previous_run_id = run_id run_id_int = available_run_id_dict.get(run_id) source_runs = None stats = None # Get source runs data if finder_id_int: source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int) # Create DataFrame for display df = pd.DataFrame(source_runs) if not source_runs: return None, None, "No results.csv found for the selected filters", # Format table columns columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate', 'folio', 'reason'] df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df # CSV for download # csv_data = df.to_csv(index=False) metadata = await get_metadata(conn, question_id, run_id_int) result_message = f"Found {len(source_runs)} results.csv" return df_display, stats, result_message, metadata # Add a new function to handle row selection async def handle_row_selection_async(evt: gr.SelectData): if evt is None or evt.value is None: return "No source selected" try: # Get the ID from the selected row tractate_chunk_id = evt.row_value[0] # Get the source text async with get_async_connection() as conn: text = await get_source_text(conn, tractate_chunk_id) return text except Exception as e: return f"Error retrieving source text: {str(e)}" def handle_row_selection(evt: gr.SelectData): return asyncio.run(handle_row_selection_async(evt)) # Create Gradio app # Ensure we clean up when done async def main(): global run_id_dropdown await initialize_data() with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app: gr.Markdown("# Source Runs Explorer") with gr.Row(): with gr.Column(scale=3): with gr.Row(): with gr.Column(scale=1): source_finder_dropdown = gr.Dropdown( choices=finder_options, value=None, label="Source Finder", interactive=True, elem_id="source_finder_dropdown" ) with gr.Column(scale=1): run_id_dropdown = gr.Dropdown( choices=run_id_options, value=None, allow_custom_value=True, label="source finder Run ID", interactive=True, elem_id="run_id_dropdown" ) with gr.Column(scale=1): baseline_rankers_dropdown = gr.Dropdown( choices=baseline_ranker_options, value=None, label="Select Baseline Ranker", interactive=True, elem_id="baseline_rankers_dropdown" ) with gr.Row(): with gr.Column(scale=1): # Main content area question_dropdown = gr.Dropdown( choices=[ALL_QUESTIONS_STR] + question_options, label="Select Question (if list is empty this means there is no overlap between source run and baseline)", value=None, interactive=True, elem_id="question_dropdown" ) with gr.Column(scale=1): # Sidebar area gr.Markdown("""To Get started select the following: * Source Finder * Source Finder Run ID (corresponds to a run of the source finder for a group of questions) * Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions) **Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.** """) with gr.Row(): result_text = gr.Markdown("Select a question to view source runs") with gr.Row(): gr.Markdown("# Source Run Statistics") with gr.Row(): statistics_table = gr.DataFrame( headers=["num_high_ranked_baseline_sources", "num_high_ranked_found_sources", "overlap_count", "overlap_percentage", "high_ranked_overlap_count", "high_ranked_overlap_percentage" ], interactive=False, ) with gr.Row(): metadata_text = gr.TextArea( label="Metadata of Source Finder for Selected Question", elem_id="metadata", lines=2 ) with gr.Row(): gr.Markdown("# Sources Found") with gr.Row(): with gr.Column(scale=3): results_table = gr.DataFrame( headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'source_reason', 'metadata'], interactive=False ) with gr.Column(scale=1): source_text = gr.TextArea( value="Text of the source will appear here", lines=15, label="Source Text", interactive=False, elem_id="source_text" ) # download_button = gr.DownloadButton( # label="Download Results as CSV", # interactive=True, # visible=True # ) # Set up event handlers results_table.select( handle_row_selection, inputs=None, outputs=source_text ) baseline_rankers_dropdown.change( update_questions_list, inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], outputs=[question_dropdown, result_text, metadata_text] ) run_id_dropdown.change( update_questions_list, inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table] ) question_dropdown.change( update_sources_list, inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], outputs=[results_table, statistics_table, result_text, metadata_text] ) source_finder_dropdown.change( update_run_ids, inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown], # outputs=[run_id_dropdown, results_table, result_text, download_button] outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text] ) app.queue() app.launch() if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())