davidr70 commited on
Commit
873b70f
·
1 Parent(s): 5f4f31d

added features - to see totals

Browse files
Files changed (3) hide show
  1. app.py +19 -9
  2. data_access.py +76 -0
  3. tests/test_db_layer.py +49 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
- get_unified_sources, get_source_text
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -63,6 +63,23 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
63
  if not question_option:
64
  return gr.skip(), gr.skip(), gr.skip(), "No question selected"
65
  logger.info("processing update")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Extract question ID from selection
67
  question_id = questions_dict.get(question_option)
68
 
@@ -72,15 +89,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id:
72
  run_id = run_id_options[0]
73
 
74
  run_id_int = int(run_id)
75
- if len(source_finder_name):
76
- finder_id_int = source_finders_dict.get(source_finder_name)
77
- else:
78
- finder_id_int = None
79
 
80
- if type(baseline_ranker_name) == list:
81
- baseline_ranker_name = baseline_ranker_name[0]
82
 
83
- baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
84
 
85
  source_runs = None
86
  stats = None
@@ -138,7 +148,7 @@ async def main():
138
  with gr.Column(scale=1):
139
  # Main content area
140
  question_dropdown = gr.Dropdown(
141
- choices=question_options,
142
  label="Select Question",
143
  value=None,
144
  interactive=True,
 
5
  import logging
6
 
7
  from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
8
+ get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
63
  if not question_option:
64
  return gr.skip(), gr.skip(), gr.skip(), "No question selected"
65
  logger.info("processing update")
66
+ if type(baseline_ranker_name) == list:
67
+ baseline_ranker_name = baseline_ranker_name[0]
68
+
69
+ baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
70
+
71
+ if len(source_finder_name):
72
+ finder_id_int = source_finders_dict.get(source_finder_name)
73
+ else:
74
+ finder_id_int = None
75
+
76
+ if question_option == "All questions":
77
+ if finder_id_int and type(run_id) == str:
78
+ all_stats = await calculate_cumulative_statistics_for_all_questions(finder_id_int, int(run_id), baseline_ranker_id_int)
79
+ else:
80
+ all_stats = None
81
+ return None, all_stats, gr.skip(), "Select Run Id and source finder to see results"
82
+
83
  # Extract question ID from selection
84
  question_id = questions_dict.get(question_option)
85
 
 
89
  run_id = run_id_options[0]
90
 
91
  run_id_int = int(run_id)
 
 
 
 
92
 
 
 
93
 
 
94
 
95
  source_runs = None
96
  stats = None
 
148
  with gr.Column(scale=1):
149
  # Main content area
150
  question_dropdown = gr.Dropdown(
151
+ choices=["All questions"] + question_options,
152
  label="Select Question",
153
  value=None,
154
  interactive=True,
data_access.py CHANGED
@@ -110,6 +110,82 @@ async def calculate_baseline_vs_source_stats_for_question(baseline_sources , sou
110
  return results_df
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
114
  """
115
  Create unified view of sources from both baseline_sources and source_runs
 
110
  return results_df
111
 
112
 
113
+ async def calculate_cumulative_statistics_for_all_questions(source_finder_id: int, run_id: int, ranker_id: int):
114
+ """
115
+ Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
116
+
117
+ Args:
118
+ source_finder_id (int): ID of the source finder
119
+ run_id (int): Run ID to analyze
120
+ ranker_id (int): ID of the baseline ranker
121
+
122
+ Returns:
123
+ pd.DataFrame: DataFrame containing aggregated statistics
124
+ """
125
+ async with get_async_connection() as conn:
126
+ # Get all questions
127
+ query = "SELECT id FROM questions ORDER BY id"
128
+ questions = await conn.fetch(query)
129
+ question_ids = [q["id"] for q in questions]
130
+
131
+ # Initialize aggregates
132
+ total_baseline_sources = 0
133
+ total_found_sources = 0
134
+ total_overlap = 0
135
+ total_high_ranked_baseline = 0
136
+ total_high_ranked_found = 0
137
+ total_high_ranked_overlap = 0
138
+
139
+ # Process each question
140
+ valid_questions = 0
141
+ for question_id in question_ids:
142
+ try:
143
+ # Get unified sources for this question
144
+ sources, stats = await get_unified_sources(question_id, source_finder_id, run_id, ranker_id)
145
+
146
+ if sources and len(sources) > 0:
147
+ valid_questions += 1
148
+ stats_dict = stats.iloc[0].to_dict()
149
+
150
+ # Add to running totals
151
+ total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
152
+ total_found_sources += stats_dict.get('total_found_sources', 0)
153
+ total_overlap += stats_dict.get('overlap_count', 0)
154
+ total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
155
+ total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
156
+ total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
157
+ except Exception as e:
158
+ # Skip questions with errors
159
+ continue
160
+
161
+ # Calculate overall percentages
162
+ overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
163
+ if max(total_baseline_sources, total_found_sources) > 0 else 0
164
+
165
+ high_ranked_overlap_percentage = round(
166
+ total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
167
+ if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
168
+
169
+ # Compile results
170
+ cumulative_stats = {
171
+ "total_questions_analyzed": valid_questions,
172
+ "total_baseline_sources": total_baseline_sources,
173
+ "total_found_sources": total_found_sources,
174
+ "total_overlap_count": total_overlap,
175
+ "overall_overlap_percentage": overlap_percentage,
176
+ "total_high_ranked_baseline_sources": total_high_ranked_baseline,
177
+ "total_high_ranked_found_sources": total_high_ranked_found,
178
+ "total_high_ranked_overlap_count": total_high_ranked_overlap,
179
+ "overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
180
+ "avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
181
+ 2) if valid_questions > 0 else 0,
182
+ "avg_found_sources_per_question": round(total_found_sources / valid_questions,
183
+ 2) if valid_questions > 0 else 0
184
+ }
185
+
186
+ return pd.DataFrame([cumulative_stats])
187
+
188
+
189
  async def get_unified_sources(question_id: int, source_finder_id: int, run_id: int, ranker_id: int):
190
  """
191
  Create unified view of sources from both baseline_sources and source_runs
tests/test_db_layer.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import pytest
2
 
 
3
  from data_access import get_unified_sources
4
 
5
 
@@ -16,4 +18,50 @@ async def test_get_unified_sources():
16
  assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
17
 
18
  # You can also check specific stats columns
19
- assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
  import pytest
3
 
4
+ from data_access import calculate_cumulative_statistics_for_all_questions
5
  from data_access import get_unified_sources
6
 
7
 
 
18
  assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row"
19
 
20
  # You can also check specific stats columns
21
+ assert "overlap_count" in stats.columns, "Stats should contain overlap_count"
22
+
23
+
24
+
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_calculate_cumulative_statistics_for_all_questions():
28
+ # Test with known source_finder_id, run_id, and ranker_id
29
+ source_finder_id = 2
30
+ run_id = 1
31
+ ranker_id = 1
32
+
33
+ # Call the function to test
34
+ result = await calculate_cumulative_statistics_for_all_questions(source_finder_id, run_id, ranker_id)
35
+
36
+ # Check basic structure of results
37
+ assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
38
+ assert result.shape[0] == 1, "Result should have one row"
39
+
40
+ # Check required columns exist
41
+ expected_columns = [
42
+ "total_questions_analyzed",
43
+ "total_baseline_sources",
44
+ "total_found_sources",
45
+ "total_overlap_count",
46
+ "overall_overlap_percentage",
47
+ "total_high_ranked_baseline_sources",
48
+ "total_high_ranked_found_sources",
49
+ "total_high_ranked_overlap_count",
50
+ "overall_high_ranked_overlap_percentage",
51
+ "avg_baseline_sources_per_question",
52
+ "avg_found_sources_per_question"
53
+ ]
54
+
55
+ for column in expected_columns:
56
+ assert column in result.columns, f"Column {column} should be in result DataFrame"
57
+
58
+ # Check some basic value validations
59
+ assert result["total_questions_analyzed"].iloc[0] >= 0, "Should have zero or more questions analyzed"
60
+ assert result["total_baseline_sources"].iloc[0] >= 0, "Should have zero or more baseline sources"
61
+ assert result["total_found_sources"].iloc[0] >= 0, "Should have zero or more found sources"
62
+
63
+ # Check that percentages are within valid ranges
64
+ assert 0 <= result["overall_overlap_percentage"].iloc[0] <= 100, "Overlap percentage should be between 0 and 100"
65
+ assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[
66
+ 0] <= 100, "High ranked overlap percentage should be between 0 and 100"
67
+