similitud / app.py
antagonico's picture
Create app.py
a4c55a2
import gradio as gr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import multiprocessing
import chromadb
import hashlib
# Carga el modelo
model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt')
# Crea el cliente ChromaDB
chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="my_collection")
def generate_hash(text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
# Funci贸n para obtener embeddings del modelo
def get_embeddings(sentences):
embeddings = []
for sentence in sentences:
sentence_hash = generate_hash(sentence)
# Verificar si el embedding ya est谩 en la base de datos
results = collection.query(query_texts=[sentence], n_results=1)
if results and 'embedding' in results[0]:
embeddings.append(np.array(results[0]['embedding']))
else:
# Si no est谩 en la base de datos, calcula el embedding y lo almacena
embedding = model.encode(sentence, show_progress_bar=False)
collection.add(
embeddings=[embedding.tolist()],
documents=[sentence],
metadatas=[{"source": "my_source"}],
ids=[sentence_hash] # Usa el hash como ID
)
embeddings.append(embedding)
return np.array(embeddings)
# Funci贸n para comparar las sentencias
def calculate_similarity(args):
source_embedding, compare_embedding = args
return cosine_similarity(source_embedding.reshape(1, -1), compare_embedding.reshape(1, -1))[0][0]
def compare(source_sentence, compare_sentences):
compare_list = compare_sentences.split("--")
# Obtiene todos los embeddings a la vez para acelerar el proceso
all_sentences = [source_sentence] + compare_list
all_embeddings = get_embeddings(all_sentences)
# Prepara los datos para el multiprocesamiento
source_embedding = all_embeddings[0]
data_for_multiprocessing = [(source_embedding, emb) for emb in all_embeddings[1:]]
# Utiliza un pool de procesos para calcular las similitudes en paralelo
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
similarities = pool.map(calculate_similarity, data_for_multiprocessing)
return ', '.join([str(sim) for sim in similarities])
# Define las interfaces de entrada y salida de Gradio
iface = gr.Interface(
fn=compare,
inputs=["text", "text"],
outputs="text",
live=False
)
# Inicia la interfaz de Gradio
iface.launch()