File size: 7,572 Bytes
2827b8a 25d2eb7 2827b8a 25d2eb7 2827b8a 25d2eb7 2827b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# import gradio as gr
# def greet(name):
# return "Hello " + name + "!!"
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
"""
Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
"""
reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
# Use a set for deduplicated indices and keep track of duplicates
deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
duplicate_to_original_mapping = {}
results = reach.nearest_neighbor_threshold(
embedding_matrix,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicates
for i, similar_items in enumerate(tqdm(results)):
if i not in deduplicated_indices:
continue # Skip already marked duplicates
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
# Mark similar documents as duplicates and map them to the original
for sim_idx in similar_indices:
if sim_idx in deduplicated_indices:
deduplicated_indices.remove(sim_idx)
duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
"""
Deduplicate embeddings across two datasets and return the indices of duplicates between them.
"""
reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
# Keep track of duplicates in the second dataset
duplicate_indices_in_test = []
duplicate_to_original_mapping = {}
# Find nearest neighbors from the test set in the train set
results = reach.nearest_neighbor_threshold(
embedding_matrix_2,
threshold=threshold,
batch_size=batch_size,
show_progressbar=True
)
# Process duplicates
for i, similar_items in enumerate(tqdm(results)):
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
# If we find a similar item in the train set, mark it as a duplicate
if similar_indices:
duplicate_indices_in_test.append(i)
duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
return duplicate_indices_in_test, duplicate_to_original_mapping
def perform_deduplication(
deduplication_type,
dataset1_name,
dataset1_split,
dataset2_name,
dataset2_split,
threshold
):
# Convert threshold to float
threshold = float(threshold)
if deduplication_type == "Single dataset":
# Load the dataset
ds = load_dataset(dataset1_name, split=dataset1_split)
# Extract texts
texts = [example['text'] for example in ds]
# Compute embeddings
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix = model.encode(texts, show_progressbar=True)
# Deduplicate
deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
# Prepare the results
num_duplicates = len(duplicate_to_original_mapping)
num_total = len(texts)
num_deduplicated = len(deduplicated_indices)
result_text = f"**Total documents:** {num_total}\n"
result_text += f"**Number of duplicates found:** {num_duplicates}\n"
result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
result_text += f"**Deduplicated indices:** {deduplicated_indices.tolist()}\n\n"
result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
return result_text
elif deduplication_type == "Cross-dataset":
# Load datasets
ds1 = load_dataset(dataset1_name, split=dataset1_split)
ds2 = load_dataset(dataset2_name, split=dataset2_split)
# Extract texts
texts1 = [example['text'] for example in ds1]
texts2 = [example['text'] for example in ds2]
# Compute embeddings
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix1 = model.encode(texts1, show_progressbar=True)
embedding_matrix2 = model.encode(texts2, show_progressbar=True)
# Deduplicate across datasets
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
num_duplicates = len(duplicate_indices_in_ds2)
num_total_ds2 = len(texts2)
num_unique_ds2 = num_total_ds2 - num_duplicates
result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
result_text += f"**Duplicate indices in {dataset2_name}/{dataset2_split}:** {duplicate_indices_in_ds2}\n\n"
result_text += f"**Duplicate to original mapping:** {duplicate_to_original_mapping}\n"
return result_text
with gr.Blocks() as demo:
gr.Markdown("# Semantic Deduplication")
deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
with gr.Row():
dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
dataset2_row = gr.Row(visible=False)
with dataset2_row:
dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
compute_button = gr.Button("Compute")
output = gr.Markdown()
# Function to update the visibility of dataset2_row
def update_visibility(deduplication_type):
if deduplication_type == "Cross-dataset":
return {dataset2_row: gr.update(visible=True)}
else:
return {dataset2_row: gr.update(visible=False)}
deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
compute_button.click(
fn=perform_deduplication,
inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, threshold],
outputs=output
)
demo.launch()
|