Spaces:
Sleeping
Sleeping
fix filtering
Browse files
app.py
CHANGED
@@ -15,6 +15,14 @@ st.set_page_config(
|
|
15 |
initial_sidebar_state="expanded"
|
16 |
)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Load a pre-trained model for embeddings with HF caching
|
19 |
@st.cache_resource
|
20 |
def load_model():
|
@@ -38,7 +46,6 @@ def load_data():
|
|
38 |
return df[["uuid", "problem", "source", "question_type", "problem_type"]]
|
39 |
except Exception as e:
|
40 |
st.error(f"Error loading dataset: {e}")
|
41 |
-
# Return empty DataFrame with correct columns if loading fails
|
42 |
return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
|
43 |
|
44 |
# Cache embeddings computation with error handling
|
@@ -51,13 +58,11 @@ def compute_embeddings(problems):
|
|
51 |
st.error(f"Error computing embeddings: {e}")
|
52 |
return np.array([])
|
53 |
|
54 |
-
# ================== FUNCTION DEFINITIONS ==================
|
55 |
def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
56 |
"""Find similar problems using cosine similarity, optimized for speed."""
|
57 |
if df.empty:
|
58 |
return []
|
59 |
|
60 |
-
# Compute embeddings with progress tracking
|
61 |
embeddings = compute_embeddings(df['problem'].tolist())
|
62 |
if embeddings.size == 0:
|
63 |
return []
|
@@ -65,17 +70,14 @@ def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
|
65 |
if progress_bar:
|
66 |
progress_bar.progress(0.33, "Computing similarity matrix...")
|
67 |
|
68 |
-
# Compute similarity matrix
|
69 |
similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
|
70 |
if progress_bar:
|
71 |
progress_bar.progress(0.66, "Finding similar pairs...")
|
72 |
|
73 |
-
# Use numpy operations for better performance
|
74 |
num_problems = len(df)
|
75 |
upper_triangle_indices = np.triu_indices(num_problems, k=1)
|
76 |
similarity_scores = similarity_matrix[upper_triangle_indices]
|
77 |
|
78 |
-
# Filter based on threshold
|
79 |
mask = similarity_scores > similarity_threshold
|
80 |
filtered_indices = np.where(mask)[0]
|
81 |
|
@@ -121,19 +123,22 @@ def analyze_clusters(_df, pairs):
|
|
121 |
})
|
122 |
return detailed_analysis
|
123 |
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def main():
|
126 |
st.title("🔍 Problem Deduplication Explorer")
|
127 |
|
128 |
-
# Check if model loaded successfully
|
129 |
if model is None:
|
130 |
st.error("Failed to load the model. Please try again later.")
|
131 |
return
|
132 |
|
133 |
-
# Initialize session state for pagination
|
134 |
-
if 'page_number' not in st.session_state:
|
135 |
-
st.session_state.page_number = 0
|
136 |
-
|
137 |
# Sidebar configuration
|
138 |
with st.sidebar:
|
139 |
st.header("Settings")
|
@@ -168,12 +173,13 @@ def main():
|
|
168 |
)
|
169 |
|
170 |
# Analysis section
|
171 |
-
if st.sidebar.button("Run Deduplication Analysis", type="primary"):
|
172 |
-
|
|
|
|
|
|
|
173 |
|
174 |
-
|
175 |
-
pairs = find_similar_problems(df, similarity_threshold, progress_bar)
|
176 |
-
results = analyze_clusters(df, pairs)
|
177 |
|
178 |
if not results:
|
179 |
st.warning("No similar problems found with the current threshold.")
|
@@ -189,18 +195,17 @@ def main():
|
|
189 |
with col2:
|
190 |
selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
|
191 |
|
192 |
-
# Apply filters
|
193 |
-
|
194 |
-
|
195 |
-
if selected_qtype:
|
196 |
-
results = [r for r in results if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
|
197 |
|
198 |
-
if not
|
199 |
st.warning("No results found with the current filters.")
|
200 |
return
|
201 |
|
202 |
# Pagination
|
203 |
-
total_pages = len(
|
|
|
204 |
|
205 |
col1, col2, col3 = st.columns([1, 3, 1])
|
206 |
with col1:
|
@@ -215,7 +220,7 @@ def main():
|
|
215 |
# Display results
|
216 |
start_idx = st.session_state.page_number * items_per_page
|
217 |
end_idx = start_idx + items_per_page
|
218 |
-
page_results =
|
219 |
|
220 |
for entry in page_results:
|
221 |
with st.container():
|
|
|
15 |
initial_sidebar_state="expanded"
|
16 |
)
|
17 |
|
18 |
+
# Initialize session state
|
19 |
+
if 'page_number' not in st.session_state:
|
20 |
+
st.session_state.page_number = 0
|
21 |
+
if 'analysis_results' not in st.session_state:
|
22 |
+
st.session_state.analysis_results = None
|
23 |
+
if 'filtered_results' not in st.session_state:
|
24 |
+
st.session_state.filtered_results = None
|
25 |
+
|
26 |
# Load a pre-trained model for embeddings with HF caching
|
27 |
@st.cache_resource
|
28 |
def load_model():
|
|
|
46 |
return df[["uuid", "problem", "source", "question_type", "problem_type"]]
|
47 |
except Exception as e:
|
48 |
st.error(f"Error loading dataset: {e}")
|
|
|
49 |
return pd.DataFrame(columns=["uuid", "problem", "source", "question_type", "problem_type"])
|
50 |
|
51 |
# Cache embeddings computation with error handling
|
|
|
58 |
st.error(f"Error computing embeddings: {e}")
|
59 |
return np.array([])
|
60 |
|
|
|
61 |
def find_similar_problems(df, similarity_threshold=0.9, progress_bar=None):
|
62 |
"""Find similar problems using cosine similarity, optimized for speed."""
|
63 |
if df.empty:
|
64 |
return []
|
65 |
|
|
|
66 |
embeddings = compute_embeddings(df['problem'].tolist())
|
67 |
if embeddings.size == 0:
|
68 |
return []
|
|
|
70 |
if progress_bar:
|
71 |
progress_bar.progress(0.33, "Computing similarity matrix...")
|
72 |
|
|
|
73 |
similarity_matrix = util.cos_sim(embeddings, embeddings).numpy()
|
74 |
if progress_bar:
|
75 |
progress_bar.progress(0.66, "Finding similar pairs...")
|
76 |
|
|
|
77 |
num_problems = len(df)
|
78 |
upper_triangle_indices = np.triu_indices(num_problems, k=1)
|
79 |
similarity_scores = similarity_matrix[upper_triangle_indices]
|
80 |
|
|
|
81 |
mask = similarity_scores > similarity_threshold
|
82 |
filtered_indices = np.where(mask)[0]
|
83 |
|
|
|
123 |
})
|
124 |
return detailed_analysis
|
125 |
|
126 |
+
def apply_filters(results, df, selected_source, selected_qtype):
|
127 |
+
"""Apply filters to results."""
|
128 |
+
filtered = results.copy()
|
129 |
+
if selected_source:
|
130 |
+
filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["source"].values[0] == selected_source]
|
131 |
+
if selected_qtype:
|
132 |
+
filtered = [r for r in filtered if df[df["uuid"] == r["base_uuid"]]["question_type"].values[0] == selected_qtype]
|
133 |
+
return filtered
|
134 |
+
|
135 |
def main():
|
136 |
st.title("🔍 Problem Deduplication Explorer")
|
137 |
|
|
|
138 |
if model is None:
|
139 |
st.error("Failed to load the model. Please try again later.")
|
140 |
return
|
141 |
|
|
|
|
|
|
|
|
|
142 |
# Sidebar configuration
|
143 |
with st.sidebar:
|
144 |
st.header("Settings")
|
|
|
173 |
)
|
174 |
|
175 |
# Analysis section
|
176 |
+
if st.sidebar.button("Run Deduplication Analysis", type="primary") or st.session_state.analysis_results is not None:
|
177 |
+
if st.session_state.analysis_results is None:
|
178 |
+
progress_bar = st.progress(0, "Starting analysis...")
|
179 |
+
pairs = find_similar_problems(df, similarity_threshold, progress_bar)
|
180 |
+
st.session_state.analysis_results = analyze_clusters(df, pairs)
|
181 |
|
182 |
+
results = st.session_state.analysis_results
|
|
|
|
|
183 |
|
184 |
if not results:
|
185 |
st.warning("No similar problems found with the current threshold.")
|
|
|
195 |
with col2:
|
196 |
selected_qtype = st.selectbox("Filter by Question Type", [None] + question_types)
|
197 |
|
198 |
+
# Apply filters and store in session state
|
199 |
+
filtered_results = apply_filters(results, df, selected_source, selected_qtype)
|
200 |
+
st.session_state.filtered_results = filtered_results
|
|
|
|
|
201 |
|
202 |
+
if not filtered_results:
|
203 |
st.warning("No results found with the current filters.")
|
204 |
return
|
205 |
|
206 |
# Pagination
|
207 |
+
total_pages = (len(filtered_results) - 1) // items_per_page
|
208 |
+
st.session_state.page_number = min(st.session_state.page_number, total_pages)
|
209 |
|
210 |
col1, col2, col3 = st.columns([1, 3, 1])
|
211 |
with col1:
|
|
|
220 |
# Display results
|
221 |
start_idx = st.session_state.page_number * items_per_page
|
222 |
end_idx = start_idx + items_per_page
|
223 |
+
page_results = filtered_results[start_idx:end_idx]
|
224 |
|
225 |
for entry in page_results:
|
226 |
with st.container():
|