import asyncio import os from contextlib import asynccontextmanager from typing import Optional import asyncpg import psycopg2 from cachetools import TTLCache, cached from dotenv import load_dotenv import pandas as pd # Global connection pool load_dotenv() @asynccontextmanager async def get_async_connection(schema="talmudexplore"): """Get a connection for the current request.""" conn = None try: # Create a single connection without relying on a shared pool conn = await asyncpg.connect( database=os.getenv("pg_dbname"), user=os.getenv("pg_user"), password=os.getenv("pg_password"), host=os.getenv("pg_host"), port=os.getenv("pg_port") ) await conn.execute(f'SET search_path TO {schema}') yield conn finally: if conn: await conn.close() async def get_questions(conn: asyncpg.Connection): questions = await conn.fetch("SELECT id, question_text FROM questions where question_group_id = 1 ORDER BY id") return [{"id": q["id"], "text": q["question_text"]} for q in questions] @cached(cache=TTLCache(ttl=1800, maxsize=1024)) async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int): metadata = await conn.fetchrow(''' SELECT metadata FROM source_finder_run_question_metadata sfrqm WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2; ''', question_id, source_finder_id_run_id) if metadata is None: return "" return metadata.get('metadata') # Get distinct source finders async def get_source_finders(conn: asyncpg.Connection): finders = await conn.fetch(""" SELECT distinct sf.id, sf.source_finder_type as name from talmudexplore.source_finder_runs sfr join talmudexplore.source_finders sf on sf.id = sfr.source_finder_id WHERE EXISTS ( SELECT 1 FROM talmudexplore.source_run_results srr WHERE srr.source_finder_run_id = sfr.id ) ORDER BY sf.id """ ) return [{"id": f["id"], "name": f["name"]} for f in finders] # Get distinct run IDs for a question @cached(cache=TTLCache(ttl=1800, maxsize=1024)) async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_id: int = None): query = """ select distinct sfr.description, srs.source_finder_run_id as run_id from source_run_results srs join source_finder_runs sfr on srs.source_finder_run_id = sfr.id join source_finders sf on sfr.source_finder_id = sf.id where sfr.source_finder_id = $1 """ if question_id is not None: query += " and srs.question_id = $2" params = (source_finder_id, question_id) else: params = (source_finder_id,) query += " order by run_id DESC;" run_ids = await conn.fetch(query, *params) return {r["description"]:r["run_id"] for r in run_ids} async def get_baseline_rankers(conn: asyncpg.Connection): query = """ SELECT sfr.id, sf.source_finder_type, sfr.description from source_finder_runs sfr join source_finders sf on sf.id = sfr.source_finder_id WHERE EXISTS ( SELECT 1 FROM source_run_results srr WHERE srr.source_finder_run_id = sfr.id ) ORDER BY sf.id """ rankers = await conn.fetch(query) return [{"id": r["id"], "name": f"{r['source_finder_type']} : {r['description']}"} for r in rankers] async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources): # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats # e.g. overlap, high ranked overlap, etc. actual_sources_set = {s["id"] for s in source_runs_sources} baseline_sources_set = {s["id"] for s in baseline_sources} # Calculate overlap overlap = actual_sources_set.intersection(baseline_sources_set) # only_in_1 = actual_sources_set - baseline_sources_set # only_in_2 = baseline_sources_set - actual_sources_set # Calculate high-ranked overlap (rank >= 4) actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4} baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4} high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked) results = { "total_baseline_sources": len(baseline_sources), "total_found_sources": len(source_runs_sources), "overlap_count": len(overlap), "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)), 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0, "num_high_ranked_baseline_sources": len(baseline_high_ranked), "num_high_ranked_found_sources": len(actual_high_ranked), "high_ranked_overlap_count": len(high_ranked_overlap), "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0 } #convert results to dataframe results_df = pd.DataFrame([results]) return results_df async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int): """ Calculate cumulative statistics across all questions for a specific source finder, run, and ranker. Args: source_finder_run_id (int): ID of the source finder and run as appears in source runs ranker_id (int): ID of the baseline ranker Returns: pd.DataFrame: DataFrame containing aggregated statistics """ # Get all questions query = "SELECT id FROM questions ORDER BY id" questions = await conn.fetch(query) question_ids = [q["id"] for q in questions] # Initialize aggregates total_baseline_sources = 0 total_found_sources = 0 total_overlap = 0 total_high_ranked_baseline = 0 total_high_ranked_found = 0 total_high_ranked_overlap = 0 # Process each question valid_questions = 0 for question_id in question_ids: try: # Get unified sources for this question sources, stats = await get_unified_sources(conn, question_id, source_finder_run_id, ranker_id) if sources and len(sources) > 0: valid_questions += 1 stats_dict = stats.iloc[0].to_dict() # Add to running totals total_baseline_sources += stats_dict.get('total_baseline_sources', 0) total_found_sources += stats_dict.get('total_found_sources', 0) total_overlap += stats_dict.get('overlap_count', 0) total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0) total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0) total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0) except Exception as e: # Skip questions with errors continue # Calculate overall percentages overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \ if max(total_baseline_sources, total_found_sources) > 0 else 0 high_ranked_overlap_percentage = round( total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \ if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0 # Compile results cumulative_stats = { "total_questions_analyzed": valid_questions, "total_baseline_sources": total_baseline_sources, "total_found_sources": total_found_sources, "total_overlap_count": total_overlap, "overall_overlap_percentage": overlap_percentage, "total_high_ranked_baseline_sources": total_high_ranked_baseline, "total_high_ranked_found_sources": total_high_ranked_found, "total_high_ranked_overlap_count": total_high_ranked_overlap, "overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage, "avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions, 2) if valid_questions > 0 else 0, "avg_found_sources_per_question": round(total_found_sources / valid_questions, 2) if valid_questions > 0 else 0 } return pd.DataFrame([cumulative_stats]) async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int): """ Create unified view of sources from both baseline_sources and source_runs with indicators of where each source appears and their respective ranks. """ query_runs = """ SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason FROM source_run_results sr join talmud_bavli tb on sr.sugya_id = tb.xml_id WHERE sr.question_id = $1 AND sr.source_finder_run_id = $2 """ source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id) # Get sources from baseline_sources baseline_query = query_runs.replace("source_rank", "baseline_rank") baseline_sources = await conn.fetch(baseline_query, question_id, ranker_id) stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs) # Convert to dictionaries for easier lookup source_runs_dict = {s["id"]: dict(s) for s in source_runs} baseline_dict = {s["id"]: dict(s) for s in baseline_sources} # Get all unique sugya_ids all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys()) # Build unified results unified_results = [] for sugya_id in all_sugya_ids: in_source_run = sugya_id in source_runs_dict in_baseline = sugya_id in baseline_dict if in_baseline: info = baseline_dict[sugya_id] else: info = source_runs_dict[sugya_id] result = { "id": sugya_id, "tractate": info.get("tractate"), "folio": info.get("folio"), "in_baseline": "Yes" if in_baseline else "No", "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"), "in_source_run": "Yes" if in_source_run else "No", "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"), "source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A"), "metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "") } unified_results.append(result) return unified_results, stats_df @cached(cache=TTLCache(ttl=1800, maxsize=1024)) async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int): """ Retrieves the text content for a given tractate chunk ID. """ query = """ SELECT tb.text as text FROM talmud_bavli tb WHERE tb.tractate_chunk_id = $1 """ result = await conn.fetchrow(query, tractate_chunk_id) return result["text"] if result else "Source text not found" def get_pg_sync_connection(schema="talmudexplore"): conn = psycopg2.connect(dbname=os.getenv("pg_dbname"), user=os.getenv("pg_user"), password=os.getenv("pg_password"), host=os.getenv("pg_host"), port=os.getenv("pg_port"), options=f"-c search_path={schema}") return conn