eval_results / tests /test_db_layer.py
davidr70's picture
changes for new version
ea4284c
import pandas as pd
import pytest
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
get_async_connection, get_questions
from data_access import get_unified_sources
@pytest.mark.asyncio
async def test_get_questions():
source_run_id = 2
baseline_source_finder_run_id = 1
async with get_async_connection() as conn:
actual = await get_questions(conn, source_run_id, baseline_source_finder_run_id)
assert len(actual) == 10
@pytest.mark.asyncio
async def test_get_unified_sources():
async with get_async_connection() as conn:
results, stats = await get_unified_sources(conn,2, 2, 1)
assert results is not None
assert stats is not None
# Check number of rows in results.csv list
assert len(results) > 4, "Results should contain at least one row"
# Check number of rows in stats DataFrame
assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
# You can also check specific stats columns
assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
@pytest.mark.asyncio
async def test_calculate_cumulative_statistics_for_all_questions():
# Test with known source_finder_id, run_id, and ranker_id
source_finder_run_id = 2
ranker_id = 1
# Call the function to test
async with get_async_connection() as conn:
questions = await get_questions(conn, source_finder_run_id, ranker_id)
question_ids = [question['id'] for question in questions]
result = await calculate_cumulative_statistics_for_all_questions(conn, question_ids, source_finder_run_id, ranker_id)
# Check basic structure of results.csv
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
assert result.shape[0] == 1, "Result should have one row"
# Check required columns exist
expected_columns = [
"total_questions_analyzed",
"total_baseline_sources",
"total_found_sources",
"total_overlap_count",
"overall_overlap_percentage",
"total_high_ranked_baseline_sources",
"total_high_ranked_found_sources",
"total_high_ranked_overlap_count",
"overall_high_ranked_overlap_percentage",
"avg_baseline_sources_per_question",
"avg_found_sources_per_question"
]
for column in expected_columns:
assert column in result.columns, f"Column {column} should be in result DataFrame"
# Check some basic value validations
assert result["total_questions_analyzed"].iloc[0] >= 0, "Should have zero or more questions analyzed"
assert result["total_baseline_sources"].iloc[0] >= 0, "Should have zero or more baseline sources"
assert result["total_found_sources"].iloc[0] >= 0, "Should have zero or more found sources"
# Check that percentages are within valid ranges
assert 0 <= result["overall_overlap_percentage"].iloc[0] <= 100, "Overlap percentage should be between 0 and 100"
assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
0] <= 100, "High ranked overlap percentage should be between 0 and 100"
@pytest.mark.asyncio
async def test_get_metadata_none_returned():
# Test with known source_finder_id, run_id, and ranker_id
source_finder_run_id = 1
question_id = 1
# Call the function to test
async with get_async_connection() as conn:
result = await get_metadata(conn, question_id, source_finder_run_id)
assert result == {}, "Should return empty string when no metadata is found"
@pytest.mark.asyncio
async def test_get_metadata():
# Test with known source_finder_id, run_id, and ranker_id
source_finder_run_id = 4
question_id = 1
# Call the function to test
async with get_async_connection() as conn:
result = await get_metadata(conn, question_id, source_finder_run_id)
assert result is not None, "Should return metadata when it exists"
@pytest.mark.asyncio
async def test_get_run_ids():
# Test with known question_id and source_finder_id
question_id = 2 # Using a question ID that exists in the test database
source_finder_id = 2 # Using a source finder ID that exists in the test database
# Call the function to test
async with get_async_connection() as conn:
result = await get_run_ids(conn, source_finder_id, question_id)
# Verify the result is a dictionary
assert isinstance(result, dict), "Result should be a dictionary"
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
assert len(result) > 0, "Should return at least one run ID"
# Test with a non-existent question_id
non_existent_question_id = 9999
empty_result = await get_run_ids(conn, source_finder_id, non_existent_question_id)
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
@pytest.mark.asyncio
async def test_get_run_ids_no_question_id():
source_finder_id = 2 # Using a source finder ID that exists in the test database
# Call the function to test
async with get_async_connection() as conn:
result = await get_run_ids(conn, source_finder_id)
# Verify the result is a dictionary
assert isinstance(result, dict), "Result should be a dictionary"
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
assert len(result) > 0, "Should return at least one run ID"