Pringled commited on
Commit
adde4af
·
1 Parent(s): 5963317
Files changed (1) hide show
  1. app.py +65 -137
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
- import tqdm
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -30,7 +30,54 @@ def display_word_differences(x: str, y: str) -> str:
30
  diff = ndiff(x.split(), y.split())
31
  return " ".join([word for word in diff if word.startswith(('+', '-'))])
32
 
33
- def perform_deduplication(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  deduplication_type,
35
  dataset1_name,
36
  dataset1_split,
@@ -65,19 +112,12 @@ def perform_deduplication(
65
  # Compute embeddings
66
  status = "Computing embeddings for Dataset 1..."
67
  yield status, ""
68
- embeddings = []
69
- batch_size = 64
70
- total_batches = (len(texts) + batch_size - 1) // batch_size
71
- # Use progress.tqdm without yielding inside the loop
72
- for batch_texts in progress.tqdm(batch_iterable(texts, batch_size), desc="Computing embeddings for Dataset 1", total=total_batches):
73
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
74
- embeddings.append(batch_embeddings)
75
- embedding_matrix = np.concatenate(embeddings, axis=0)
76
 
77
  # Deduplicate
78
  status = "Deduplicating embeddings..."
79
  yield status, ""
80
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(
81
  embedding_matrix, threshold, progress=progress
82
  )
83
 
@@ -110,6 +150,7 @@ def perform_deduplication(
110
  yield status, result_text
111
 
112
  elif deduplication_type == "Cross-dataset":
 
113
  # Load Dataset 1
114
  status = "Loading Dataset 1..."
115
  yield status, ""
@@ -139,28 +180,17 @@ def perform_deduplication(
139
  # Compute embeddings for Dataset 1
140
  status = "Computing embeddings for Dataset 1..."
141
  yield status, ""
142
- embeddings1 = []
143
- batch_size = 64
144
- total_batches1 = (len(texts1) + batch_size - 1) // batch_size
145
- for batch_texts in progress.tqdm(batch_iterable(texts1, batch_size), desc="Computing embeddings for Dataset 1", total=total_batches1):
146
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
147
- embeddings1.append(batch_embeddings)
148
- embedding_matrix1 = np.concatenate(embeddings1, axis=0)
149
 
150
  # Compute embeddings for Dataset 2
151
  status = "Computing embeddings for Dataset 2..."
152
  yield status, ""
153
- embeddings2 = []
154
- total_batches2 = (len(texts2) + batch_size - 1) // batch_size
155
- for batch_texts in progress.tqdm(batch_iterable(texts2, batch_size), desc="Computing embeddings for Dataset 2", total=total_batches2):
156
- batch_embeddings = model.encode(batch_texts, show_progressbar=False)
157
- embeddings2.append(batch_embeddings)
158
- embedding_matrix2 = np.concatenate(embeddings2, axis=0)
159
 
160
  # Deduplicate across datasets
161
  status = "Deduplicating embeddings across datasets..."
162
  yield status, ""
163
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
164
  embedding_matrix1, embedding_matrix2, threshold, progress=progress
165
  )
166
 
@@ -196,132 +226,30 @@ def perform_deduplication(
196
  yield f"An error occurred: {e}", ""
197
  raise e
198
 
199
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
200
- """
201
- Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
202
- """
203
- # Building the index
204
- progress(0, desc="Building search index...")
205
- reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
206
-
207
- deduplicated_indices = set(range(len(embedding_matrix)))
208
- duplicate_to_original_mapping = {}
209
-
210
- # Finding nearest neighbors
211
- progress(0, desc="Finding nearest neighbors...")
212
- results = reach.nearest_neighbor_threshold(
213
- embedding_matrix,
214
- threshold=threshold,
215
- batch_size=batch_size,
216
- show_progressbar=False # Disable internal progress bar
217
- )
218
-
219
- # Processing duplicates with a progress bar
220
- total_items = len(embedding_matrix)
221
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
222
- if i not in deduplicated_indices:
223
- continue
224
-
225
- similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
226
-
227
- for sim_idx in similar_indices:
228
- if sim_idx in deduplicated_indices:
229
- deduplicated_indices.remove(sim_idx)
230
- duplicate_to_original_mapping[sim_idx] = i
231
-
232
- return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
233
-
234
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
235
  """
236
- Deduplicate embeddings across two datasets and return the indices of duplicates between them.
237
  """
238
- # Building the index from Dataset 1
239
  progress(0, desc="Building search index from Dataset 1...")
240
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
241
 
242
  duplicate_indices_in_test = []
243
  duplicate_to_original_mapping = {}
244
 
245
- # Finding nearest neighbors between datasets
246
  progress(0, desc="Finding nearest neighbors between datasets...")
247
- results = reach.nearest_neighbor_threshold(
248
- embedding_matrix_2,
249
- threshold=threshold,
250
- batch_size=batch_size,
251
- show_progressbar=False # Disable internal progress bar
252
- )
253
 
254
  total_items = len(embedding_matrix_2)
255
- # Processing duplicates with a progress bar
256
- for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
257
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
258
 
259
  if similar_indices:
260
  duplicate_indices_in_test.append(i)
261
  duplicate_to_original_mapping[i] = similar_indices[0]
262
 
263
- return duplicate_indices_in_test, duplicate_to_original_mapping
264
-
265
- with gr.Blocks() as demo:
266
- gr.Markdown("# Semantic Deduplication")
267
-
268
- deduplication_type = gr.Radio(
269
- choices=["Single dataset", "Cross-dataset"],
270
- label="Deduplication Type",
271
- value="Single dataset"
272
- )
273
-
274
- with gr.Row():
275
- dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
276
- dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
277
- dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
278
-
279
- dataset2_inputs = gr.Column(visible=False)
280
- with dataset2_inputs:
281
- gr.Markdown("### Dataset 2")
282
- with gr.Row():
283
- dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
284
- dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
285
- dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
286
-
287
- threshold = gr.Slider(
288
- minimum=0.0,
289
- maximum=1.0,
290
- value=default_threshold,
291
- label="Similarity Threshold"
292
- )
293
-
294
- compute_button = gr.Button("Compute")
295
-
296
- status_output = gr.Markdown()
297
- result_output = gr.Markdown()
298
-
299
- # Function to update the visibility of dataset2_inputs
300
- def update_visibility(deduplication_type_value):
301
- if deduplication_type_value == "Cross-dataset":
302
- return gr.update(visible=True)
303
- else:
304
- return gr.update(visible=False)
305
-
306
- deduplication_type.change(
307
- update_visibility,
308
- inputs=deduplication_type,
309
- outputs=dataset2_inputs
310
- )
311
-
312
- compute_button.click(
313
- fn=perform_deduplication,
314
- inputs=[
315
- deduplication_type,
316
- dataset1_name,
317
- dataset1_split,
318
- dataset1_text_column,
319
- dataset2_name,
320
- dataset2_split,
321
- dataset2_text_column,
322
- threshold
323
- ],
324
- outputs=[status_output, result_output]
325
- )
326
-
327
- demo.launch()
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
+ import asyncio
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
 
30
  diff = ndiff(x.split(), y.split())
31
  return " ".join([word for word in diff if word.startswith(('+', '-'))])
32
 
33
+ async def compute_embeddings_async(texts, batch_size, progress, desc):
34
+ embeddings = []
35
+ total_batches = (len(texts) + batch_size - 1) // batch_size
36
+ for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
37
+ batch_embeddings = await asyncio.to_thread(model.encode, batch_texts, show_progressbar=False)
38
+ embeddings.append(batch_embeddings)
39
+ progress((i + 1) / total_batches, desc=desc)
40
+ await asyncio.sleep(0)
41
+ embedding_matrix = np.concatenate(embeddings, axis=0)
42
+ return embedding_matrix
43
+
44
+ async def deduplicate_async(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
45
+ """
46
+ Deduplicate embeddings asynchronously.
47
+ """
48
+ progress(0, desc="Building search index...")
49
+ reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
50
+
51
+ deduplicated_indices = set(range(len(embedding_matrix)))
52
+ duplicate_to_original_mapping = {}
53
+
54
+ progress(0, desc="Finding nearest neighbors...")
55
+ results = await asyncio.to_thread(reach.nearest_neighbor_threshold,
56
+ embedding_matrix,
57
+ threshold=threshold,
58
+ batch_size=batch_size,
59
+ show_progressbar=False)
60
+
61
+ total_items = len(embedding_matrix)
62
+ for i, similar_items in enumerate(results):
63
+ if i not in deduplicated_indices:
64
+ continue
65
+
66
+ similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
67
+
68
+ for sim_idx in similar_indices:
69
+ if sim_idx in deduplicated_indices:
70
+ deduplicated_indices.remove(sim_idx)
71
+ duplicate_to_original_mapping[sim_idx] = i
72
+
73
+ if i % 100 == 0:
74
+ progress(i / total_items, desc="Processing duplicates")
75
+ await asyncio.sleep(0)
76
+
77
+ progress(1, desc="Processing duplicates")
78
+ return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
79
+
80
+ async def perform_deduplication(
81
  deduplication_type,
82
  dataset1_name,
83
  dataset1_split,
 
112
  # Compute embeddings
113
  status = "Computing embeddings for Dataset 1..."
114
  yield status, ""
115
+ embedding_matrix = await compute_embeddings_async(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
 
 
 
 
 
 
116
 
117
  # Deduplicate
118
  status = "Deduplicating embeddings..."
119
  yield status, ""
120
+ deduplicated_indices, duplicate_to_original_mapping = await deduplicate_async(
121
  embedding_matrix, threshold, progress=progress
122
  )
123
 
 
150
  yield status, result_text
151
 
152
  elif deduplication_type == "Cross-dataset":
153
+ # Similar code for cross-dataset deduplication, using async functions
154
  # Load Dataset 1
155
  status = "Loading Dataset 1..."
156
  yield status, ""
 
180
  # Compute embeddings for Dataset 1
181
  status = "Computing embeddings for Dataset 1..."
182
  yield status, ""
183
+ embedding_matrix1 = await compute_embeddings_async(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
 
 
 
 
 
 
184
 
185
  # Compute embeddings for Dataset 2
186
  status = "Computing embeddings for Dataset 2..."
187
  yield status, ""
188
+ embedding_matrix2 = await compute_embeddings_async(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
 
 
 
 
 
189
 
190
  # Deduplicate across datasets
191
  status = "Deduplicating embeddings across datasets..."
192
  yield status, ""
193
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = await deduplicate_across_datasets_async(
194
  embedding_matrix1, embedding_matrix2, threshold, progress=progress
195
  )
196
 
 
226
  yield f"An error occurred: {e}", ""
227
  raise e
228
 
229
+ async def deduplicate_across_datasets_async(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  """
231
+ Deduplicate embeddings across two datasets asynchronously.
232
  """
 
233
  progress(0, desc="Building search index from Dataset 1...")
234
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
235
 
236
  duplicate_indices_in_test = []
237
  duplicate_to_original_mapping = {}
238
 
 
239
  progress(0, desc="Finding nearest neighbors between datasets...")
240
+ results = await asyncio.to_thread(reach.nearest_neighbor_threshold,
241
+ embedding_matrix_2,
242
+ threshold=threshold,
243
+ batch_size=batch_size,
244
+ show_progressbar=False)
 
245
 
246
  total_items = len(embedding_matrix_2)
247
+ for i, similar_items in enumerate(results):
 
248
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
249
 
250
  if similar_indices:
251
  duplicate_indices_in_test.append(i)
252
  duplicate_to_original_mapping[i] = similar_indices[0]
253
 
254
+ if i % 100 == 0:
255
+ progress(i / total_items, desc="Processing duplicates across datasets")