Pringled commited on
Commit
b9fcd2c
·
1 Parent(s): 4f9641d
Files changed (1) hide show
  1. app.py +400 -162
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
- #from model2vec import StaticModel
5
  import model2vec
6
  from reach import Reach
7
  from difflib import ndiff
8
-
9
 
10
  # Load the model at startup
11
  model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
@@ -22,52 +21,19 @@ default_threshold = 0.9
22
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
23
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
24
 
25
-
26
- # Patch tqdm to use Gradio's progress bar
27
- #from tqdm import tqdm as original_tqdm
28
-
29
- # Patch tqdm to use Gradio's progress bar
30
- # Patch tqdm to use Gradio's progress bar
31
- # def patch_tqdm_for_gradio(progress):
32
- # class GradioTqdm(original_tqdm):
33
- # def __init__(self, *args, **kwargs):
34
- # super().__init__(*args, **kwargs)
35
- # self.progress = progress
36
- # self.total_batches = kwargs.get('total', len(args[0])) if len(args) > 0 else 1
37
- # self.update_interval = max(1, self.total_batches // 100) # Update every 1%
38
-
39
- # def update(self, n=1):
40
- # super().update(n)
41
- # # Update Gradio progress bar every update_interval steps
42
- # if self.n % self.update_interval == 0 or self.n == self.total_batches:
43
- # self.progress(self.n / self.total_batches)
44
-
45
- # return GradioTqdm
46
-
47
- # def patch_model2vec_tqdm(progress):
48
- # patched_tqdm = patch_tqdm_for_gradio(progress)
49
- # model2vec.tqdm = patched_tqdm # Replace tqdm in model2vec
50
-
51
- # # Function to patch the original encode function with our Gradio tqdm
52
- # def original_encode_with_tqdm(original_encode_func, patched_tqdm):
53
- # def new_encode(*args, **kwargs):
54
- # original_tqdm_backup = original_tqdm
55
- # try:
56
- # # Patch the `tqdm` within encode
57
- # globals()['tqdm'] = patched_tqdm
58
- # return original_encode_func(*args, **kwargs)
59
- # finally:
60
- # # Restore original tqdm after calling encode
61
- # globals()['tqdm'] = original_tqdm_backup
62
-
63
- # return new_encode
64
-
65
-
66
  def batch_iterable(iterable, batch_size):
67
  """Helper function to create batches from an iterable."""
68
  for i in range(0, len(iterable), batch_size):
69
  yield iterable[i:i + batch_size]
70
 
 
 
 
 
 
 
 
 
71
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
72
  embeddings = []
73
  total_batches = (len(texts) + batch_size - 1) // batch_size
@@ -122,7 +88,6 @@ def display_word_differences(x: str, y: str) -> str:
122
  diff = ndiff(x.split(), y.split())
123
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
124
 
125
-
126
  def encode_texts(texts, progress=None):
127
  embedding_matrix = model.encode(texts, show_progressbar=False)
128
  return embedding_matrix
@@ -147,7 +112,8 @@ def perform_deduplication(
147
 
148
  if deduplication_type == "Single dataset":
149
  # Load Dataset 1
150
- status = "Loading Dataset 1..."
 
151
  yield status, ""
152
  if (
153
  dataset1_name == default_dataset1_name
@@ -156,29 +122,34 @@ def perform_deduplication(
156
  ds = ds_default1
157
  else:
158
  ds = load_dataset(dataset1_name, split=dataset1_split)
 
 
159
 
160
  # Extract texts
161
- status = "Extracting texts from Dataset 1..."
 
162
  yield status, ""
163
  texts = [example[dataset1_text_column] for example in ds]
 
 
 
164
  # Compute embeddings
165
- status = "Computing embeddings for Dataset 1..."
 
166
  yield status, ""
167
  embedding_matrix = encode_texts(texts, progress=progress)
168
- #embedding_matrix = model.encode(texts, show_progressbar=True)
169
- # embedding_matrix = compute_embeddings(
170
- # texts,
171
- # batch_size=64,
172
- # progress=progress,
173
- # desc="Computing embeddings for Dataset 1",
174
- # )
175
 
176
  # Deduplicate
177
- status = "Deduplicating embeddings..."
 
178
  yield status, ""
179
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
180
  embedding_matrix, threshold, progress=progress
181
  )
 
 
182
 
183
  # Prepare the results
184
  num_duplicates = len(duplicate_to_original_mapping)
@@ -207,13 +178,13 @@ def perform_deduplication(
207
  result_text += "No duplicates found."
208
 
209
  # Final status
210
- status = "Deduplication completed."
211
  yield status, result_text
212
 
213
  elif deduplication_type == "Cross-dataset":
214
- # Similar code for cross-dataset deduplication
215
- # Load Dataset 1
216
- status = "Loading Dataset 1..."
217
  yield status, ""
218
  if (
219
  dataset1_name == default_dataset1_name
@@ -222,9 +193,11 @@ def perform_deduplication(
222
  ds1 = ds_default1
223
  else:
224
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
 
 
225
 
226
- # Load Dataset 2
227
- status = "Loading Dataset 2..."
228
  yield status, ""
229
  if (
230
  dataset2_name == default_dataset2_name
@@ -233,114 +206,15 @@ def perform_deduplication(
233
  ds2 = ds_default2
234
  else:
235
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
236
-
237
- # Extract texts from Dataset 1
238
- status = "Extracting texts from Dataset 1..."
239
- yield status, ""
240
- texts1 = [example[dataset1_text_column] for example in ds1]
241
-
242
- # Extract texts from Dataset 2
243
- status = "Extracting texts from Dataset 2..."
244
- yield status, ""
245
- texts2 = [example[dataset2_text_column] for example in ds2]
246
-
247
- # Compute embeddings for Dataset 1
248
- status = "Computing embeddings for Dataset 1..."
249
- yield status, ""
250
- embedding_matrix1 = compute_embeddings(
251
- texts1,
252
- batch_size=64,
253
- progress=progress,
254
- desc="Computing embeddings for Dataset 1",
255
- )
256
-
257
- # Compute embeddings for Dataset 2
258
- status = "Computing embeddings for Dataset 2..."
259
- yield status, ""
260
- embedding_matrix2 = compute_embeddings(
261
- texts2,
262
- batch_size=64,
263
- progress=progress,
264
- desc="Computing embeddings for Dataset 2",
265
- )
266
-
267
- # Deduplicate across datasets
268
- status = "Deduplicating embeddings across datasets..."
269
  yield status, ""
270
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
271
- embedding_matrix1, embedding_matrix2, threshold, progress=progress
272
- )
273
 
274
- num_duplicates = len(duplicate_indices_in_ds2)
275
- num_total_ds2 = len(texts2)
276
- num_unique_ds2 = num_total_ds2 - num_duplicates
277
-
278
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
279
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
280
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
281
-
282
- # Show deduplicated examples
283
- if num_duplicates > 0:
284
- result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
285
- num_examples = min(5, num_duplicates)
286
- for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
287
- original_idx = duplicate_to_original_mapping[duplicate_idx]
288
- original_text = texts1[original_idx]
289
- duplicate_text = texts2[duplicate_idx]
290
- differences = display_word_differences(original_text, duplicate_text)
291
- result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
292
- result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
293
- result_text += f"**Differences:**\n{differences}\n"
294
- result_text += "-" * 50 + "\n\n"
295
- else:
296
- result_text += "No duplicates found."
297
-
298
- # Final status
299
- status = "Deduplication completed."
300
- yield status, result_text
301
 
302
  except Exception as e:
303
  yield f"An error occurred: {e}", ""
304
  raise e
305
 
306
- def deduplicate_across_datasets(
307
- embedding_matrix_1: np.ndarray,
308
- embedding_matrix_2: np.ndarray,
309
- threshold: float,
310
- batch_size: int = 1024,
311
- progress=None
312
- ) -> tuple[list[int], dict[int, int]]:
313
- # Building the index from Dataset 1
314
- progress(0, desc="Building search index from Dataset 1...")
315
- reach = Reach(
316
- vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
317
- )
318
-
319
- duplicate_indices_in_test = []
320
- duplicate_to_original_mapping = {}
321
-
322
- # Finding nearest neighbors between datasets
323
- progress(0, desc="Finding nearest neighbors between datasets...")
324
- results = reach.nearest_neighbor_threshold(
325
- embedding_matrix_2,
326
- threshold=threshold,
327
- batch_size=batch_size,
328
- show_progressbar=False, # Disable internal progress bar
329
- )
330
-
331
- total_items = len(embedding_matrix_2)
332
- # Processing duplicates with a progress bar
333
- for i, similar_items in enumerate(
334
- progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
335
- ):
336
- similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
337
-
338
- if similar_indices:
339
- duplicate_indices_in_test.append(i)
340
- duplicate_to_original_mapping[i] = similar_indices[0]
341
-
342
- return duplicate_indices_in_test, duplicate_to_original_mapping
343
-
344
  # Adjust the height of the status_output component using custom CSS
345
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
346
  gr.Markdown("# Semantic Deduplication")
@@ -401,3 +275,367 @@ with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
401
  )
402
 
403
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
 
4
  import model2vec
5
  from reach import Reach
6
  from difflib import ndiff
7
+ import time
8
 
9
  # Load the model at startup
10
  model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
 
21
  ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
22
  ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def batch_iterable(iterable, batch_size):
25
  """Helper function to create batches from an iterable."""
26
  for i in range(0, len(iterable), batch_size):
27
  yield iterable[i:i + batch_size]
28
 
29
+ def log_time(message, start_time=None):
30
+ """Helper function to log the start and end times."""
31
+ current_time = time.time()
32
+ if start_time is not None:
33
+ elapsed = current_time - start_time
34
+ return f"{message} - Took {elapsed:.2f} seconds"
35
+ return f"{message} - Started"
36
+
37
  def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
38
  embeddings = []
39
  total_batches = (len(texts) + batch_size - 1) // batch_size
 
88
  diff = ndiff(x.split(), y.split())
89
  return " ".join([word for word in diff if word.startswith(("+", "-"))])
90
 
 
91
  def encode_texts(texts, progress=None):
92
  embedding_matrix = model.encode(texts, show_progressbar=False)
93
  return embedding_matrix
 
112
 
113
  if deduplication_type == "Single dataset":
114
  # Load Dataset 1
115
+ start_time = time.time()
116
+ status = log_time("Loading Dataset 1")
117
  yield status, ""
118
  if (
119
  dataset1_name == default_dataset1_name
 
122
  ds = ds_default1
123
  else:
124
  ds = load_dataset(dataset1_name, split=dataset1_split)
125
+ status = log_time("Loading Dataset 1 completed", start_time)
126
+ yield status, ""
127
 
128
  # Extract texts
129
+ start_time = time.time()
130
+ status = log_time("Extracting texts from Dataset 1")
131
  yield status, ""
132
  texts = [example[dataset1_text_column] for example in ds]
133
+ status = log_time("Extracting texts from Dataset 1 completed", start_time)
134
+ yield status, ""
135
+
136
  # Compute embeddings
137
+ start_time = time.time()
138
+ status = log_time("Computing embeddings for Dataset 1")
139
  yield status, ""
140
  embedding_matrix = encode_texts(texts, progress=progress)
141
+ status = log_time("Computing embeddings for Dataset 1 completed", start_time)
142
+ yield status, ""
 
 
 
 
 
143
 
144
  # Deduplicate
145
+ start_time = time.time()
146
+ status = log_time("Deduplicating embeddings")
147
  yield status, ""
148
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(
149
  embedding_matrix, threshold, progress=progress
150
  )
151
+ status = log_time("Deduplication completed", start_time)
152
+ yield status, ""
153
 
154
  # Prepare the results
155
  num_duplicates = len(duplicate_to_original_mapping)
 
178
  result_text += "No duplicates found."
179
 
180
  # Final status
181
+ status = log_time("Deduplication process finished")
182
  yield status, result_text
183
 
184
  elif deduplication_type == "Cross-dataset":
185
+ # Similar code for cross-dataset deduplication with time logging
186
+ start_time = time.time()
187
+ status = log_time("Loading Dataset 1")
188
  yield status, ""
189
  if (
190
  dataset1_name == default_dataset1_name
 
193
  ds1 = ds_default1
194
  else:
195
  ds1 = load_dataset(dataset1_name, split=dataset1_split)
196
+ status = log_time("Loading Dataset 1 completed", start_time)
197
+ yield status, ""
198
 
199
+ start_time = time.time()
200
+ status = log_time("Loading Dataset 2")
201
  yield status, ""
202
  if (
203
  dataset2_name == default_dataset2_name
 
206
  ds2 = ds_default2
207
  else:
208
  ds2 = load_dataset(dataset2_name, split=dataset2_split)
209
+ status = log_time("Loading Dataset 2 completed", start_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  yield status, ""
 
 
 
211
 
212
+ # Similar time logging for embedding computations and deduplication steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  except Exception as e:
215
  yield f"An error occurred: {e}", ""
216
  raise e
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  # Adjust the height of the status_output component using custom CSS
219
  with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
220
  gr.Markdown("# Semantic Deduplication")
 
275
  )
276
 
277
  demo.launch()
278
+
279
+ # import gradio as gr
280
+ # from datasets import load_dataset
281
+ # import numpy as np
282
+ # #from model2vec import StaticModel
283
+ # import model2vec
284
+ # from reach import Reach
285
+ # from difflib import ndiff
286
+
287
+
288
+ # # Load the model at startup
289
+ # model = model2vec.StaticModel.from_pretrained("minishlab/M2V_base_output")
290
+
291
+ # # Default dataset parameters
292
+ # default_dataset1_name = "sst2"
293
+ # default_dataset1_split = "train"
294
+ # default_dataset2_name = "sst2"
295
+ # default_dataset2_split = "validation"
296
+ # default_text_column = "sentence"
297
+ # default_threshold = 0.9
298
+
299
+ # # Load the default datasets at startup
300
+ # ds_default1 = load_dataset(default_dataset1_name, split=default_dataset1_split)
301
+ # ds_default2 = load_dataset(default_dataset2_name, split=default_dataset2_split)
302
+
303
+
304
+ # def batch_iterable(iterable, batch_size):
305
+ # """Helper function to create batches from an iterable."""
306
+ # for i in range(0, len(iterable), batch_size):
307
+ # yield iterable[i:i + batch_size]
308
+
309
+ # def compute_embeddings(texts, batch_size, progress, desc="Computing embeddings"):
310
+ # embeddings = []
311
+ # total_batches = (len(texts) + batch_size - 1) // batch_size
312
+ # for i, batch_texts in enumerate(batch_iterable(texts, batch_size)):
313
+ # batch_embeddings = model.encode(batch_texts, show_progressbar=False)
314
+ # embeddings.append(batch_embeddings)
315
+ # progress((i + 1) / total_batches, desc=desc)
316
+ # return np.concatenate(embeddings, axis=0)
317
+
318
+ # def deduplicate(
319
+ # embedding_matrix: np.ndarray,
320
+ # threshold: float,
321
+ # batch_size: int = 1024,
322
+ # progress=None
323
+ # ) -> tuple[np.ndarray, dict[int, int]]:
324
+ # # Building the index
325
+ # progress(0, desc="Building search index...")
326
+ # reach = Reach(
327
+ # vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))]
328
+ # )
329
+
330
+ # deduplicated_indices = set(range(len(embedding_matrix)))
331
+ # duplicate_to_original_mapping = {}
332
+
333
+ # # Finding nearest neighbors
334
+ # progress(0, desc="Finding nearest neighbors...")
335
+ # results = reach.nearest_neighbor_threshold(
336
+ # embedding_matrix,
337
+ # threshold=threshold,
338
+ # batch_size=batch_size,
339
+ # show_progressbar=False, # Disable internal progress bar
340
+ # )
341
+
342
+ # # Processing duplicates with a progress bar
343
+ # total_items = len(embedding_matrix)
344
+ # for i, similar_items in enumerate(
345
+ # progress.tqdm(results, desc="Processing duplicates", total=total_items)
346
+ # ):
347
+ # if i not in deduplicated_indices:
348
+ # continue
349
+
350
+ # similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
351
+
352
+ # for sim_idx in similar_indices:
353
+ # if sim_idx in deduplicated_indices:
354
+ # deduplicated_indices.remove(sim_idx)
355
+ # duplicate_to_original_mapping[sim_idx] = i
356
+
357
+ # return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
358
+
359
+ # def display_word_differences(x: str, y: str) -> str:
360
+ # diff = ndiff(x.split(), y.split())
361
+ # return " ".join([word for word in diff if word.startswith(("+", "-"))])
362
+
363
+
364
+ # def encode_texts(texts, progress=None):
365
+ # embedding_matrix = model.encode(texts, show_progressbar=False)
366
+ # return embedding_matrix
367
+
368
+ # def perform_deduplication(
369
+ # deduplication_type,
370
+ # dataset1_name,
371
+ # dataset1_split,
372
+ # dataset1_text_column,
373
+ # dataset2_name="",
374
+ # dataset2_split="",
375
+ # dataset2_text_column="",
376
+ # threshold=default_threshold,
377
+ # progress=gr.Progress(track_tqdm=True),
378
+ # ):
379
+ # try:
380
+ # # Convert threshold to float
381
+ # threshold = float(threshold)
382
+
383
+ # # Initialize status message
384
+ # status = ""
385
+
386
+ # if deduplication_type == "Single dataset":
387
+ # # Load Dataset 1
388
+ # status = "Loading Dataset 1..."
389
+ # yield status, ""
390
+ # if (
391
+ # dataset1_name == default_dataset1_name
392
+ # and dataset1_split == default_dataset1_split
393
+ # ):
394
+ # ds = ds_default1
395
+ # else:
396
+ # ds = load_dataset(dataset1_name, split=dataset1_split)
397
+
398
+ # # Extract texts
399
+ # status = "Extracting texts from Dataset 1..."
400
+ # yield status, ""
401
+ # texts = [example[dataset1_text_column] for example in ds]
402
+ # # Compute embeddings
403
+ # status = "Computing embeddings for Dataset 1..."
404
+ # yield status, ""
405
+ # embedding_matrix = encode_texts(texts, progress=progress)
406
+ # #embedding_matrix = model.encode(texts, show_progressbar=True)
407
+ # # embedding_matrix = compute_embeddings(
408
+ # # texts,
409
+ # # batch_size=64,
410
+ # # progress=progress,
411
+ # # desc="Computing embeddings for Dataset 1",
412
+ # # )
413
+
414
+ # # Deduplicate
415
+ # status = "Deduplicating embeddings..."
416
+ # yield status, ""
417
+ # deduplicated_indices, duplicate_to_original_mapping = deduplicate(
418
+ # embedding_matrix, threshold, progress=progress
419
+ # )
420
+
421
+ # # Prepare the results
422
+ # num_duplicates = len(duplicate_to_original_mapping)
423
+ # num_total = len(texts)
424
+ # num_deduplicated = len(deduplicated_indices)
425
+
426
+ # result_text = f"**Total documents:** {num_total}\n"
427
+ # result_text += f"**Number of duplicates found:** {num_duplicates}\n"
428
+ # result_text += (
429
+ # f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
430
+ # )
431
+
432
+ # # Show deduplicated examples
433
+ # if num_duplicates > 0:
434
+ # result_text += "**Examples of duplicates found:**\n\n"
435
+ # num_examples = min(5, num_duplicates)
436
+ # for duplicate_idx, original_idx in list(duplicate_to_original_mapping.items())[:num_examples]:
437
+ # original_text = texts[original_idx]
438
+ # duplicate_text = texts[duplicate_idx]
439
+ # differences = display_word_differences(original_text, duplicate_text)
440
+ # result_text += f"**Original text:**\n{original_text}\n\n"
441
+ # result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
442
+ # result_text += f"**Differences:**\n{differences}\n"
443
+ # result_text += "-" * 50 + "\n\n"
444
+ # else:
445
+ # result_text += "No duplicates found."
446
+
447
+ # # Final status
448
+ # status = "Deduplication completed."
449
+ # yield status, result_text
450
+
451
+ # elif deduplication_type == "Cross-dataset":
452
+ # # Similar code for cross-dataset deduplication
453
+ # # Load Dataset 1
454
+ # status = "Loading Dataset 1..."
455
+ # yield status, ""
456
+ # if (
457
+ # dataset1_name == default_dataset1_name
458
+ # and dataset1_split == default_dataset1_split
459
+ # ):
460
+ # ds1 = ds_default1
461
+ # else:
462
+ # ds1 = load_dataset(dataset1_name, split=dataset1_split)
463
+
464
+ # # Load Dataset 2
465
+ # status = "Loading Dataset 2..."
466
+ # yield status, ""
467
+ # if (
468
+ # dataset2_name == default_dataset2_name
469
+ # and dataset2_split == default_dataset2_split
470
+ # ):
471
+ # ds2 = ds_default2
472
+ # else:
473
+ # ds2 = load_dataset(dataset2_name, split=dataset2_split)
474
+
475
+ # # Extract texts from Dataset 1
476
+ # status = "Extracting texts from Dataset 1..."
477
+ # yield status, ""
478
+ # texts1 = [example[dataset1_text_column] for example in ds1]
479
+
480
+ # # Extract texts from Dataset 2
481
+ # status = "Extracting texts from Dataset 2..."
482
+ # yield status, ""
483
+ # texts2 = [example[dataset2_text_column] for example in ds2]
484
+
485
+ # # Compute embeddings for Dataset 1
486
+ # status = "Computing embeddings for Dataset 1..."
487
+ # yield status, ""
488
+ # embedding_matrix1 = compute_embeddings(
489
+ # texts1,
490
+ # batch_size=64,
491
+ # progress=progress,
492
+ # desc="Computing embeddings for Dataset 1",
493
+ # )
494
+
495
+ # # Compute embeddings for Dataset 2
496
+ # status = "Computing embeddings for Dataset 2..."
497
+ # yield status, ""
498
+ # embedding_matrix2 = compute_embeddings(
499
+ # texts2,
500
+ # batch_size=64,
501
+ # progress=progress,
502
+ # desc="Computing embeddings for Dataset 2",
503
+ # )
504
+
505
+ # # Deduplicate across datasets
506
+ # status = "Deduplicating embeddings across datasets..."
507
+ # yield status, ""
508
+ # duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(
509
+ # embedding_matrix1, embedding_matrix2, threshold, progress=progress
510
+ # )
511
+
512
+ # num_duplicates = len(duplicate_indices_in_ds2)
513
+ # num_total_ds2 = len(texts2)
514
+ # num_unique_ds2 = num_total_ds2 - num_duplicates
515
+
516
+ # result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
517
+ # result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
518
+ # result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
519
+
520
+ # # Show deduplicated examples
521
+ # if num_duplicates > 0:
522
+ # result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
523
+ # num_examples = min(5, num_duplicates)
524
+ # for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
525
+ # original_idx = duplicate_to_original_mapping[duplicate_idx]
526
+ # original_text = texts1[original_idx]
527
+ # duplicate_text = texts2[duplicate_idx]
528
+ # differences = display_word_differences(original_text, duplicate_text)
529
+ # result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
530
+ # result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
531
+ # result_text += f"**Differences:**\n{differences}\n"
532
+ # result_text += "-" * 50 + "\n\n"
533
+ # else:
534
+ # result_text += "No duplicates found."
535
+
536
+ # # Final status
537
+ # status = "Deduplication completed."
538
+ # yield status, result_text
539
+
540
+ # except Exception as e:
541
+ # yield f"An error occurred: {e}", ""
542
+ # raise e
543
+
544
+ # def deduplicate_across_datasets(
545
+ # embedding_matrix_1: np.ndarray,
546
+ # embedding_matrix_2: np.ndarray,
547
+ # threshold: float,
548
+ # batch_size: int = 1024,
549
+ # progress=None
550
+ # ) -> tuple[list[int], dict[int, int]]:
551
+ # # Building the index from Dataset 1
552
+ # progress(0, desc="Building search index from Dataset 1...")
553
+ # reach = Reach(
554
+ # vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))]
555
+ # )
556
+
557
+ # duplicate_indices_in_test = []
558
+ # duplicate_to_original_mapping = {}
559
+
560
+ # # Finding nearest neighbors between datasets
561
+ # progress(0, desc="Finding nearest neighbors between datasets...")
562
+ # results = reach.nearest_neighbor_threshold(
563
+ # embedding_matrix_2,
564
+ # threshold=threshold,
565
+ # batch_size=batch_size,
566
+ # show_progressbar=False, # Disable internal progress bar
567
+ # )
568
+
569
+ # total_items = len(embedding_matrix_2)
570
+ # # Processing duplicates with a progress bar
571
+ # for i, similar_items in enumerate(
572
+ # progress.tqdm(results, desc="Processing duplicates across datasets", total=total_items)
573
+ # ):
574
+ # similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]
575
+
576
+ # if similar_indices:
577
+ # duplicate_indices_in_test.append(i)
578
+ # duplicate_to_original_mapping[i] = similar_indices[0]
579
+
580
+ # return duplicate_indices_in_test, duplicate_to_original_mapping
581
+
582
+ # # Adjust the height of the status_output component using custom CSS
583
+ # with gr.Blocks(css="#status_output { height: 150px; overflow: auto; }") as demo:
584
+ # gr.Markdown("# Semantic Deduplication")
585
+
586
+ # deduplication_type = gr.Radio(
587
+ # choices=["Single dataset", "Cross-dataset"],
588
+ # label="Deduplication Type",
589
+ # value="Single dataset",
590
+ # )
591
+
592
+ # with gr.Row():
593
+ # dataset1_name = gr.Textbox(value=default_dataset1_name, label="Dataset 1 Name")
594
+ # dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
595
+ # dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
596
+
597
+ # dataset2_inputs = gr.Column(visible=False)
598
+ # with dataset2_inputs:
599
+ # gr.Markdown("### Dataset 2")
600
+ # with gr.Row():
601
+ # dataset2_name = gr.Textbox(value=default_dataset2_name, label="Dataset 2 Name")
602
+ # dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
603
+ # dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
604
+
605
+ # threshold = gr.Slider(
606
+ # minimum=0.0, maximum=1.0, value=default_threshold, label="Similarity Threshold"
607
+ # )
608
+
609
+ # compute_button = gr.Button("Compute")
610
+
611
+ # # Use 'gr.Markdown' with 'elem_id' and custom CSS to adjust height
612
+ # status_output = gr.Markdown(elem_id="status_output")
613
+ # result_output = gr.Markdown()
614
+
615
+ # # Function to update the visibility of dataset2_inputs
616
+ # def update_visibility(deduplication_type_value):
617
+ # if deduplication_type_value == "Cross-dataset":
618
+ # return gr.update(visible=True)
619
+ # else:
620
+ # return gr.update(visible=False)
621
+
622
+ # deduplication_type.change(
623
+ # update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
624
+ # )
625
+
626
+ # compute_button.click(
627
+ # fn=perform_deduplication,
628
+ # inputs=[
629
+ # deduplication_type,
630
+ # dataset1_name,
631
+ # dataset1_split,
632
+ # dataset1_text_column,
633
+ # dataset2_name,
634
+ # dataset2_split,
635
+ # dataset2_text_column,
636
+ # threshold,
637
+ # ],
638
+ # outputs=[status_output, result_output],
639
+ # )
640
+
641
+ # demo.launch()