Pringled commited on
Commit
7a1cd7a
·
1 Parent(s): 6188d2c

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +130 -130
app.py CHANGED
@@ -4,11 +4,7 @@ import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from tqdm import tqdm
7
- import difflib
8
-
9
- def display_word_differences(x: str, y: str) -> str:
10
- diff = difflib.ndiff(x.split(), y.split())
11
- return " ".join([word for word in diff if word.startswith(('+', '-'))])
12
 
13
  def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
14
  """
@@ -24,11 +20,11 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
24
  embedding_matrix,
25
  threshold=threshold,
26
  batch_size=batch_size,
27
- show_progressbar=False # Disable internal progress bar
28
  )
29
 
30
  # Process duplicates
31
- for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
32
  if i not in deduplicated_indices:
33
  continue # Skip already marked duplicates
34
 
@@ -58,11 +54,11 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
58
  embedding_matrix_2,
59
  threshold=threshold,
60
  batch_size=batch_size,
61
- show_progressbar=False # Disable internal progress bar
62
  )
63
 
64
  # Process duplicates
65
- for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
66
  # Similar items are returned as (index, score), we are only interested in the index
67
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
68
 
@@ -73,150 +69,154 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
73
 
74
  return duplicate_indices_in_test, duplicate_to_original_mapping
75
 
 
 
 
 
76
  def perform_deduplication(
77
  deduplication_type,
78
  dataset1_name,
79
  dataset1_split,
 
80
  dataset2_name,
81
  dataset2_split,
82
- text_column_name,
83
  threshold
84
  ):
85
  # Convert threshold to float
86
  threshold = float(threshold)
87
-
88
- with gr.Progress(track_tqdm=True) as progress:
89
- if deduplication_type == "Single dataset":
90
- # Load the dataset
91
- ds = load_dataset(dataset1_name, split=dataset1_split)
92
-
93
- # Extract texts
94
- try:
95
- texts = [example[text_column_name] for example in ds]
96
- except KeyError:
97
- return f"Error: Text column '{text_column_name}' not found in dataset."
98
-
99
- # Compute embeddings
100
- progress(0.1, desc="Loading model and computing embeddings...")
101
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
102
- embedding_matrix = model.encode(texts, show_progressbar=False)
103
-
104
- # Deduplicate
105
- progress(0.5, desc="Performing deduplication...")
106
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
107
-
108
- # Prepare the results
109
- num_duplicates = len(duplicate_to_original_mapping)
110
- num_total = len(texts)
111
- num_deduplicated = len(deduplicated_indices)
112
-
113
- result_text = f"**Total documents:** {num_total}\n"
114
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
115
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
116
-
117
- # Show sample duplicates
118
- result_text += "### Sample Duplicate Pairs with Differences:\n\n"
119
- num_examples = min(5, num_duplicates)
120
- if num_examples > 0:
121
- sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
122
- for duplicate_idx, original_idx in sample_duplicates:
123
- original_text = texts[original_idx]
124
- duplicate_text = texts[duplicate_idx]
125
- differences = display_word_differences(original_text, duplicate_text)
126
- result_text += f"**Original Text (Index {original_idx}):**\n{original_text}\n\n"
127
- result_text += f"**Duplicate Text (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
128
- result_text += f"**Differences:**\n{differences}\n\n"
129
- result_text += "---\n\n"
130
- else:
131
- result_text += "No duplicates found.\n"
132
-
133
- return result_text
134
-
135
- elif deduplication_type == "Cross-dataset":
136
- # Load datasets
137
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
138
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
139
-
140
- # Extract texts
141
- try:
142
- texts1 = [example[text_column_name] for example in ds1]
143
- texts2 = [example[text_column_name] for example in ds2]
144
- except KeyError:
145
- return f"Error: Text column '{text_column_name}' not found in one of the datasets."
146
-
147
- # Compute embeddings
148
- progress(0.1, desc="Computing embeddings for Dataset 1...")
149
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
150
- embedding_matrix1 = model.encode(texts1, show_progressbar=False)
151
-
152
- progress(0.5, desc="Computing embeddings for Dataset 2...")
153
- embedding_matrix2 = model.encode(texts2, show_progressbar=False)
154
-
155
- # Deduplicate across datasets
156
- progress(0.7, desc="Performing cross-dataset deduplication...")
157
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
158
-
159
- num_duplicates = len(duplicate_indices_in_ds2)
160
- num_total_ds2 = len(texts2)
161
- num_unique_ds2 = num_total_ds2 - num_duplicates
162
-
163
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
164
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
165
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
166
-
167
- # Show sample duplicates
168
- result_text += "### Sample Duplicate Pairs with Differences:\n\n"
169
- num_examples = min(5, num_duplicates)
170
- if num_examples > 0:
171
- sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
172
- for duplicate_idx, original_idx in sample_duplicates:
173
- original_text = texts1[original_idx]
174
- duplicate_text = texts2[duplicate_idx]
175
- differences = display_word_differences(original_text, duplicate_text)
176
- result_text += f"**Original Text in {dataset1_name}/{dataset1_split} (Index {original_idx}):**\n{original_text}\n\n"
177
- result_text += f"**Duplicate Text in {dataset2_name}/{dataset2_split} (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
178
- result_text += f"**Differences:**\n{differences}\n\n"
179
- result_text += "---\n\n"
180
- else:
181
- result_text += "No duplicates found.\n"
182
-
183
- return result_text
184
 
185
  with gr.Blocks() as demo:
186
  gr.Markdown("# Semantic Deduplication")
187
-
188
  deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
189
-
190
- with gr.Row():
191
- dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
192
- dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
193
-
194
- dataset2_row = gr.Column(visible=False)
195
- with dataset2_row:
196
- dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
197
- dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
198
-
199
- text_column_name = gr.Textbox(value="text", label="Text Column Name")
200
-
 
 
201
  threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
202
-
203
  compute_button = gr.Button("Compute")
204
-
205
  output = gr.Markdown()
206
-
207
- # Function to update the visibility of dataset2_row
208
- def update_visibility(choice):
209
- if choice == "Cross-dataset":
210
- return {dataset2_row: gr.update(visible=True)}
211
  else:
212
- return {dataset2_row: gr.update(visible=False)}
 
 
213
 
214
- deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
215
-
216
  compute_button.click(
217
  fn=perform_deduplication,
218
- inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, text_column_name, threshold],
 
 
 
 
 
 
 
 
 
219
  outputs=output
220
  )
221
-
222
  demo.launch()
 
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from tqdm import tqdm
7
+ from difflib import ndiff
 
 
 
 
8
 
9
  def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
10
  """
 
20
  embedding_matrix,
21
  threshold=threshold,
22
  batch_size=batch_size,
23
+ show_progressbar=True
24
  )
25
 
26
  # Process duplicates
27
+ for i, similar_items in enumerate(tqdm(results)):
28
  if i not in deduplicated_indices:
29
  continue # Skip already marked duplicates
30
 
 
54
  embedding_matrix_2,
55
  threshold=threshold,
56
  batch_size=batch_size,
57
+ show_progressbar=True
58
  )
59
 
60
  # Process duplicates
61
+ for i, similar_items in enumerate(tqdm(results)):
62
  # Similar items are returned as (index, score), we are only interested in the index
63
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
64
 
 
69
 
70
  return duplicate_indices_in_test, duplicate_to_original_mapping
71
 
72
+ def display_word_differences(x: str, y: str) -> str:
73
+ diff = ndiff(x.split(), y.split())
74
+ return " ".join([f"{word}" for word in diff if word.startswith(('+', '-'))])
75
+
76
  def perform_deduplication(
77
  deduplication_type,
78
  dataset1_name,
79
  dataset1_split,
80
+ dataset1_text_column,
81
  dataset2_name,
82
  dataset2_split,
83
+ dataset2_text_column,
84
  threshold
85
  ):
86
  # Convert threshold to float
87
  threshold = float(threshold)
88
+
89
+ if deduplication_type == "Single dataset":
90
+ # Load the dataset
91
+ ds = load_dataset(dataset1_name, split=dataset1_split)
92
+
93
+ # Extract texts
94
+ texts = [example[dataset1_text_column] for example in ds]
95
+
96
+ # Compute embeddings
97
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
98
+ embedding_matrix = model.encode(texts, show_progressbar=True)
99
+
100
+ # Deduplicate
101
+ with gr.Progress(track_tqdm=True):
 
 
 
 
 
102
  deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
103
+
104
+ # Prepare the results
105
+ num_duplicates = len(duplicate_to_original_mapping)
106
+ num_total = len(texts)
107
+ num_deduplicated = len(deduplicated_indices)
108
+
109
+ result_text = f"**Total documents:** {num_total}\n"
110
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
111
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
112
+
113
+ # Show deduplicated examples
114
+ result_text += "**Examples of duplicates found:**\n\n"
115
+ num_examples = min(5, num_duplicates)
116
+ examples_shown = 0
117
+ for duplicate_idx, original_idx in duplicate_to_original_mapping.items():
118
+ if examples_shown >= num_examples:
119
+ break
120
+ original_text = texts[original_idx]
121
+ duplicate_text = texts[duplicate_idx]
122
+ differences = display_word_differences(original_text, duplicate_text)
123
+ result_text += f"**Original text:**\n{original_text}\n\n"
124
+ result_text += f"**Duplicate text:**\n{duplicate_text}\n\n"
125
+ result_text += f"**Differences:**\n{differences}\n"
126
+ result_text += "-" * 50 + "\n\n"
127
+ examples_shown += 1
128
+
129
+ return result_text
130
+
131
+ elif deduplication_type == "Cross-dataset":
132
+ # Load datasets
133
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
134
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
135
+
136
+ # Extract texts
137
+ texts1 = [example[dataset1_text_column] for example in ds1]
138
+ texts2 = [example[dataset2_text_column] for example in ds2]
139
+
140
+ # Compute embeddings
141
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
142
+ embedding_matrix1 = model.encode(texts1, show_progressbar=True)
143
+ embedding_matrix2 = model.encode(texts2, show_progressbar=True)
144
+
145
+ # Deduplicate across datasets
146
+ with gr.Progress(track_tqdm=True):
 
 
 
 
 
 
147
  duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
148
+
149
+ num_duplicates = len(duplicate_indices_in_ds2)
150
+ num_total_ds2 = len(texts2)
151
+ num_unique_ds2 = num_total_ds2 - num_duplicates
152
+
153
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
154
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
155
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
156
+
157
+ # Show deduplicated examples
158
+ result_text += "**Examples of duplicates found in Dataset 2:**\n\n"
159
+ num_examples = min(5, num_duplicates)
160
+ examples_shown = 0
161
+ for duplicate_idx in duplicate_indices_in_ds2[:num_examples]:
162
+ original_idx = duplicate_to_original_mapping[duplicate_idx]
163
+ original_text = texts1[original_idx]
164
+ duplicate_text = texts2[duplicate_idx]
165
+ differences = display_word_differences(original_text, duplicate_text)
166
+ result_text += f"**Original text (Dataset 1):**\n{original_text}\n\n"
167
+ result_text += f"**Duplicate text (Dataset 2):**\n{duplicate_text}\n\n"
168
+ result_text += f"**Differences:**\n{differences}\n"
169
+ result_text += "-" * 50 + "\n\n"
170
+ examples_shown += 1
171
+
172
+ return result_text
 
173
 
174
  with gr.Blocks() as demo:
175
  gr.Markdown("# Semantic Deduplication")
176
+
177
  deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
178
+
179
+ with gr.Tab("Dataset 1"):
180
+ with gr.Row():
181
+ dataset1_name = gr.Textbox(value="ag_news", label="Dataset Name")
182
+ dataset1_split = gr.Textbox(value="train", label="Split")
183
+ dataset1_text_column = gr.Textbox(value="text", label="Text Column Name")
184
+
185
+ dataset2_tab = gr.Tab("Dataset 2", visible=False)
186
+ with dataset2_tab:
187
+ with gr.Row():
188
+ dataset2_name = gr.Textbox(value="ag_news", label="Dataset Name")
189
+ dataset2_split = gr.Textbox(value="test", label="Split")
190
+ dataset2_text_column = gr.Textbox(value="text", label="Text Column Name")
191
+
192
  threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
193
+
194
  compute_button = gr.Button("Compute")
195
+
196
  output = gr.Markdown()
197
+
198
+ # Function to update the visibility of dataset2_tab
199
+ def update_visibility(deduplication_type):
200
+ if deduplication_type == "Cross-dataset":
201
+ return {dataset2_tab: gr.update(visible=True)}
202
  else:
203
+ return {dataset2_tab: gr.update(visible=False)}
204
+
205
+ deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_tab])
206
 
 
 
207
  compute_button.click(
208
  fn=perform_deduplication,
209
+ inputs=[
210
+ deduplication_type,
211
+ dataset1_name,
212
+ dataset1_split,
213
+ dataset1_text_column,
214
+ dataset2_name,
215
+ dataset2_split,
216
+ dataset2_text_column,
217
+ threshold
218
+ ],
219
  outputs=output
220
  )
221
+
222
  demo.launch()