Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
import gradio as gr | |
import pandas as pd | |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \ | |
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \ | |
get_async_connection | |
logger = logging.getLogger(__name__) | |
ALL_QUESTIONS_STR = "All questions" | |
# Initialize data at the module level | |
questions = [] | |
source_finders = [] | |
questions_dict = {} | |
source_finders_dict = {} | |
question_options = [] | |
baseline_rankers_dict = {} | |
baseline_ranker_options = [] | |
run_ids = [] | |
available_run_id_dict = {} | |
finder_options = [] | |
previous_run_id = "initial_run" | |
run_id_options = [] | |
run_id_dropdown = None | |
# Get all questions | |
# Initialize data in a single async function | |
async def initialize_data(): | |
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options | |
async with get_async_connection() as conn: | |
# Get questions and source finders | |
questions = await get_questions(conn) | |
source_finders = await get_source_finders(conn) | |
baseline_rankers = await get_baseline_rankers(conn) | |
# Convert to dictionaries for easier lookup | |
questions_dict = {q["text"]: q["id"] for q in questions} | |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers} | |
source_finders_dict = {f["name"]: f["id"] for f in source_finders} | |
# Create formatted options for dropdowns | |
question_options = [q['text'] for q in questions] | |
finder_options = [s["name"] for s in source_finders] | |
baseline_ranker_options = [b["name"] for b in baseline_rankers] | |
await update_run_ids_async(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0]) | |
def update_run_ids(question_option, source_finder_name): | |
return asyncio.run(update_run_ids_async(question_option, source_finder_name)) | |
async def update_run_ids_async(question_option, source_finder_name): | |
global previous_run_id, available_run_id_dict, run_id_options | |
async with get_async_connection() as conn: | |
finder_id_int = source_finders_dict.get(source_finder_name) | |
if question_option and question_option != ALL_QUESTIONS_STR: | |
question_id = questions_dict.get(question_option) | |
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id) | |
else: | |
available_run_id_dict = await get_run_ids(conn, finder_id_int) | |
run_id = list(available_run_id_dict.keys())[0] | |
previous_run_id = run_id | |
run_id_options = list(available_run_id_dict.keys()) | |
return None, None, gr.Dropdown(choices=run_id_options, | |
value=run_id), "Select Question to see results", "" | |
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, | |
evt: gr.EventData = None): | |
global previous_run_id | |
if evt: | |
logger.info(f"event: {evt.target.elem_id}") | |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id): | |
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() | |
if type(run_id) == str: | |
previous_run_id = run_id | |
return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id)) | |
# Main function to handle UI interactions | |
async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str): | |
global available_run_id_dict, previous_run_id | |
if not question_option: | |
return gr.skip(), gr.skip(), gr.skip(), "No question selected", "" | |
logger.info("processing update") | |
async with get_async_connection() as conn: | |
if type(baseline_ranker_name) == list: | |
baseline_ranker_name = baseline_ranker_name[0] | |
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get( | |
baseline_ranker_name) | |
if len(source_finder_name): | |
finder_id_int = source_finders_dict.get(source_finder_name) | |
else: | |
finder_id_int = None | |
if question_option == ALL_QUESTIONS_STR: | |
if finder_id_int: | |
if run_id is None: | |
available_run_id_dict = await get_run_ids(conn, finder_id_int) | |
run_id = list(available_run_id_dict.keys())[0] | |
previous_run_id = run_id | |
run_id_int = available_run_id_dict.get(run_id) | |
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, | |
baseline_ranker_id_int) | |
else: | |
run_id_options = list(available_run_id_dict.keys()) | |
all_stats = None | |
run_id_options = list(available_run_id_dict.keys()) | |
return None, all_stats, gr.Dropdown(choices=run_id_options, | |
value=run_id), "Select Run Id and source finder to see results", "" | |
# Extract question ID from selection | |
question_id = questions_dict.get(question_option) | |
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id) | |
run_id_options = list(available_run_id_dict.keys()) | |
if run_id not in run_id_options: | |
run_id = run_id_options[0] | |
previous_run_id = run_id | |
run_id_int = available_run_id_dict.get(run_id) | |
source_runs = None | |
stats = None | |
# Get source runs data | |
if finder_id_int: | |
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int) | |
# Create DataFrame for display | |
df = pd.DataFrame(source_runs) | |
if not source_runs: | |
return None, None, run_id_options, "No results found for the selected filters", | |
# Format table columns | |
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', | |
'tractate', | |
'folio', 'reason'] | |
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df | |
# CSV for download | |
# csv_data = df.to_csv(index=False) | |
metadata = await get_metadata(conn, question_id, run_id_int) | |
result_message = f"Found {len(source_runs)} results" | |
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata | |
# Add a new function to handle row selection | |
async def handle_row_selection_async(evt: gr.SelectData): | |
if evt is None or evt.value is None: | |
return "No source selected" | |
try: | |
# Get the ID from the selected row | |
tractate_chunk_id = evt.row_value[0] | |
# Get the source text | |
async with get_async_connection() as conn: | |
text = await get_source_text(conn, tractate_chunk_id) | |
return text | |
except Exception as e: | |
return f"Error retrieving source text: {str(e)}" | |
def handle_row_selection(evt: gr.SelectData): | |
return asyncio.run(handle_row_selection_async(evt)) | |
# Create Gradio app | |
# Ensure we clean up when done | |
async def main(): | |
global run_id_dropdown | |
await initialize_data() | |
with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app: | |
gr.Markdown("# Source Runs Explorer") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Main content area | |
question_dropdown = gr.Dropdown( | |
choices=[ALL_QUESTIONS_STR] + question_options, | |
label="Select Question", | |
value=None, | |
interactive=True, | |
elem_id="question_dropdown" | |
) | |
with gr.Column(scale=1): | |
baseline_rankers_dropdown = gr.Dropdown( | |
choices=baseline_ranker_options, | |
label="Select Baseline Ranker", | |
interactive=True, | |
elem_id="baseline_rankers_dropdown" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
source_finder_dropdown = gr.Dropdown( | |
choices=finder_options, | |
label="Source Finder", | |
interactive=True, | |
elem_id="source_finder_dropdown" | |
) | |
with gr.Column(scale=1): | |
run_id_dropdown = gr.Dropdown( | |
choices=run_id_options, | |
allow_custom_value=True, | |
label="Run id for Question and source finder", | |
interactive=True, | |
elem_id="run_id_dropdown" | |
) | |
with gr.Column(scale=1): | |
# Sidebar area | |
gr.Markdown("### About") | |
gr.Markdown("This tool allows you to explore source runs for Talmudic questions.") | |
gr.Markdown("### Statistics") | |
gr.Markdown(f"Total Questions: {len(questions)}") | |
gr.Markdown(f"Source Finders: {len(source_finders)}") | |
with gr.Row(): | |
result_text = gr.Markdown("Select a question to view source runs") | |
with gr.Row(): | |
gr.Markdown("# Source Run Statistics") | |
with gr.Row(): | |
statistics_table = gr.DataFrame( | |
headers=["num_high_ranked_baseline_sources", | |
"num_high_ranked_found_sources", | |
"overlap_count", | |
"overlap_percentage", | |
"high_ranked_overlap_count", | |
"high_ranked_overlap_percentage" | |
], | |
interactive=False, | |
) | |
with gr.Row(): | |
metadata_text = gr.TextArea( | |
label="Metadata of Source Finder for Selected Question", | |
elem_id="metadata", | |
lines=2 | |
) | |
with gr.Row(): | |
gr.Markdown("# Sources Found") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
results_table = gr.DataFrame( | |
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', | |
'source_run_rank', 'source_reason', 'metadata'], | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
source_text = gr.TextArea( | |
value="Text of the source will appear here", | |
lines=15, | |
label="Source Text", | |
interactive=False, | |
elem_id="source_text" | |
) | |
# download_button = gr.DownloadButton( | |
# label="Download Results as CSV", | |
# interactive=True, | |
# visible=True | |
# ) | |
# Set up event handlers | |
results_table.select( | |
handle_row_selection, | |
inputs=None, | |
outputs=source_text | |
) | |
baseline_rankers_dropdown.change( | |
update_sources_list, | |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text] | |
) | |
question_dropdown.change( | |
update_sources_list, | |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text] | |
) | |
source_finder_dropdown.change( | |
update_run_ids, | |
inputs=[question_dropdown, source_finder_dropdown], | |
# outputs=[run_id_dropdown, results_table, result_text, download_button] | |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text] | |
) | |
run_id_dropdown.change( | |
update_sources_list, | |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text] | |
) | |
app.queue() | |
app.launch() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
asyncio.run(main()) | |