Spaces:
Paused
Paused
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import multiprocessing | |
import chromadb | |
import hashlib | |
# Carga el modelo | |
model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt') | |
# Crea el cliente ChromaDB | |
chroma_client = chromadb.Client() | |
collection = chroma_client.create_collection(name="my_collection") | |
def generate_hash(text): | |
return hashlib.md5(text.encode('utf-8')).hexdigest() | |
# Funci贸n para obtener embeddings del modelo | |
import sqlite3 | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import multiprocessing | |
# Inicializa la base de datos y crea la tabla si no existe | |
conn = sqlite3.connect('embeddings.db') | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS embeddings | |
(sentence TEXT PRIMARY KEY, embedding BLOB)''') | |
conn.commit() | |
# Carga el modelo | |
model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt') | |
# Funci贸n para obtener embeddings del modelo | |
def get_embeddings(sentences): | |
# Intenta recuperar los embeddings de la base de datos | |
embeddings = [] | |
new_sentences = [] | |
for sentence in sentences: | |
c.execute('SELECT embedding FROM embeddings WHERE sentence=?', (sentence,)) | |
result = c.fetchone() | |
if result: | |
embeddings.append(np.frombuffer(result[0], dtype=np.float32)) | |
else: | |
new_sentences.append(sentence) | |
# Si hay nuevas sentencias, obt茅n los embeddings y almac茅nalos en la base de datos | |
if new_sentences: | |
new_embeddings = model.encode(new_sentences, show_progress_bar=False) | |
embeddings.extend(new_embeddings) | |
c.executemany('INSERT INTO embeddings VALUES (?,?)', | |
[(sent, emb.tobytes()) for sent, emb in zip(new_sentences, new_embeddings)]) | |
conn.commit() | |
return embeddings | |
# Funci贸n para comparar las sentencias | |
def calculate_similarity(args): | |
source_embedding, compare_embedding = args | |
return cosine_similarity(source_embedding.reshape(1, -1), compare_embedding.reshape(1, -1))[0][0] | |
def compare(source_sentence, compare_sentences): | |
compare_list = compare_sentences.split("--") | |
# Obtiene todos los embeddings a la vez para acelerar el proceso | |
all_sentences = [source_sentence] + compare_list | |
all_embeddings = get_embeddings(all_sentences) | |
# Prepara los datos para el multiprocesamiento | |
source_embedding = all_embeddings[0] | |
data_for_multiprocessing = [(source_embedding, emb) for emb in all_embeddings[1:]] | |
# Utiliza un pool de procesos para calcular las similitudes en paralelo | |
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool: | |
similarities = pool.map(calculate_similarity, data_for_multiprocessing) | |
return ', '.join([str(sim) for sim in similarities]) | |
# Define las interfaces de entrada y salida de Gradio | |
iface = gr.Interface( | |
fn=compare, | |
inputs=["text", "text"], | |
outputs="text", | |
live=False | |
) | |
# Inicia la interfaz de Gradio | |
iface.launch() | |
conn.close() | |