Pringled commited on
Commit
24f7d5b
·
1 Parent(s): 1282e7b
Files changed (1) hide show
  1. app.py +261 -17
app.py CHANGED
@@ -20,8 +20,17 @@ def deduplicate_embeddings(
20
  threshold: float = 0.9,
21
  batch_size: int = 1024,
22
  progress=None
23
- ):
24
- """Deduplicate within one dataset or across two datasets."""
 
 
 
 
 
 
 
 
 
25
  if embeddings_b is None:
26
  reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
27
  duplicate_to_original = {}
@@ -49,26 +58,53 @@ def deduplicate_embeddings(
49
  return duplicate_indices_in_b, duplicate_to_original
50
 
51
  def display_word_differences(x: str, y: str) -> str:
52
- """Display differences between two texts."""
 
 
 
 
 
 
53
  diff = ndiff(x.split(), y.split())
54
  return " ".join(word for word in diff if word.startswith(("+", "-")))
55
 
56
- def load_dataset_texts(dataset_name, dataset_split, text_column):
57
- """Load texts from a specified dataset."""
 
 
 
 
 
 
 
58
  ds = load_dataset(dataset_name, split=dataset_split)
59
  return [example[text_column] for example in ds]
60
 
61
  def perform_deduplication(
62
- deduplication_type,
63
- dataset1_name,
64
- dataset1_split,
65
- dataset1_text_column,
66
- dataset2_name="",
67
- dataset2_split="",
68
- dataset2_text_column="",
69
- threshold=default_threshold,
70
- progress=gr.Progress(track_tqdm=True),
71
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
  threshold = float(threshold)
74
 
@@ -76,8 +112,8 @@ def perform_deduplication(
76
  yield "Loading Dataset 1...", ""
77
  texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
78
  yield "Computing embeddings for Dataset 1...", ""
79
- #embeddings1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Dataset 1 embeddings")
80
  embeddings1 = model.encode(texts1, show_progressbar=True)
 
81
  if deduplication_type == "Single dataset":
82
  # Deduplicate within Dataset 1
83
  yield "Deduplicating within Dataset 1...", ""
@@ -114,8 +150,8 @@ def perform_deduplication(
114
  yield "Loading Dataset 2...", ""
115
  texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
116
  yield "Computing embeddings for Dataset 2...", ""
117
- #embeddings2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Dataset 2 embeddings")
118
  embeddings2 = model.encode(texts2, show_progressbar=True)
 
119
  # Deduplicate Dataset 2 against Dataset 1
120
  yield "Deduplicating Dataset 2 against Dataset 1...", ""
121
  duplicate_indices, duplicate_mapping = deduplicate_embeddings(
@@ -152,6 +188,12 @@ def perform_deduplication(
152
 
153
  with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
154
  gr.Markdown("# Semantic Deduplication")
 
 
 
 
 
 
155
 
156
  deduplication_type = gr.Radio(
157
  choices=["Single dataset", "Cross-dataset"],
@@ -177,7 +219,7 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
177
  status_output = gr.Markdown(elem_id="status_output")
178
  result_output = gr.Markdown()
179
 
180
- def update_visibility(choice):
181
  return gr.update(visible=choice == "Cross-dataset")
182
 
183
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
@@ -198,3 +240,205 @@ with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
198
  )
199
 
200
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  threshold: float = 0.9,
21
  batch_size: int = 1024,
22
  progress=None
23
+ ) -> tuple[np.ndarray, dict[int, int]]:
24
+ """
25
+ Deduplicate embeddings within one dataset or across two datasets.
26
+
27
+ :param embeddings_a: Embeddings of Dataset 1.
28
+ :param embeddings_b: Optional, embeddings of Dataset 2.
29
+ :param threshold: Similarity threshold for deduplication.
30
+ :param batch_size: Batch size for similarity computation.
31
+ :param progress: Gradio progress tracker for feedback.
32
+ :return: Deduplicated indices and a mapping of removed indices to their original counterparts.
33
+ """
34
  if embeddings_b is None:
35
  reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
36
  duplicate_to_original = {}
 
58
  return duplicate_indices_in_b, duplicate_to_original
59
 
60
  def display_word_differences(x: str, y: str) -> str:
61
+ """
62
+ Display the word-level differences between two texts.
63
+
64
+ :param x: First text.
65
+ :param y: Second text.
66
+ :return: A string showing word-level differences.
67
+ """
68
  diff = ndiff(x.split(), y.split())
69
  return " ".join(word for word in diff if word.startswith(("+", "-")))
70
 
71
+ def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
72
+ """
73
+ Load texts from a specified dataset and split.
74
+
75
+ :param dataset_name: Name of the dataset.
76
+ :param dataset_split: Split of the dataset (e.g., 'train', 'validation').
77
+ :param text_column: Name of the text column.
78
+ :return: A list of texts from the dataset.
79
+ """
80
  ds = load_dataset(dataset_name, split=dataset_split)
81
  return [example[text_column] for example in ds]
82
 
83
  def perform_deduplication(
84
+ deduplication_type: str,
85
+ dataset1_name: str,
86
+ dataset1_split: str,
87
+ dataset1_text_column: str,
88
+ dataset2_name: str = "",
89
+ dataset2_split: str = "",
90
+ dataset2_text_column: str = "",
91
+ threshold: float = default_threshold,
92
+ progress: gr.Progress = gr.Progress(track_tqdm=True)
93
  ):
94
+ """
95
+ Perform deduplication on one or two datasets based on the deduplication type.
96
+
97
+ :param deduplication_type: 'Single dataset' or 'Cross-dataset'.
98
+ :param dataset1_name: Name of the first dataset.
99
+ :param dataset1_split: Split of the first dataset.
100
+ :param dataset1_text_column: Text column of the first dataset.
101
+ :param dataset2_name: Optional, name of the second dataset (for cross-dataset deduplication).
102
+ :param dataset2_split: Optional, split of the second dataset.
103
+ :param dataset2_text_column: Optional, text column of the second dataset.
104
+ :param threshold: Similarity threshold for deduplication.
105
+ :param progress: Gradio progress tracker.
106
+ :return: Status updates and result text for the Gradio interface.
107
+ """
108
  try:
109
  threshold = float(threshold)
110
 
 
112
  yield "Loading Dataset 1...", ""
113
  texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
114
  yield "Computing embeddings for Dataset 1...", ""
 
115
  embeddings1 = model.encode(texts1, show_progressbar=True)
116
+
117
  if deduplication_type == "Single dataset":
118
  # Deduplicate within Dataset 1
119
  yield "Deduplicating within Dataset 1...", ""
 
150
  yield "Loading Dataset 2...", ""
151
  texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
152
  yield "Computing embeddings for Dataset 2...", ""
 
153
  embeddings2 = model.encode(texts2, show_progressbar=True)
154
+
155
  # Deduplicate Dataset 2 against Dataset 1
156
  yield "Deduplicating Dataset 2 against Dataset 1...", ""
157
  duplicate_indices, duplicate_mapping = deduplicate_embeddings(
 
188
 
189
  with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
190
  gr.Markdown("# Semantic Deduplication")
191
+ gr.Markdown("""
192
+ This demo showcases semantic deduplication using Model2Vec.
193
+ It can be used to identify duplicate texts within a single dataset or across two datasets.
194
+ You can adjust the similarity threshold to control the strictness of the deduplication.
195
+ NOTE: this demo runs on a free CPU backend, so it may be slow for large datasets. For faster results, please run the code locally.
196
+ """)
197
 
198
  deduplication_type = gr.Radio(
199
  choices=["Single dataset", "Cross-dataset"],
 
219
  status_output = gr.Markdown(elem_id="status_output")
220
  result_output = gr.Markdown()
221
 
222
+ def update_visibility(choice: str):
223
  return gr.update(visible=choice == "Cross-dataset")
224
 
225
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
 
240
  )
241
 
242
  demo.launch()
243
+
244
+
245
+ # import gradio as gr
246
+ # from datasets import load_dataset
247
+ # import numpy as np
248
+ # from model2vec import StaticModel
249
+ # from reach import Reach
250
+ # from difflib import ndiff
251
+
252
+ # # Load the model
253
+ # model = StaticModel.from_pretrained("minishlab/M2V_base_output")
254
+
255
+ # # Default parameters
256
+ # default_dataset_name = "sst2"
257
+ # default_dataset_split = "train"
258
+ # default_text_column = "sentence"
259
+ # default_threshold = 0.9
260
+
261
+ # def deduplicate_embeddings(
262
+ # embeddings_a: np.ndarray,
263
+ # embeddings_b: np.ndarray = None,
264
+ # threshold: float = 0.9,
265
+ # batch_size: int = 1024,
266
+ # progress=None
267
+ # ):
268
+ # """Deduplicate within one dataset or across two datasets."""
269
+ # if embeddings_b is None:
270
+ # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
271
+ # duplicate_to_original = {}
272
+ # results = reach.nearest_neighbor_threshold(
273
+ # embeddings_a, threshold=threshold, batch_size=batch_size, show_progressbar=False
274
+ # )
275
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_a))):
276
+ # for sim_idx, _ in similar_items:
277
+ # sim_idx = int(sim_idx)
278
+ # if sim_idx != i and sim_idx not in duplicate_to_original:
279
+ # duplicate_to_original[sim_idx] = i
280
+ # deduplicated_indices = set(range(len(embeddings_a))) - set(duplicate_to_original.keys())
281
+ # return deduplicated_indices, duplicate_to_original
282
+ # else:
283
+ # reach = Reach(vectors=embeddings_a, items=[str(i) for i in range(len(embeddings_a))])
284
+ # duplicate_indices_in_b = []
285
+ # duplicate_to_original = {}
286
+ # results = reach.nearest_neighbor_threshold(
287
+ # embeddings_b, threshold=threshold, batch_size=batch_size, show_progressbar=False
288
+ # )
289
+ # for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=len(embeddings_b))):
290
+ # if similar_items:
291
+ # duplicate_indices_in_b.append(i)
292
+ # duplicate_to_original[i] = int(similar_items[0][0])
293
+ # return duplicate_indices_in_b, duplicate_to_original
294
+
295
+ # def display_word_differences(x: str, y: str) -> str:
296
+ # """Display differences between two texts."""
297
+ # diff = ndiff(x.split(), y.split())
298
+ # return " ".join(word for word in diff if word.startswith(("+", "-")))
299
+
300
+ # def load_dataset_texts(dataset_name, dataset_split, text_column):
301
+ # """Load texts from a specified dataset."""
302
+ # ds = load_dataset(dataset_name, split=dataset_split)
303
+ # return [example[text_column] for example in ds]
304
+
305
+ # def perform_deduplication(
306
+ # deduplication_type,
307
+ # dataset1_name,
308
+ # dataset1_split,
309
+ # dataset1_text_column,
310
+ # dataset2_name="",
311
+ # dataset2_split="",
312
+ # dataset2_text_column="",
313
+ # threshold=default_threshold,
314
+ # progress=gr.Progress(track_tqdm=True),
315
+ # ):
316
+ # try:
317
+ # threshold = float(threshold)
318
+
319
+ # # Load and process Dataset 1
320
+ # yield "Loading Dataset 1...", ""
321
+ # texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
322
+ # yield "Computing embeddings for Dataset 1...", ""
323
+ # #embeddings1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Dataset 1 embeddings")
324
+ # embeddings1 = model.encode(texts1, show_progressbar=True)
325
+ # if deduplication_type == "Single dataset":
326
+ # # Deduplicate within Dataset 1
327
+ # yield "Deduplicating within Dataset 1...", ""
328
+ # deduplicated_indices, duplicate_mapping = deduplicate_embeddings(
329
+ # embeddings1, threshold=threshold, progress=progress
330
+ # )
331
+
332
+ # num_duplicates = len(duplicate_mapping)
333
+ # result_text = (
334
+ # f"**Total documents:** {len(texts1)}\n\n"
335
+ # f"**Duplicates found:** {num_duplicates}\n\n"
336
+ # f"**Unique documents after deduplication:** {len(deduplicated_indices)}\n\n"
337
+ # )
338
+
339
+ # if num_duplicates > 0:
340
+ # result_text += "**Sample duplicates:**\n\n"
341
+ # for dup_idx, orig_idx in list(duplicate_mapping.items())[:5]:
342
+ # orig_text = texts1[orig_idx]
343
+ # dup_text = texts1[dup_idx]
344
+ # differences = display_word_differences(orig_text, dup_text)
345
+ # result_text += (
346
+ # f"**Original:**\n{orig_text}\n\n"
347
+ # f"**Duplicate:**\n{dup_text}\n\n"
348
+ # f"**Differences:**\n{differences}\n"
349
+ # + "-" * 50 + "\n\n"
350
+ # )
351
+ # else:
352
+ # result_text += "No duplicates found."
353
+
354
+ # yield "Deduplication completed.", result_text
355
+
356
+ # else:
357
+ # # Load and process Dataset 2
358
+ # yield "Loading Dataset 2...", ""
359
+ # texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
360
+ # yield "Computing embeddings for Dataset 2...", ""
361
+ # #embeddings2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Dataset 2 embeddings")
362
+ # embeddings2 = model.encode(texts2, show_progressbar=True)
363
+ # # Deduplicate Dataset 2 against Dataset 1
364
+ # yield "Deduplicating Dataset 2 against Dataset 1...", ""
365
+ # duplicate_indices, duplicate_mapping = deduplicate_embeddings(
366
+ # embeddings1, embeddings_b=embeddings2, threshold=threshold, progress=progress
367
+ # )
368
+
369
+ # num_duplicates = len(duplicate_indices)
370
+ # result_text = (
371
+ # f"**Total documents in {dataset2_name}/{dataset2_split}:** {len(texts2)}\n\n"
372
+ # f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
373
+ # f"**Unique documents after deduplication:** {len(texts2) - num_duplicates}\n\n"
374
+ # )
375
+
376
+ # if num_duplicates > 0:
377
+ # result_text += "**Sample duplicates from Dataset 2:**\n\n"
378
+ # for idx in duplicate_indices[:5]:
379
+ # orig_text = texts1[duplicate_mapping[idx]]
380
+ # dup_text = texts2[idx]
381
+ # differences = display_word_differences(orig_text, dup_text)
382
+ # result_text += (
383
+ # f"**Original (Dataset 1):**\n{orig_text}\n\n"
384
+ # f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
385
+ # f"**Differences:**\n{differences}\n"
386
+ # + "-" * 50 + "\n\n"
387
+ # )
388
+ # else:
389
+ # result_text += "No duplicates found."
390
+
391
+ # yield "Deduplication completed.", result_text
392
+
393
+ # except Exception as e:
394
+ # yield f"An error occurred: {e}", ""
395
+ # raise e
396
+
397
+ # with gr.Blocks(css="#status_output { height: 50px; overflow: auto; }") as demo:
398
+ # gr.Markdown("# Semantic Deduplication")
399
+
400
+ # deduplication_type = gr.Radio(
401
+ # choices=["Single dataset", "Cross-dataset"],
402
+ # label="Deduplication Type",
403
+ # value="Single dataset",
404
+ # )
405
+
406
+ # with gr.Row():
407
+ # dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
408
+ # dataset1_split = gr.Textbox(value=default_dataset_split, label="Dataset 1 Split")
409
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
410
+
411
+ # dataset2_inputs = gr.Column(visible=False)
412
+ # with dataset2_inputs:
413
+ # gr.Markdown("### Dataset 2")
414
+ # with gr.Row():
415
+ # dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
416
+ # dataset2_split = gr.Textbox(value=default_dataset_split, label="Dataset 2 Split")
417
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
418
+
419
+ # threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
420
+ # compute_button = gr.Button("Compute")
421
+ # status_output = gr.Markdown(elem_id="status_output")
422
+ # result_output = gr.Markdown()
423
+
424
+ # def update_visibility(choice):
425
+ # return gr.update(visible=choice == "Cross-dataset")
426
+
427
+ # deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
428
+
429
+ # compute_button.click(
430
+ # fn=perform_deduplication,
431
+ # inputs=[
432
+ # deduplication_type,
433
+ # dataset1_name,
434
+ # dataset1_split,
435
+ # dataset1_text_column,
436
+ # dataset2_name,
437
+ # dataset2_split,
438
+ # dataset2_text_column,
439
+ # threshold,
440
+ # ],
441
+ # outputs=[status_output, result_output],
442
+ # )
443
+
444
+ # demo.launch()