Spaces:
Paused
Paused
File size: 1,727 Bytes
04903ac cb8f228 5c96576 cb8f228 5c96576 cb8f228 5c96576 cb8f228 5c96576 04903ac cb8f228 04903ac cb8f228 04903ac cb8f228 04903ac cb8f228 04903ac cb8f228 04903ac cb8f228 04903ac 5cd8d39 04903ac 5cd8d39 04903ac cb8f228 f16d80b 5c96576 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from accelerate import Accelerator
import gradio as gr
# Inicializa el Accelerator
accelerator = Accelerator()
# Cargar el modelo y colocarlo en el dispositivo adecuado
model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt')
model, _ = accelerator.prepare(model, model)
# Funci贸n para obtener embeddings del modelo
def get_embeddings(sentences):
# Preparar los datos para ejecuci贸n acelerada
sentences = accelerator.prepare(sentences)
return model.encode(sentences, show_progress_bar=False, convert_to_tensor=True)
# Funci贸n para calcular la similitud
def calculate_similarity(arguments):
source_embedding, compare_embedding = arguments
return cosine_similarity([source_embedding], [compare_embedding])[0][0]
# Funci贸n para comparar oraciones
def compare(source_sentence, compare_sentences):
compare_list = compare_sentences.split("--")
# Obtener todos los embeddings de una vez para acelerar el proceso
all_sentences = [source_sentence] + compare_list
all_embeddings = get_embeddings(all_sentences)
# No se necesita multiprocesamiento si usamos Accelerate ya que esto se maneja internamente
source_embedding = all_embeddings[0]
similarities = [calculate_similarity((source_embedding, emb)) for emb in all_embeddings[1:]]
return ', '.join([str(sim) for sim in similarities])
# Define las interfaces de entrada y salida de Gradio
iface = gr.Interface(
fn=compare,
inputs=["text", "text"],
outputs="text",
live=False
)
# Iniciar la interfaz de Gradio
iface.launch()
|