davidr70 commited on
Commit
5cca310
·
1 Parent(s): 3a7a44c

fix case of "all Questions" when selected first

Browse files
Files changed (3) hide show
  1. app.py +12 -5
  2. data_access.py +15 -8
  3. tests/test_db_layer.py +20 -4
app.py CHANGED
@@ -62,7 +62,7 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
62
 
63
  # Main function to handle UI interactions
64
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
65
- global available_run_id_dict
66
  if not question_option:
67
  return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
68
  logger.info("processing update")
@@ -78,21 +78,28 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
78
  finder_id_int = None
79
 
80
  if question_option == "All questions":
81
- if finder_id_int and type(run_id) == str:
 
 
 
 
82
  run_id_int = available_run_id_dict.get(run_id)
83
  all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
 
84
  else:
 
85
  all_stats = None
86
- return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
 
87
 
88
  # Extract question ID from selection
89
  question_id = questions_dict.get(question_option)
90
 
91
- available_run_id_dict = await get_run_ids(conn, question_id, finder_id_int)
92
  run_id_options = list(available_run_id_dict.keys())
93
  if run_id not in run_id_options:
94
  run_id = run_id_options[0]
95
-
96
  run_id_int = available_run_id_dict.get(run_id)
97
 
98
 
 
62
 
63
  # Main function to handle UI interactions
64
  async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str):
65
+ global available_run_id_dict, previous_run_id
66
  if not question_option:
67
  return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
68
  logger.info("processing update")
 
78
  finder_id_int = None
79
 
80
  if question_option == "All questions":
81
+ if finder_id_int:
82
+ if run_id is None:
83
+ available_run_id_dict = await get_run_ids(conn, finder_id_int)
84
+ run_id = list(available_run_id_dict.keys())[0]
85
+ previous_run_id = run_id
86
  run_id_int = available_run_id_dict.get(run_id)
87
  all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
88
+
89
  else:
90
+ run_id_options = list(available_run_id_dict.keys())
91
  all_stats = None
92
+ run_id_options = list(available_run_id_dict.keys())
93
+ return None, all_stats, gr.Dropdown(choices=run_id_options, value=run_id), "Select Run Id and source finder to see results", ""
94
 
95
  # Extract question ID from selection
96
  question_id = questions_dict.get(question_option)
97
 
98
+ available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
99
  run_id_options = list(available_run_id_dict.keys())
100
  if run_id not in run_id_options:
101
  run_id = run_id_options[0]
102
+ previous_run_id = run_id
103
  run_id_int = available_run_id_dict.get(run_id)
104
 
105
 
data_access.py CHANGED
@@ -52,16 +52,23 @@ async def get_source_finders(conn: asyncpg.Connection):
52
 
53
 
54
  # Get distinct run IDs for a question
55
- async def get_run_ids(conn: asyncpg.Connection, question_id: int, source_finder_id: int):
56
  query = """
57
  select distinct sfr.description, srs.source_finder_run_id as run_id
58
- from talmudexplore.source_run_results srs
59
- join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
60
- join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
61
- where sfr.source_finder_id = $1
62
- and srs.question_id = $2
63
  """
64
- run_ids = await conn.fetch(query, source_finder_id, question_id)
 
 
 
 
 
 
 
 
65
  return {r["description"]:r["run_id"] for r in run_ids}
66
 
67
 
@@ -132,7 +139,7 @@ async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connec
132
  for question_id in question_ids:
133
  try:
134
  # Get unified sources for this question
135
- sources, stats = await get_unified_sources(conn, question_id, ranker_id, source_finder_run_id)
136
 
137
  if sources and len(sources) > 0:
138
  valid_questions += 1
 
52
 
53
 
54
  # Get distinct run IDs for a question
55
+ async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_id: int = None):
56
  query = """
57
  select distinct sfr.description, srs.source_finder_run_id as run_id
58
+ from source_run_results srs
59
+ join source_finder_runs sfr on srs.source_finder_run_id = sfr.id
60
+ join source_finders sf on sfr.source_finder_id = sf.id
61
+ where sfr.source_finder_id = $1
 
62
  """
63
+
64
+ if question_id is not None:
65
+ query += " and srs.question_id = $2"
66
+ params = (source_finder_id, question_id)
67
+ else:
68
+ params = (source_finder_id,)
69
+ query += " order by run_id DESC;"
70
+
71
+ run_ids = await conn.fetch(query, *params)
72
  return {r["description"]:r["run_id"] for r in run_ids}
73
 
74
 
 
139
  for question_id in question_ids:
140
  try:
141
  # Get unified sources for this question
142
+ sources, stats = await get_unified_sources(conn, question_id, source_finder_run_id, ranker_id)
143
 
144
  if sources and len(sources) > 0:
145
  valid_questions += 1
tests/test_db_layer.py CHANGED
@@ -97,7 +97,7 @@ async def test_get_run_ids():
97
 
98
  # Call the function to test
99
  async with get_async_connection() as conn:
100
- result = await get_run_ids(conn, question_id, source_finder_id)
101
 
102
  # Verify the result is a dictionary
103
  assert isinstance(result, dict), "Result should be a dictionary"
@@ -107,6 +107,22 @@ async def test_get_run_ids():
107
 
108
  # Test with a non-existent question_id
109
  non_existent_question_id = 9999
110
- empty_result = await get_run_ids(conn, non_existent_question_id, source_finder_id)
111
- assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
112
- assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Call the function to test
99
  async with get_async_connection() as conn:
100
+ result = await get_run_ids(conn, source_finder_id, question_id)
101
 
102
  # Verify the result is a dictionary
103
  assert isinstance(result, dict), "Result should be a dictionary"
 
107
 
108
  # Test with a non-existent question_id
109
  non_existent_question_id = 9999
110
+ empty_result = await get_run_ids(conn, source_finder_id, non_existent_question_id)
111
+ assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
112
+ assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
113
+
114
+ @pytest.mark.asyncio
115
+ async def test_get_run_ids_no_question_id():
116
+ source_finder_id = 2 # Using a source finder ID that exists in the test database
117
+
118
+ # Call the function to test
119
+ async with get_async_connection() as conn:
120
+ result = await get_run_ids(conn, source_finder_id)
121
+
122
+ # Verify the result is a dictionary
123
+ assert isinstance(result, dict), "Result should be a dictionary"
124
+
125
+ # Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
126
+ assert len(result) > 0, "Should return at least one run ID"
127
+
128
+