Spaces:
Build error
Build error
from sentence_transformers import util | |
import json | |
from safetensors.numpy import load_file | |
import numpy as np | |
import torch | |
from safetensors.torch import save_file | |
def dedup_similar_sentences(sentences, threshold=0.9, batch_size=64): | |
file = load_file("/root/autodl-tmp/dedup_audio_text.safetensors") | |
sentence_embeddings = torch.tensor(file['text_embed']).cuda() | |
batch_idx = 0 | |
while batch_idx * batch_size < len(sentences): | |
start_idx = batch_idx * batch_size | |
end_idx = min((batch_idx + 1) * batch_size, len(sentences)) | |
batch_embeddings = sentence_embeddings[start_idx:end_idx] | |
cosine_scores = torch.matmul(batch_embeddings, sentence_embeddings.T) / ( | |
torch.norm(batch_embeddings, dim=1)[:, None] * | |
torch.norm(sentence_embeddings, dim=1)) | |
duplicate_indices = torch.where((cosine_scores > threshold) & (cosine_scores < 1.0)) | |
duplicate_indices_list = duplicate_indices[1].tolist() | |
remove_idx = set(duplicate_indices_list) | |
# Update sentences and sentence_embeddings to remove duplicates | |
keep_indices = [idx for idx in range(len(sentences)) if idx not in remove_idx] | |
sentences = [sentences[idx] for idx in keep_indices] | |
sentence_embeddings = sentence_embeddings[keep_indices] | |
print(len(sentences)) | |
print(sentence_embeddings.shape) | |
# Update batch_idx accordingly | |
batch_idx = start_idx // batch_size + 1 | |
uq_sentences = sentences | |
return uq_sentences, sentence_embeddings | |
def read_default_prompt(): | |
import json | |
with open('/root/autodl-tmp/dedup_audio_text.json', 'r') as f: | |
data = json.load(f) | |
return data | |
if __name__ == '__main__': | |
all_texts = read_default_prompt() | |
unique_sentences, unique_embeddings = dedup_similar_sentences(all_texts, threshold=0.8) | |
with open("/root/autodl-tmp/dedup_audio_text_80.json", "w") as outfile: | |
json.dump(unique_sentences, outfile) | |
tensors = { | |
"text_embed": unique_embeddings, | |
} | |
save_file(tensors, "/root/autodl-tmp/dedup_audio_text_80.safetensors") | |