Pringled commited on
Commit
6188d2c
·
1 Parent(s): 2827b8a

Updated app with code for deduplication

Browse files
Files changed (1) hide show
  1. app.py +122 -82
app.py CHANGED
@@ -1,18 +1,14 @@
1
- # import gradio as gr
2
-
3
- # def greet(name):
4
- # return "Hello " + name + "!!"
5
-
6
- # demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- # demo.launch()
8
-
9
-
10
  import gradio as gr
11
  from datasets import load_dataset
12
  import numpy as np
13
  from model2vec import StaticModel
14
  from reach import Reach
15
  from tqdm import tqdm
 
 
 
 
 
16
 
17
  def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
18
  """
@@ -28,11 +24,11 @@ def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int
28
  embedding_matrix,
29
  threshold=threshold,
30
  batch_size=batch_size,
31
- show_progressbar=True
32
  )
33
 
34
  # Process duplicates
35
- for i, similar_items in enumerate(tqdm(results)):
36
  if i not in deduplicated_indices:
37
  continue # Skip already marked duplicates
38
 
@@ -62,11 +58,11 @@ def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix
62
  embedding_matrix_2,
63
  threshold=threshold,
64
  batch_size=batch_size,
65
- show_progressbar=True
66
  )
67
 
68
  # Process duplicates
69
- for i, similar_items in enumerate(tqdm(results)):
70
  # Similar items are returned as (index, score), we are only interested in the index
71
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
72
 
@@ -83,100 +79,144 @@ def perform_deduplication(
83
  dataset1_split,
84
  dataset2_name,
85
  dataset2_split,
 
86
  threshold
87
  ):
88
  # Convert threshold to float
89
  threshold = float(threshold)
90
-
91
- if deduplication_type == "Single dataset":
92
- # Load the dataset
93
- ds = load_dataset(dataset1_name, split=dataset1_split)
94
-
95
- # Extract texts
96
- texts = [example['text'] for example in ds]
97
-
98
- # Compute embeddings
99
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
100
- embedding_matrix = model.encode(texts, show_progressbar=True)
101
-
102
- # Deduplicate
103
- deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
104
-
105
- # Prepare the results
106
- num_duplicates = len(duplicate_to_original_mapping)
107
- num_total = len(texts)
108
- num_deduplicated = len(deduplicated_indices)
109
-
110
- result_text = f"**Total documents:** {num_total}\n"
111
- result_text += f"**Number of duplicates found:** {num_duplicates}\n"
112
- result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
113
- result_text += f"**Deduplicated indices:** {deduplicated_indices.tolist()}\n\n"
114
- result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
115
-
116
- return result_text
117
-
118
- elif deduplication_type == "Cross-dataset":
119
- # Load datasets
120
- ds1 = load_dataset(dataset1_name, split=dataset1_split)
121
- ds2 = load_dataset(dataset2_name, split=dataset2_split)
122
-
123
- # Extract texts
124
- texts1 = [example['text'] for example in ds1]
125
- texts2 = [example['text'] for example in ds2]
126
-
127
- # Compute embeddings
128
- model = StaticModel.from_pretrained("minishlab/M2V_base_output")
129
- embedding_matrix1 = model.encode(texts1, show_progressbar=True)
130
- embedding_matrix2 = model.encode(texts2, show_progressbar=True)
131
-
132
- # Deduplicate across datasets
133
- duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
134
-
135
- num_duplicates = len(duplicate_indices_in_ds2)
136
- num_total_ds2 = len(texts2)
137
- num_unique_ds2 = num_total_ds2 - num_duplicates
138
-
139
- result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
140
- result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
141
- result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
142
- result_text += f"**Duplicate indices in {dataset2_name}/{dataset2_split}:** {duplicate_indices_in_ds2}\n\n"
143
- result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
144
-
145
- return result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  with gr.Blocks() as demo:
148
  gr.Markdown("# Semantic Deduplication")
149
-
150
  deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
151
-
152
  with gr.Row():
153
  dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
154
  dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
155
-
156
- dataset2_row = gr.Row(visible=False)
157
  with dataset2_row:
158
  dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
159
  dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
160
-
 
 
161
  threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
162
-
163
  compute_button = gr.Button("Compute")
164
-
165
  output = gr.Markdown()
166
-
167
  # Function to update the visibility of dataset2_row
168
- def update_visibility(deduplication_type):
169
- if deduplication_type == "Cross-dataset":
170
  return {dataset2_row: gr.update(visible=True)}
171
  else:
172
  return {dataset2_row: gr.update(visible=False)}
173
 
174
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
175
-
176
  compute_button.click(
177
  fn=perform_deduplication,
178
- inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, threshold],
179
  outputs=output
180
  )
181
-
182
  demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  import numpy as np
4
  from model2vec import StaticModel
5
  from reach import Reach
6
  from tqdm import tqdm
7
+ import difflib
8
+
9
+ def display_word_differences(x: str, y: str) -> str:
10
+ diff = difflib.ndiff(x.split(), y.split())
11
+ return " ".join([word for word in diff if word.startswith(('+', '-'))])
12
 
13
  def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
14
  """
 
24
  embedding_matrix,
25
  threshold=threshold,
26
  batch_size=batch_size,
27
+ show_progressbar=False # Disable internal progress bar
28
  )
29
 
30
  # Process duplicates
31
+ for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
32
  if i not in deduplicated_indices:
33
  continue # Skip already marked duplicates
34
 
 
58
  embedding_matrix_2,
59
  threshold=threshold,
60
  batch_size=batch_size,
61
+ show_progressbar=False # Disable internal progress bar
62
  )
63
 
64
  # Process duplicates
65
+ for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
66
  # Similar items are returned as (index, score), we are only interested in the index
67
  similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
68
 
 
79
  dataset1_split,
80
  dataset2_name,
81
  dataset2_split,
82
+ text_column_name,
83
  threshold
84
  ):
85
  # Convert threshold to float
86
  threshold = float(threshold)
87
+
88
+ with gr.Progress(track_tqdm=True) as progress:
89
+ if deduplication_type == "Single dataset":
90
+ # Load the dataset
91
+ ds = load_dataset(dataset1_name, split=dataset1_split)
92
+
93
+ # Extract texts
94
+ try:
95
+ texts = [example[text_column_name] for example in ds]
96
+ except KeyError:
97
+ return f"Error: Text column '{text_column_name}' not found in dataset."
98
+
99
+ # Compute embeddings
100
+ progress(0.1, desc="Loading model and computing embeddings...")
101
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
102
+ embedding_matrix = model.encode(texts, show_progressbar=False)
103
+
104
+ # Deduplicate
105
+ progress(0.5, desc="Performing deduplication...")
106
+ deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
107
+
108
+ # Prepare the results
109
+ num_duplicates = len(duplicate_to_original_mapping)
110
+ num_total = len(texts)
111
+ num_deduplicated = len(deduplicated_indices)
112
+
113
+ result_text = f"**Total documents:** {num_total}\n"
114
+ result_text += f"**Number of duplicates found:** {num_duplicates}\n"
115
+ result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
116
+
117
+ # Show sample duplicates
118
+ result_text += "### Sample Duplicate Pairs with Differences:\n\n"
119
+ num_examples = min(5, num_duplicates)
120
+ if num_examples > 0:
121
+ sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
122
+ for duplicate_idx, original_idx in sample_duplicates:
123
+ original_text = texts[original_idx]
124
+ duplicate_text = texts[duplicate_idx]
125
+ differences = display_word_differences(original_text, duplicate_text)
126
+ result_text += f"**Original Text (Index {original_idx}):**\n{original_text}\n\n"
127
+ result_text += f"**Duplicate Text (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
128
+ result_text += f"**Differences:**\n{differences}\n\n"
129
+ result_text += "---\n\n"
130
+ else:
131
+ result_text += "No duplicates found.\n"
132
+
133
+ return result_text
134
+
135
+ elif deduplication_type == "Cross-dataset":
136
+ # Load datasets
137
+ ds1 = load_dataset(dataset1_name, split=dataset1_split)
138
+ ds2 = load_dataset(dataset2_name, split=dataset2_split)
139
+
140
+ # Extract texts
141
+ try:
142
+ texts1 = [example[text_column_name] for example in ds1]
143
+ texts2 = [example[text_column_name] for example in ds2]
144
+ except KeyError:
145
+ return f"Error: Text column '{text_column_name}' not found in one of the datasets."
146
+
147
+ # Compute embeddings
148
+ progress(0.1, desc="Computing embeddings for Dataset 1...")
149
+ model = StaticModel.from_pretrained("minishlab/M2V_base_output")
150
+ embedding_matrix1 = model.encode(texts1, show_progressbar=False)
151
+
152
+ progress(0.5, desc="Computing embeddings for Dataset 2...")
153
+ embedding_matrix2 = model.encode(texts2, show_progressbar=False)
154
+
155
+ # Deduplicate across datasets
156
+ progress(0.7, desc="Performing cross-dataset deduplication...")
157
+ duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
158
+
159
+ num_duplicates = len(duplicate_indices_in_ds2)
160
+ num_total_ds2 = len(texts2)
161
+ num_unique_ds2 = num_total_ds2 - num_duplicates
162
+
163
+ result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
164
+ result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
165
+ result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
166
+
167
+ # Show sample duplicates
168
+ result_text += "### Sample Duplicate Pairs with Differences:\n\n"
169
+ num_examples = min(5, num_duplicates)
170
+ if num_examples > 0:
171
+ sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
172
+ for duplicate_idx, original_idx in sample_duplicates:
173
+ original_text = texts1[original_idx]
174
+ duplicate_text = texts2[duplicate_idx]
175
+ differences = display_word_differences(original_text, duplicate_text)
176
+ result_text += f"**Original Text in {dataset1_name}/{dataset1_split} (Index {original_idx}):**\n{original_text}\n\n"
177
+ result_text += f"**Duplicate Text in {dataset2_name}/{dataset2_split} (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
178
+ result_text += f"**Differences:**\n{differences}\n\n"
179
+ result_text += "---\n\n"
180
+ else:
181
+ result_text += "No duplicates found.\n"
182
+
183
+ return result_text
184
 
185
  with gr.Blocks() as demo:
186
  gr.Markdown("# Semantic Deduplication")
187
+
188
  deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
189
+
190
  with gr.Row():
191
  dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
192
  dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
193
+
194
+ dataset2_row = gr.Column(visible=False)
195
  with dataset2_row:
196
  dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
197
  dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
198
+
199
+ text_column_name = gr.Textbox(value="text", label="Text Column Name")
200
+
201
  threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
202
+
203
  compute_button = gr.Button("Compute")
204
+
205
  output = gr.Markdown()
206
+
207
  # Function to update the visibility of dataset2_row
208
+ def update_visibility(choice):
209
+ if choice == "Cross-dataset":
210
  return {dataset2_row: gr.update(visible=True)}
211
  else:
212
  return {dataset2_row: gr.update(visible=False)}
213
 
214
  deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
215
+
216
  compute_button.click(
217
  fn=perform_deduplication,
218
+ inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, text_column_name, threshold],
219
  outputs=output
220
  )
221
+
222
  demo.launch()