audio_img / data.py
pengdaqian
add more
171f55b
raw
history blame
2.12 kB
from sentence_transformers import util
import json
from safetensors.numpy import load_file
import numpy as np
import torch
from safetensors.torch import save_file
def dedup_similar_sentences(sentences, threshold=0.9, batch_size=64):
file = load_file("/root/autodl-tmp/dedup_audio_text.safetensors")
sentence_embeddings = torch.tensor(file['text_embed']).cuda()
batch_idx = 0
while batch_idx * batch_size < len(sentences):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(sentences))
batch_embeddings = sentence_embeddings[start_idx:end_idx]
cosine_scores = torch.matmul(batch_embeddings, sentence_embeddings.T) / (
torch.norm(batch_embeddings, dim=1)[:, None] *
torch.norm(sentence_embeddings, dim=1))
duplicate_indices = torch.where((cosine_scores > threshold) & (cosine_scores < 1.0))
duplicate_indices_list = duplicate_indices[1].tolist()
remove_idx = set(duplicate_indices_list)
# Update sentences and sentence_embeddings to remove duplicates
keep_indices = [idx for idx in range(len(sentences)) if idx not in remove_idx]
sentences = [sentences[idx] for idx in keep_indices]
sentence_embeddings = sentence_embeddings[keep_indices]
print(len(sentences))
print(sentence_embeddings.shape)
# Update batch_idx accordingly
batch_idx = start_idx // batch_size + 1
uq_sentences = sentences
return uq_sentences, sentence_embeddings
def read_default_prompt():
import json
with open('/root/autodl-tmp/dedup_audio_text.json', 'r') as f:
data = json.load(f)
return data
if __name__ == '__main__':
all_texts = read_default_prompt()
unique_sentences, unique_embeddings = dedup_similar_sentences(all_texts, threshold=0.8)
with open("/root/autodl-tmp/dedup_audio_text_80.json", "w") as outfile:
json.dump(unique_sentences, outfile)
tensors = {
"text_embed": unique_embeddings,
}
save_file(tensors, "/root/autodl-tmp/dedup_audio_text_80.safetensors")