davidr70 commited on
Commit
83afd54
·
1 Parent(s): 6e35819

improvements

Browse files
Files changed (2) hide show
  1. app.py +22 -19
  2. data_access.py +9 -36
app.py CHANGED
@@ -2,8 +2,9 @@ import asyncio
2
 
3
  import gradio as gr
4
  import pandas as pd
5
- from data_access import get_pool, get_async_connection, close_pool, get_questions, get_source_finders, get_run_ids, \
6
- get_source_runs, get_baseline_rankers, calculate_baseline_vs_source_stats_for_question, get_unified_sources
 
7
 
8
  # Initialize data at the module level
9
  questions = []
@@ -43,16 +44,19 @@ async def initialize_data():
43
  baseline_ranker_labels = {str(f["id"]): f["name"] for f in source_finders}
44
 
45
 
 
 
 
 
46
  # Main function to handle UI interactions
47
- def update_sources_list(question_option, source_finder_id, baseline_ranker_id: str, run_id:str):
48
  if not question_option:
49
  return None, [], "No question selected", None
50
 
51
  # Extract question ID from selection
52
  question_id = int(question_option.split(":")[0])
53
 
54
- # Get run_ids for filtering - use asyncio.run for each independent operation
55
- available_run_ids = asyncio.run(get_run_ids(question_id))
56
  run_id_options = [str(r_id) for r_id in available_run_ids]
57
  if run_id not in run_id_options:
58
  run_id = run_id_options[0]
@@ -67,16 +71,16 @@ def update_sources_list(question_option, source_finder_id, baseline_ranker_id: s
67
  stats = None
68
  # Get source runs data
69
  if finder_id_int:
70
- source_runs, stats = asyncio.run(get_unified_sources(question_id, finder_id_int, run_id_int, baseline_ranker_id_int))
71
  # Create DataFrame for display
72
  df = pd.DataFrame(source_runs)
73
 
74
  if not source_runs:
75
  return None, None, run_id_options, "No results found for the selected filters",
76
 
77
-
78
  # Format table columns
79
- columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate', 'folio', 'reason']
 
80
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
81
 
82
  # CSV for download
@@ -90,7 +94,6 @@ def update_sources_list(question_option, source_finder_id, baseline_ranker_id: s
90
 
91
  # Ensure we clean up when done
92
  async def main():
93
- await get_pool()
94
  await initialize_data()
95
  with gr.Blocks(title="Source Runs Explorer") as app:
96
  gr.Markdown("# Source Runs Explorer")
@@ -129,17 +132,16 @@ async def main():
129
  interactive=True
130
  )
131
 
132
-
133
  result_text = gr.Markdown("Select a question to view source runs")
134
  gr.Markdown("# Source Run Statistics")
135
  statistics_table = gr.DataFrame(
136
  headers=["num_high_ranked_baseline_sources",
137
- "num_high_ranked_found_sources",
138
- "overlap_count",
139
- "overlap_percentage",
140
- "high_ranked_overlap_count",
141
- "high_ranked_overlap_percentage"
142
- ],
143
  interactive=False,
144
  )
145
  gr.Markdown("# Sources Found")
@@ -187,9 +189,9 @@ async def main():
187
  )
188
 
189
  run_id_dropdown.change(
190
- update_sources_list,
191
- inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
192
- outputs=[results_table, statistics_table, run_id_dropdown, result_text]
193
  )
194
 
195
  # Initial load of data when question is selected
@@ -202,5 +204,6 @@ async def main():
202
  app.queue()
203
  app.launch()
204
 
 
205
  if __name__ == "__main__":
206
  asyncio.run(main())
 
2
 
3
  import gradio as gr
4
  import pandas as pd
5
+
6
+ from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
7
+ get_unified_sources
8
 
9
  # Initialize data at the module level
10
  questions = []
 
44
  baseline_ranker_labels = {str(f["id"]): f["name"] for f in source_finders}
45
 
46
 
47
+ def update_sources_list(question_option, source_finder_id, baseline_ranker_id: str, run_id: str):
48
+ return asyncio.run(update_sources_list_async(question_option, source_finder_id, baseline_ranker_id, run_id))
49
+
50
+
51
  # Main function to handle UI interactions
52
+ async def update_sources_list_async(question_option, source_finder_id, baseline_ranker_id: str, run_id: str):
53
  if not question_option:
54
  return None, [], "No question selected", None
55
 
56
  # Extract question ID from selection
57
  question_id = int(question_option.split(":")[0])
58
 
59
+ available_run_ids = await get_run_ids(question_id)
 
60
  run_id_options = [str(r_id) for r_id in available_run_ids]
61
  if run_id not in run_id_options:
62
  run_id = run_id_options[0]
 
71
  stats = None
72
  # Get source runs data
73
  if finder_id_int:
74
+ source_runs, stats = await get_unified_sources(question_id, finder_id_int, run_id_int, baseline_ranker_id_int)
75
  # Create DataFrame for display
76
  df = pd.DataFrame(source_runs)
77
 
78
  if not source_runs:
79
  return None, None, run_id_options, "No results found for the selected filters",
80
 
 
81
  # Format table columns
82
+ columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
83
+ 'folio', 'reason']
84
  df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
85
 
86
  # CSV for download
 
94
 
95
  # Ensure we clean up when done
96
  async def main():
 
97
  await initialize_data()
98
  with gr.Blocks(title="Source Runs Explorer") as app:
99
  gr.Markdown("# Source Runs Explorer")
 
132
  interactive=True
133
  )
134
 
 
135
  result_text = gr.Markdown("Select a question to view source runs")
136
  gr.Markdown("# Source Run Statistics")
137
  statistics_table = gr.DataFrame(
138
  headers=["num_high_ranked_baseline_sources",
139
+ "num_high_ranked_found_sources",
140
+ "overlap_count",
141
+ "overlap_percentage",
142
+ "high_ranked_overlap_count",
143
+ "high_ranked_overlap_percentage"
144
+ ],
145
  interactive=False,
146
  )
147
  gr.Markdown("# Sources Found")
 
189
  )
190
 
191
  run_id_dropdown.change(
192
+ update_sources_list,
193
+ inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
194
+ outputs=[results_table, statistics_table, run_id_dropdown, result_text]
195
  )
196
 
197
  # Initial load of data when question is selected
 
204
  app.queue()
205
  app.launch()
206
 
207
+
208
  if __name__ == "__main__":
209
  asyncio.run(main())
data_access.py CHANGED
@@ -9,52 +9,25 @@ from dotenv import load_dotenv
9
  import pandas as pd
10
 
11
  # Global connection pool
12
- _pool = None
13
  load_dotenv()
14
 
15
 
16
- async def get_pool(schema="talmudexplore", min_size=2, max_size=5):
17
- """Initialize and return the connection pool with the specified schema."""
18
- global _pool
19
- if _pool is not None:
20
- current_loop = asyncio.get_running_loop()
21
- if getattr(_pool, '_loop', None) != current_loop:
22
- try:
23
- await _pool.close()
24
- except:
25
- pass
26
- _pool = None
27
-
28
- if _pool is None:
29
- _pool = await asyncpg.create_pool(
30
  database=os.getenv("pg_dbname"),
31
  user=os.getenv("pg_user"),
32
  password=os.getenv("pg_password"),
33
  host=os.getenv("pg_host"),
34
- port=os.getenv("pg_port"),
35
- min_size=min_size,
36
- max_size=max_size,
37
- setup=lambda conn: conn.execute(f'SET search_path TO {schema}')
38
-
39
  )
40
- return _pool
41
-
42
- @asynccontextmanager
43
- async def get_async_connection():
44
- """Get a connection from the pool as an async context manager."""
45
- pool = await get_pool()
46
- conn = await pool.acquire()
47
- try:
48
  yield conn
49
  finally:
50
- await pool.release(conn)
51
-
52
- async def close_pool():
53
- """Close the connection pool."""
54
- global _pool
55
- if _pool:
56
- await _pool.close()
57
- _pool = None
58
 
59
 
60
  async def get_questions():
 
9
  import pandas as pd
10
 
11
  # Global connection pool
 
12
  load_dotenv()
13
 
14
 
15
+ @asynccontextmanager
16
+ async def get_async_connection(schema="talmudexplore"):
17
+ """Get a connection for the current request."""
18
+ try:
19
+ # Create a single connection without relying on a shared pool
20
+ conn = await asyncpg.connect(
 
 
 
 
 
 
 
 
21
  database=os.getenv("pg_dbname"),
22
  user=os.getenv("pg_user"),
23
  password=os.getenv("pg_password"),
24
  host=os.getenv("pg_host"),
25
+ port=os.getenv("pg_port")
 
 
 
 
26
  )
27
+ await conn.execute(f'SET search_path TO {schema}')
 
 
 
 
 
 
 
28
  yield conn
29
  finally:
30
+ await conn.close()
 
 
 
 
 
 
 
31
 
32
 
33
  async def get_questions():