eval_results / data_access.py
davidr70's picture
improvements
83afd54
raw
history blame
7.51 kB
import asyncio
import os
from contextlib import asynccontextmanager
from typing import Optional
import asyncpg
import psycopg2
from dotenv import load_dotenv
import pandas as pd
# Global connection pool
load_dotenv()
@asynccontextmanager
async def get_async_connection(schema="talmudexplore"):
"""Get a connection for the current request."""
try:
# Create a single connection without relying on a shared pool
conn = await asyncpg.connect(
database=os.getenv("pg_dbname"),
user=os.getenv("pg_user"),
password=os.getenv("pg_password"),
host=os.getenv("pg_host"),
port=os.getenv("pg_port")
)
await conn.execute(f'SET search_path TO {schema}')
yield conn
finally:
await conn.close()
async def get_questions():
async with get_async_connection() as conn:
questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
# Get distinct source finders
async def get_source_finders():
async with get_async_connection() as conn:
finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
return [{"id": f["id"], "name": f["name"]} for f in finders]
# Get distinct run IDs for a question
async def get_run_ids(question_id: int):
async with get_async_connection() as conn:
query = "SELECT DISTINCT run_id FROM source_runs WHERE question_id = $1 order by run_id desc"
params = [question_id]
run_ids = await conn.fetch(query, *params)
return [r["run_id"] for r in run_ids]
# Get source runs for a specific question with filters
async def get_source_runs(question_id: int, source_finder_id: Optional[int] = None,
run_id: Optional[int] = None):
async with get_async_connection() as conn:
# Build query with filters
query = """
SELECT sr.*, sf.source_finder_type as finder_name
FROM source_runs sr
JOIN source_finders sf ON sr.source_finder_id = sf.id
WHERE sr.question_id = $1 and sr.run_id = $2
AND sr.source_finder_id = $3
"""
params = [question_id, run_id, source_finder_id]
query += " ORDER BY sr.rank DESC"
sources = await conn.fetch(query, *params)
return [dict(s) for s in sources]
async def get_baseline_rankers():
async with get_async_connection() as conn:
rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources):
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
# e.g. overlap, high ranked overlap, etc.
async with get_async_connection() as conn:
actual_sources_set = {s["sugya_id"] for s in source_runs_sources}
baseline_sources_set = {s["sugya_id"] for s in baseline_sources}
# Calculate overlap
overlap = actual_sources_set.intersection(baseline_sources_set)
# only_in_1 = actual_sources_set - baseline_sources_set
# only_in_2 = baseline_sources_set - actual_sources_set
# Calculate high-ranked overlap (rank >= 4)
actual_high_ranked = {s["sugya_id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
baseline_high_ranked = {s["sugya_id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
results = {
"total_baseline_sources": len(baseline_sources),
"total_found_sources": len(source_runs_sources),
"overlap_count": len(overlap),
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
"num_high_ranked_found_sources": len(actual_high_ranked),
"high_ranked_overlap_count": len(high_ranked_overlap),
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
}
#convert results to dataframe
results_df = pd.DataFrame([results])
return results_df
async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
"""
Create unified view of sources from both baseline_sources and source_runs
with indicators of where each source appears and their respective ranks.
"""
async with get_async_connection() as conn:
# Get sources from source_runs
query_runs = """
SELECT sr.sugya_id, sr.rank as source_rank, sr.tractate, sr.folio, sr.reason as source_reason
FROM source_runs sr
WHERE sr.question_id = $1 AND sr.source_finder_id = $2 AND sr.run_id = $3
"""
source_runs = await conn.fetch(query_runs, question_id, source_finder_id, run_id)
# Get sources from baseline_sources
query_baseline = """
SELECT bs.sugya_id, bs.rank as baseline_rank, bs.tractate, bs.folio
FROM baseline_sources bs
WHERE bs.question_id = $1 AND bs.ranker_id = $2
"""
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
# Convert to dictionaries for easier lookup
source_runs_dict = {s["sugya_id"]: dict(s) for s in source_runs}
baseline_dict = {s["sugya_id"]: dict(s) for s in baseline_sources}
# Get all unique sugya_ids
all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys())
# Build unified results
unified_results = []
for sugya_id in all_sugya_ids:
in_source_run = sugya_id in source_runs_dict
in_baseline = sugya_id in baseline_dict
if in_baseline:
info = baseline_dict[sugya_id]
else:
info = source_runs_dict[sugya_id]
result = {
"sugya_id": sugya_id,
"tractate": info.get("tractate", "N/A"),
"folio": info.get("folio", "N/A"),
"in_baseline": "Yes" if in_baseline else "No",
"baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"),
"in_source_run": "Yes" if in_source_run else "No",
"source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"),
"source_reason": source_runs_dict.get(sugya_id, {}).get("reason", "N/A")
}
unified_results.append(result)
return unified_results, stats_df
def get_pg_sync_connection(schema="talmudexplore"):
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
user=os.getenv("pg_user"),
password=os.getenv("pg_password"),
host=os.getenv("pg_host"),
port=os.getenv("pg_port"),
options=f"-c search_path={schema}")
return conn