dsleo commited on
Commit
74d7b60
·
verified ·
1 Parent(s): 17186a1

fix filtering

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -15,6 +15,14 @@ st.set_page_config(
15
  initial_sidebar_state="expanded"
16
  )
17
 
 
 
 
 
 
 
 
 
18
  # Load a pre-trained model for embeddings with HF caching
19
  @st.cache_resource
20
  def load_model():
@@ -38,7 +46,6 @@ def load_data():
38
  return df[["uuid", "problem", "source", "question_type", "problem_type"]]
39
  except Exception as e:
40
  st.error(f"Error loading dataset: {e}")
41
- # Return empty DataFrame with correct columns if loading fails
42
  return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
43
 
44
  # Cache embeddings computation with error handling
@@ -51,13 +58,11 @@ def compute_embeddings(problems):
51
  st.error(f"Error computing embeddings: {e}")
52
  return np.array([])
53
 
54
- # ================== FUNCTION DEFINITIONS ==================
55
  def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
56
  """Find similar problems using cosine similarity, optimized for speed."""
57
  if df.empty:
58
  return []
59
 
60
- # Compute embeddings with progress tracking
61
  embeddings = compute_embeddings(df['problem'].tolist())
62
  if embeddings.size == 0:
63
  return []
@@ -65,17 +70,14 @@ def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
65
  if progress_bar:
66
  progress_bar.progress(0.33, "Computing similarity matrix...")
67
 
68
- # Compute similarity matrix
69
  similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
70
  if progress_bar:
71
  progress_bar.progress(0.66, "Finding similar pairs...")
72
 
73
- # Use numpy operations for better performance
74
  num_problems = len(df)
75
  upper_triangle_indices = np.triu_indices(num_problems, k=1)
76
  similarity_scores = similarity_matrix[upper_triangle_indices]
77
 
78
- # Filter based on threshold
79
  mask = similarity_scores > similarity_threshold
80
  filtered_indices = np.where(mask)[0]
81
 
@@ -121,19 +123,22 @@ def analyze_clusters(_df, pairs):
121
  })
122
  return detailed_analysis
123
 
124
- # ================== STREAMLIT UI ==================
 
 
 
 
 
 
 
 
125
  def main():
126
  st.title("🔍 Problem Deduplication Explorer")
127
 
128
- # Check if model loaded successfully
129
  if model is None:
130
  st.error("Failed to load the model. Please try again later.")
131
  return
132
 
133
- # Initialize session state for pagination
134
- if 'page_number' not in st.session_state:
135
- st.session_state.page_number = 0
136
-
137
  # Sidebar configuration
138
  with st.sidebar:
139
  st.header("Settings")
@@ -168,12 +173,13 @@ def main():
168
  )
169
 
170
  # Analysis section
171
- if st.sidebar.button("Run Deduplication Analysis", type="primary"):
172
- progress_bar = st.progress(0, "Starting analysis...")
 
 
 
173
 
174
- # Run analysis
175
- pairs = find_similar_problems(df, similarity_threshold, progress_bar)
176
- results = analyze_clusters(df, pairs)
177
 
178
  if not results:
179
  st.warning("No similar problems found with the current threshold.")
@@ -189,18 +195,17 @@ def main():
189
  with col2:
190
  selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
191
 
192
- # Apply filters
193
- if selected_source:
194
- results = [r for r in results if df[df["uuid"] == r["base_uuid"]]["source"].values[0] == selected_source]
195
- if selected_qtype:
196
- results = [r for r in results if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
197
 
198
- if not results:
199
  st.warning("No results found with the current filters.")
200
  return
201
 
202
  # Pagination
203
- total_pages = len(results) // items_per_page
 
204
 
205
  col1, col2, col3 = st.columns([1, 3, 1])
206
  with col1:
@@ -215,7 +220,7 @@ def main():
215
  # Display results
216
  start_idx = st.session_state.page_number * items_per_page
217
  end_idx = start_idx + items_per_page
218
- page_results = results[start_idx:end_idx]
219
 
220
  for entry in page_results:
221
  with st.container():
 
15
  initial_sidebar_state="expanded"
16
  )
17
 
18
+ # Initialize session state
19
+ if 'page_number' not in st.session_state:
20
+ st.session_state.page_number = 0
21
+ if 'analysis_results' not in st.session_state:
22
+ st.session_state.analysis_results = None
23
+ if 'filtered_results' not in st.session_state:
24
+ st.session_state.filtered_results = None
25
+
26
  # Load a pre-trained model for embeddings with HF caching
27
  @st.cache_resource
28
  def load_model():
 
46
  return df[["uuid", "problem", "source", "question_type", "problem_type"]]
47
  except Exception as e:
48
  st.error(f"Error loading dataset: {e}")
 
49
  return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
50
 
51
  # Cache embeddings computation with error handling
 
58
  st.error(f"Error computing embeddings: {e}")
59
  return np.array([])
60
 
 
61
  def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
62
  """Find similar problems using cosine similarity, optimized for speed."""
63
  if df.empty:
64
  return []
65
 
 
66
  embeddings = compute_embeddings(df['problem'].tolist())
67
  if embeddings.size == 0:
68
  return []
 
70
  if progress_bar:
71
  progress_bar.progress(0.33, "Computing similarity matrix...")
72
 
 
73
  similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
74
  if progress_bar:
75
  progress_bar.progress(0.66, "Finding similar pairs...")
76
 
 
77
  num_problems = len(df)
78
  upper_triangle_indices = np.triu_indices(num_problems, k=1)
79
  similarity_scores = similarity_matrix[upper_triangle_indices]
80
 
 
81
  mask = similarity_scores > similarity_threshold
82
  filtered_indices = np.where(mask)[0]
83
 
 
123
  })
124
  return detailed_analysis
125
 
126
+ def apply_filters(results, df, selected_source, selected_qtype):
127
+ """Apply filters to results."""
128
+ filtered = results.copy()
129
+ if selected_source:
130
+ filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["source"].values[0] == selected_source]
131
+ if selected_qtype:
132
+ filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
133
+ return filtered
134
+
135
  def main():
136
  st.title("🔍 Problem Deduplication Explorer")
137
 
 
138
  if model is None:
139
  st.error("Failed to load the model. Please try again later.")
140
  return
141
 
 
 
 
 
142
  # Sidebar configuration
143
  with st.sidebar:
144
  st.header("Settings")
 
173
  )
174
 
175
  # Analysis section
176
+ if st.sidebar.button("Run Deduplication Analysis", type="primary") or st.session_state.analysis_results is not None:
177
+ if st.session_state.analysis_results is None:
178
+ progress_bar = st.progress(0, "Starting analysis...")
179
+ pairs = find_similar_problems(df, similarity_threshold, progress_bar)
180
+ st.session_state.analysis_results = analyze_clusters(df, pairs)
181
 
182
+ results = st.session_state.analysis_results
 
 
183
 
184
  if not results:
185
  st.warning("No similar problems found with the current threshold.")
 
195
  with col2:
196
  selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
197
 
198
+ # Apply filters and store in session state
199
+ filtered_results = apply_filters(results, df, selected_source, selected_qtype)
200
+ st.session_state.filtered_results = filtered_results
 
 
201
 
202
+ if not filtered_results:
203
  st.warning("No results found with the current filters.")
204
  return
205
 
206
  # Pagination
207
+ total_pages = (len(filtered_results) - 1) // items_per_page
208
+ st.session_state.page_number = min(st.session_state.page_number, total_pages)
209
 
210
  col1, col2, col3 = st.columns([1, 3, 1])
211
  with col1:
 
220
  # Display results
221
  start_idx = st.session_state.page_number * items_per_page
222
  end_idx = start_idx + items_per_page
223
+ page_results = filtered_results[start_idx:end_idx]
224
 
225
  for entry in page_results:
226
  with st.container():