import asyncio import os from contextlib import asynccontextmanager from typing import Optional import asyncpg import psycopg2 from dotenv import load_dotenv import pandas as pd # Global connection pool load_dotenv() @asynccontextmanager async def get_async_connection(schema="talmudexplore"): """Get a connection for the current request.""" try: # Create a single connection without relying on a shared pool conn = await asyncpg.connect( database=os.getenv("pg_dbname"), user=os.getenv("pg_user"), password=os.getenv("pg_password"), host=os.getenv("pg_host"), port=os.getenv("pg_port") ) await conn.execute(f'SET search_path TO {schema}') yield conn finally: await conn.close() async def get_questions(): async with get_async_connection() as conn: questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id") return [{"id": q["id"], "text": q["question_text"]} for q in questions] # Get distinct source finders async def get_source_finders(): async with get_async_connection() as conn: finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id") return [{"id": f["id"], "name": f["name"]} for f in finders] # Get distinct run IDs for a question async def get_run_ids(question_id: int): async with get_async_connection() as conn: query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc" params = [question_id] run_ids = await conn.fetch(query, *params) return [r["run_id"] for r in run_ids] # Get source runs for a specific question with filters async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None, run_id: Optional[int] = None): async with get_async_connection() as conn: # Build query with filters query = """ SELECT sr.*, sf.source_finder_type as finder_name FROM source_runs sr JOIN source_finders sf ON sr.source_finder_id = sf.id WHERE sr.question_id = $1 and sr.run_id = $2 AND sr.source_finder_id = $3 """ params = [question_id, run_id, source_finder_id] query += " ORDER BY sr.rank DESC" sources = await conn.fetch(query, *params) return [dict(s) for s in sources] async def get_baseline_rankers(): async with get_async_connection() as conn: rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id") return [{"id": f["id"], "name": f["ranker"]} for f in rankers] async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources): # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats # e.g. overlap, high ranked overlap, etc. async with get_async_connection() as conn: actual_sources_set = {s["id"] for s in source_runs_sources} baseline_sources_set = {s["id"] for s in baseline_sources} # Calculate overlap overlap = actual_sources_set.intersection(baseline_sources_set) # only_in_1 = actual_sources_set - baseline_sources_set # only_in_2 = baseline_sources_set - actual_sources_set # Calculate high-ranked overlap (rank >= 4) actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4} baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4} high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked) results = { "total_baseline_sources": len(baseline_sources), "total_found_sources": len(source_runs_sources), "overlap_count": len(overlap), "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)), 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0, "num_high_ranked_baseline_sources": len(baseline_high_ranked), "num_high_ranked_found_sources": len(actual_high_ranked), "high_ranked_overlap_count": len(high_ranked_overlap), "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0 } #convert results to dataframe results_df = pd.DataFrame([results]) return results_df async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int): """ Create unified view of sources from both baseline_sources and source_runs with indicators of where each source appears and their respective ranks. """ async with get_async_connection() as conn: # Get sources from source_runs query_runs = """ SELECT tb.tractate_chunk_id as id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason FROM source_runs sr join talmud_bavli tb on sr.sugya_id = tb.xml_id WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3 """ source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id) # Get sources from baseline_sources query_baseline = """ SELECT tb.tractate_chunk_id as id, bs.rank as baseline_rank, bs.tractate, bs.folio FROM baseline_sources bs join talmud_bavli tb on bs.sugya_id = tb.xml_id WHERE bs.question_id = $1 AND bs.ranker_id = $2 """ baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id) stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs) # Convert to dictionaries for easier lookup source_runs_dict = {s["id"]: dict(s) for s in source_runs} baseline_dict = {s["id"]: dict(s) for s in baseline_sources} # Get all unique sugya_ids all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys()) # Build unified results unified_results = [] for sugya_id in all_sugya_ids: in_source_run = sugya_id in source_runs_dict in_baseline = sugya_id in baseline_dict if in_baseline: info = baseline_dict[sugya_id] else: info = source_runs_dict[sugya_id] result = { "id": sugya_id, "tractate": info.get("tractate"), "folio": info.get("folio"), "in_baseline": "Yes" if in_baseline else "No", "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"), "in_source_run": "Yes" if in_source_run else "No", "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"), "source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A") } unified_results.append(result) return unified_results, stats_df async def get_source_text(tractate_chunk_id: int): """ Retrieves the text content for a given tractate chunk ID. """ async with get_async_connection() as conn: query = """ SELECT tb.text_with_nikud as text FROM talmud_bavli tb WHERE tb.tractate_chunk_id = $1 """ result = await conn.fetchrow(query, tractate_chunk_id) return result["text"] if result else "Source text not found" def get_pg_sync_connection(schema="talmudexplore"): conn = psycopg2.connect(dbname=os.getenv("pg_dbname"), user=os.getenv("pg_user"), password=os.getenv("pg_password"), host=os.getenv("pg_host"), port=os.getenv("pg_port"), options=f"-c search_path={schema}") return conn