Pringled commited on
Commit
6b0e834
·
1 Parent(s): 73b7a75

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +35 -47
app.py CHANGED
@@ -4,9 +4,7 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
7
- import sys
8
  import tqdm
9
- from tqdm.utils import format_interval, format_num, format_sizeof
10
 
11
  # Load the model at startup
12
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -23,26 +21,41 @@ default_threshold = 0.9
23
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
24
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
25
 
26
- def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
29
  """
30
  # Building the index
 
31
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
32
 
33
  deduplicated_indices = set(range(len(embedding_matrix)))
34
  duplicate_to_original_mapping = {}
35
 
36
  # Finding nearest neighbors
 
37
  results = reach.nearest_neighbor_threshold(
38
  embedding_matrix,
39
  threshold=threshold,
40
  batch_size=batch_size,
41
- show_progressbar=True # Allow internal progress bar
42
  )
43
 
44
- # Processing duplicates
45
- for i, similar_items in enumerate(results):
 
46
  if i not in deduplicated_indices:
47
  continue
48
 
@@ -55,26 +68,29 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
55
 
56
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
57
 
58
- def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
59
  """
60
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
61
  """
62
  # Building the index from Dataset 1
 
63
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
64
 
65
  duplicate_indices_in_test = []
66
  duplicate_to_original_mapping = {}
67
 
68
  # Finding nearest neighbors between datasets
 
69
  results = reach.nearest_neighbor_threshold(
70
  embedding_matrix_2,
71
  threshold=threshold,
72
  batch_size=batch_size,
73
- show_progressbar=True # Allow internal progress bar
74
  )
75
 
76
- # Processing duplicates
77
- for i, similar_items in enumerate(results):
 
78
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
79
 
80
  if similar_indices:
@@ -98,31 +114,6 @@ def perform_deduplication(
98
  threshold=default_threshold,
99
  progress=gr.Progress(track_tqdm=True)
100
  ):
101
- # Custom tqdm class that wraps progress.tqdm and includes module-level attributes
102
- class TqdmWrapper(tqdm.std.tqdm):
103
- def __init__(self, *args, **kwargs):
104
- super().__init__(*args, **kwargs)
105
-
106
- # Copy module-level attributes from original tqdm module
107
- TqdmWrapper.format_interval = staticmethod(format_interval)
108
- TqdmWrapper.format_num = staticmethod(format_num)
109
- TqdmWrapper.format_sizeof = staticmethod(format_sizeof)
110
-
111
- # Monkey-patch tqdm.tqdm with our wrapper
112
- original_tqdm_tqdm = tqdm.tqdm
113
- tqdm.tqdm = progress.tqdm
114
-
115
- # Monkey-patch model2vec's tqdm reference if needed
116
- import model2vec.model
117
- if hasattr(model2vec.model, 'tqdm'):
118
- original_model2vec_tqdm = model2vec.model.tqdm
119
- model2vec.model.tqdm = TqdmWrapper
120
-
121
- # Monkey-patch reach's tqdm reference if needed
122
- if hasattr(Reach, 'tqdm'):
123
- original_reach_tqdm = Reach.tqdm
124
- Reach.tqdm = TqdmWrapper
125
-
126
  try:
127
  # Convert threshold to float
128
  threshold = float(threshold)
@@ -147,13 +138,13 @@ def perform_deduplication(
147
  # Compute embeddings
148
  status = "Computing embeddings for Dataset 1..."
149
  yield status, ""
150
- embedding_matrix = model.encode(texts, show_progressbar=True) # Enable internal progress bar
151
 
152
  # Deduplicate
153
  status = "Deduplicating embeddings..."
154
  yield status, ""
155
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
156
- embedding_matrix, threshold
157
  )
158
 
159
  # Prepare the results
@@ -214,18 +205,18 @@ def perform_deduplication(
214
  # Compute embeddings for Dataset 1
215
  status = "Computing embeddings for Dataset 1..."
216
  yield status, ""
217
- embedding_matrix1 = model.encode(texts1, show_progressbar=True)
218
 
219
  # Compute embeddings for Dataset 2
220
  status = "Computing embeddings for Dataset 2..."
221
  yield status, ""
222
- embedding_matrix2 = model.encode(texts2, show_progressbar=True)
223
 
224
  # Deduplicate across datasets
225
  status = "Deduplicating embeddings across datasets..."
226
  yield status, ""
227
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
228
- embedding_matrix1, embedding_matrix2, threshold
229
  )
230
 
231
  num_duplicates = len(duplicate_indices_in_ds2)
@@ -256,13 +247,9 @@ def perform_deduplication(
256
  status = "Deduplication completed."
257
  yield status, result_text
258
 
259
- finally:
260
- # Restore original tqdm functions
261
- tqdm.tqdm = original_tqdm_tqdm
262
- if hasattr(model2vec.model, 'tqdm'):
263
- model2vec.model.tqdm = original_model2vec_tqdm
264
- if hasattr(Reach, 'tqdm'):
265
- Reach.tqdm = original_reach_tqdm
266
 
267
  with gr.Blocks() as demo:
268
  gr.Markdown("# Semantic Deduplication")
@@ -330,6 +317,7 @@ demo.launch()
330
 
331
 
332
 
 
333
  # import gradio as gr
334
  # from datasets import load_dataset
335
  # import numpy as np
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from difflib import ndiff
 
7
  import tqdm
 
8
 
9
  # Load the model at startup
10
  model = StaticModel.from_pretrained("minishlab/M2V_base_output")
 
21
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
22
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
23
 
24
+ def batch_iterable(iterable, batch_size):
25
+ """Helper function to create batches from an iterable."""
26
+ for i in range(0, len(iterable), batch_size):
27
+ yield iterable[i:i + batch_size]
28
+
29
+ def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
30
+ embeddings = []
31
+ for batch in progress.tqdm(batch_iterable(texts, batch_size), total=(len(texts) + batch_size - 1) // batch_size, desc=desc):
32
+ batch_embeddings = model.encode(batch, show_progressbar=False)
33
+ embeddings.append(batch_embeddings)
34
+ return np.concatenate(embeddings, axis=0)
35
+
36
+ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[np.ndarray, dict[int, int]]:
37
  """
38
  Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
39
  """
40
  # Building the index
41
+ progress(0, desc="Building search index...")
42
  reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
43
 
44
  deduplicated_indices = set(range(len(embedding_matrix)))
45
  duplicate_to_original_mapping = {}
46
 
47
  # Finding nearest neighbors
48
+ progress(0, desc="Finding nearest neighbors...")
49
  results = reach.nearest_neighbor_threshold(
50
  embedding_matrix,
51
  threshold=threshold,
52
  batch_size=batch_size,
53
+ show_progressbar=False # Disable internal progress bar
54
  )
55
 
56
+ # Processing duplicates with a progress bar
57
+ total_items = len(embedding_matrix)
58
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates", total=total_items)):
59
  if i not in deduplicated_indices:
60
  continue
61
 
 
68
 
69
  return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
70
 
71
+ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024, progress=None) -> tuple[list[int], dict[int, int]]:
72
  """
73
  Deduplicate embeddings across two datasets and return the indices of duplicates between them.
74
  """
75
  # Building the index from Dataset 1
76
+ progress(0, desc="Building search index from Dataset 1...")
77
  reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
78
 
79
  duplicate_indices_in_test = []
80
  duplicate_to_original_mapping = {}
81
 
82
  # Finding nearest neighbors between datasets
83
+ progress(0, desc="Finding nearest neighbors between datasets...")
84
  results = reach.nearest_neighbor_threshold(
85
  embedding_matrix_2,
86
  threshold=threshold,
87
  batch_size=batch_size,
88
+ show_progressbar=False # Disable internal progress bar
89
  )
90
 
91
+ total_items = len(embedding_matrix_2)
92
+ # Processing duplicates with a progress bar
93
+ for i, similar_items in enumerate(progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)):
94
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
95
 
96
  if similar_indices:
 
114
  threshold=default_threshold,
115
  progress=gr.Progress(track_tqdm=True)
116
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  try:
118
  # Convert threshold to float
119
  threshold = float(threshold)
 
138
  # Compute embeddings
139
  status = "Computing embeddings for Dataset 1..."
140
  yield status, ""
141
+ embedding_matrix = compute_embeddings(texts, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
142
 
143
  # Deduplicate
144
  status = "Deduplicating embeddings..."
145
  yield status, ""
146
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
147
+ embedding_matrix, threshold, progress=progress
148
  )
149
 
150
  # Prepare the results
 
205
  # Compute embeddings for Dataset 1
206
  status = "Computing embeddings for Dataset 1..."
207
  yield status, ""
208
+ embedding_matrix1 = compute_embeddings(texts1, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 1")
209
 
210
  # Compute embeddings for Dataset 2
211
  status = "Computing embeddings for Dataset 2..."
212
  yield status, ""
213
+ embedding_matrix2 = compute_embeddings(texts2, batch_size=64, progress=progress, desc="Computing embeddings for Dataset 2")
214
 
215
  # Deduplicate across datasets
216
  status = "Deduplicating embeddings across datasets..."
217
  yield status, ""
218
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
219
+ embedding_matrix1, embedding_matrix2, threshold, progress=progress
220
  )
221
 
222
  num_duplicates = len(duplicate_indices_in_ds2)
 
247
  status = "Deduplication completed."
248
  yield status, result_text
249
 
250
+ except Exception as e:
251
+ yield f"An error occurred: {e}", ""
252
+ raise e
 
 
 
 
253
 
254
  with gr.Blocks() as demo:
255
  gr.Markdown("# Semantic Deduplication")
 
317
 
318
 
319
 
320
+
321
  # import gradio as gr
322
  # from datasets import load_dataset
323
  # import numpy as np