File size: 1,727 Bytes
04903ac
 
 
cb8f228
5c96576
 
cb8f228
 
5c96576
cb8f228
5c96576
cb8f228
5c96576
04903ac
 
cb8f228
 
 
04903ac
cb8f228
 
 
 
04903ac
cb8f228
04903ac
 
 
cb8f228
04903ac
 
 
cb8f228
04903ac
cb8f228
04903ac
 
 
5cd8d39
 
04903ac
5cd8d39
 
 
 
04903ac
cb8f228
f16d80b
5c96576
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from accelerate import Accelerator
import gradio as gr

# Inicializa el Accelerator
accelerator = Accelerator()

# Cargar el modelo y colocarlo en el dispositivo adecuado
model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt')
model, _ = accelerator.prepare(model, model)

# Funci贸n para obtener embeddings del modelo
def get_embeddings(sentences):
    # Preparar los datos para ejecuci贸n acelerada
    sentences = accelerator.prepare(sentences)
    return model.encode(sentences, show_progress_bar=False, convert_to_tensor=True)

# Funci贸n para calcular la similitud
def calculate_similarity(arguments):
    source_embedding, compare_embedding = arguments
    return cosine_similarity([source_embedding], [compare_embedding])[0][0]

# Funci贸n para comparar oraciones
def compare(source_sentence, compare_sentences):
    compare_list = compare_sentences.split("--")
    
    # Obtener todos los embeddings de una vez para acelerar el proceso
    all_sentences = [source_sentence] + compare_list
    all_embeddings = get_embeddings(all_sentences)
    
    # No se necesita multiprocesamiento si usamos Accelerate ya que esto se maneja internamente
    source_embedding = all_embeddings[0]
    similarities = [calculate_similarity((source_embedding, emb)) for emb in all_embeddings[1:]]
    
    return ', '.join([str(sim) for sim in similarities])


# Define las interfaces de entrada y salida de Gradio
iface = gr.Interface(
    fn=compare, 
    inputs=["text", "text"], 
    outputs="text",
    live=False  
)
# Iniciar la interfaz de Gradio
iface.launch()