Spaces:
Running
Running
added features - to see totals
Browse files- app.py +19 -9
- data_access.py +76 -0
- tests/test_db_layer.py +49 -1
app.py
CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
-
get_unified_sources, get_source_text
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
@@ -63,6 +63,23 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
|
|
63 |
if not question_option:
|
64 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected"
|
65 |
logger.info("processing update")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Extract question ID from selection
|
67 |
question_id = questions_dict.get(question_option)
|
68 |
|
@@ -72,15 +89,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
|
|
72 |
run_id = run_id_options[0]
|
73 |
|
74 |
run_id_int = int(run_id)
|
75 |
-
if len(source_finder_name):
|
76 |
-
finder_id_int = source_finders_dict.get(source_finder_name)
|
77 |
-
else:
|
78 |
-
finder_id_int = None
|
79 |
|
80 |
-
if type(baseline_ranker_name) == list:
|
81 |
-
baseline_ranker_name = baseline_ranker_name[0]
|
82 |
|
83 |
-
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
84 |
|
85 |
source_runs = None
|
86 |
stats = None
|
@@ -138,7 +148,7 @@ async def main():
|
|
138 |
with gr.Column(scale=1):
|
139 |
# Main content area
|
140 |
question_dropdown = gr.Dropdown(
|
141 |
-
choices=question_options,
|
142 |
label="Select Question",
|
143 |
value=None,
|
144 |
interactive=True,
|
|
|
5 |
import logging
|
6 |
|
7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
8 |
+
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
|
|
63 |
if not question_option:
|
64 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected"
|
65 |
logger.info("processing update")
|
66 |
+
if type(baseline_ranker_name) == list:
|
67 |
+
baseline_ranker_name = baseline_ranker_name[0]
|
68 |
+
|
69 |
+
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
70 |
+
|
71 |
+
if len(source_finder_name):
|
72 |
+
finder_id_int = source_finders_dict.get(source_finder_name)
|
73 |
+
else:
|
74 |
+
finder_id_int = None
|
75 |
+
|
76 |
+
if question_option == "All questions":
|
77 |
+
if finder_id_int and type(run_id) == str:
|
78 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(finder_id_int, int(run_id), baseline_ranker_id_int)
|
79 |
+
else:
|
80 |
+
all_stats = None
|
81 |
+
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results"
|
82 |
+
|
83 |
# Extract question ID from selection
|
84 |
question_id = questions_dict.get(question_option)
|
85 |
|
|
|
89 |
run_id = run_id_options[0]
|
90 |
|
91 |
run_id_int = int(run_id)
|
|
|
|
|
|
|
|
|
92 |
|
|
|
|
|
93 |
|
|
|
94 |
|
95 |
source_runs = None
|
96 |
stats = None
|
|
|
148 |
with gr.Column(scale=1):
|
149 |
# Main content area
|
150 |
question_dropdown = gr.Dropdown(
|
151 |
+
choices=["All questions"] + question_options,
|
152 |
label="Select Question",
|
153 |
value=None,
|
154 |
interactive=True,
|
data_access.py
CHANGED
@@ -110,6 +110,82 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
|
|
110 |
return results_df
|
111 |
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
|
114 |
"""
|
115 |
Create unified view of sources from both baseline_sources and source_runs
|
|
|
110 |
return results_df
|
111 |
|
112 |
|
113 |
+
async def calculate_cumulative_statistics_for_all_questions(source_finder_id: int, run_id: int, ranker_id: int):
|
114 |
+
"""
|
115 |
+
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
source_finder_id (int): ID of the source finder
|
119 |
+
run_id (int): Run ID to analyze
|
120 |
+
ranker_id (int): ID of the baseline ranker
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
pd.DataFrame: DataFrame containing aggregated statistics
|
124 |
+
"""
|
125 |
+
async with get_async_connection() as conn:
|
126 |
+
# Get all questions
|
127 |
+
query = "SELECT id FROM questions ORDER BY id"
|
128 |
+
questions = await conn.fetch(query)
|
129 |
+
question_ids = [q["id"] for q in questions]
|
130 |
+
|
131 |
+
# Initialize aggregates
|
132 |
+
total_baseline_sources = 0
|
133 |
+
total_found_sources = 0
|
134 |
+
total_overlap = 0
|
135 |
+
total_high_ranked_baseline = 0
|
136 |
+
total_high_ranked_found = 0
|
137 |
+
total_high_ranked_overlap = 0
|
138 |
+
|
139 |
+
# Process each question
|
140 |
+
valid_questions = 0
|
141 |
+
for question_id in question_ids:
|
142 |
+
try:
|
143 |
+
# Get unified sources for this question
|
144 |
+
sources, stats = await get_unified_sources(question_id, source_finder_id, run_id, ranker_id)
|
145 |
+
|
146 |
+
if sources and len(sources) > 0:
|
147 |
+
valid_questions += 1
|
148 |
+
stats_dict = stats.iloc[0].to_dict()
|
149 |
+
|
150 |
+
# Add to running totals
|
151 |
+
total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
|
152 |
+
total_found_sources += stats_dict.get('total_found_sources', 0)
|
153 |
+
total_overlap += stats_dict.get('overlap_count', 0)
|
154 |
+
total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
|
155 |
+
total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
|
156 |
+
total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
|
157 |
+
except Exception as e:
|
158 |
+
# Skip questions with errors
|
159 |
+
continue
|
160 |
+
|
161 |
+
# Calculate overall percentages
|
162 |
+
overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
|
163 |
+
if max(total_baseline_sources, total_found_sources) > 0 else 0
|
164 |
+
|
165 |
+
high_ranked_overlap_percentage = round(
|
166 |
+
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
|
167 |
+
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
|
168 |
+
|
169 |
+
# Compile results
|
170 |
+
cumulative_stats = {
|
171 |
+
"total_questions_analyzed": valid_questions,
|
172 |
+
"total_baseline_sources": total_baseline_sources,
|
173 |
+
"total_found_sources": total_found_sources,
|
174 |
+
"total_overlap_count": total_overlap,
|
175 |
+
"overall_overlap_percentage": overlap_percentage,
|
176 |
+
"total_high_ranked_baseline_sources": total_high_ranked_baseline,
|
177 |
+
"total_high_ranked_found_sources": total_high_ranked_found,
|
178 |
+
"total_high_ranked_overlap_count": total_high_ranked_overlap,
|
179 |
+
"overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
|
180 |
+
"avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
|
181 |
+
2) if valid_questions > 0 else 0,
|
182 |
+
"avg_found_sources_per_question": round(total_found_sources / valid_questions,
|
183 |
+
2) if valid_questions > 0 else 0
|
184 |
+
}
|
185 |
+
|
186 |
+
return pd.DataFrame([cumulative_stats])
|
187 |
+
|
188 |
+
|
189 |
async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
|
190 |
"""
|
191 |
Create unified view of sources from both baseline_sources and source_runs
|
tests/test_db_layer.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import pytest
|
2 |
|
|
|
3 |
from data_access import get_unified_sources
|
4 |
|
5 |
|
@@ -16,4 +18,50 @@ async def test_get_unified_sources():
|
|
16 |
assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
|
17 |
|
18 |
# You can also check specific stats columns
|
19 |
-
assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
import pytest
|
3 |
|
4 |
+
from data_access import calculate_cumulative_statistics_for_all_questions
|
5 |
from data_access import get_unified_sources
|
6 |
|
7 |
|
|
|
18 |
assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
|
19 |
|
20 |
# You can also check specific stats columns
|
21 |
+
assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
@pytest.mark.asyncio
|
27 |
+
async def test_calculate_cumulative_statistics_for_all_questions():
|
28 |
+
# Test with known source_finder_id, run_id, and ranker_id
|
29 |
+
source_finder_id = 2
|
30 |
+
run_id = 1
|
31 |
+
ranker_id = 1
|
32 |
+
|
33 |
+
# Call the function to test
|
34 |
+
result = await calculate_cumulative_statistics_for_all_questions(source_finder_id, run_id, ranker_id)
|
35 |
+
|
36 |
+
# Check basic structure of results
|
37 |
+
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
38 |
+
assert result.shape[0] == 1, "Result should have one row"
|
39 |
+
|
40 |
+
# Check required columns exist
|
41 |
+
expected_columns = [
|
42 |
+
"total_questions_analyzed",
|
43 |
+
"total_baseline_sources",
|
44 |
+
"total_found_sources",
|
45 |
+
"total_overlap_count",
|
46 |
+
"overall_overlap_percentage",
|
47 |
+
"total_high_ranked_baseline_sources",
|
48 |
+
"total_high_ranked_found_sources",
|
49 |
+
"total_high_ranked_overlap_count",
|
50 |
+
"overall_high_ranked_overlap_percentage",
|
51 |
+
"avg_baseline_sources_per_question",
|
52 |
+
"avg_found_sources_per_question"
|
53 |
+
]
|
54 |
+
|
55 |
+
for column in expected_columns:
|
56 |
+
assert column in result.columns, f"Column {column} should be in result DataFrame"
|
57 |
+
|
58 |
+
# Check some basic value validations
|
59 |
+
assert result["total_questions_analyzed"].iloc[0] >= 0, "Should have zero or more questions analyzed"
|
60 |
+
assert result["total_baseline_sources"].iloc[0] >= 0, "Should have zero or more baseline sources"
|
61 |
+
assert result["total_found_sources"].iloc[0] >= 0, "Should have zero or more found sources"
|
62 |
+
|
63 |
+
# Check that percentages are within valid ranges
|
64 |
+
assert 0 <= result["overall_overlap_percentage"].iloc[0] <= 100, "Overlap percentage should be between 0 and 100"
|
65 |
+
assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
|
66 |
+
0] <= 100, "High ranked overlap percentage should be between 0 and 100"
|
67 |
+
|