antagonico commited on
Commit
04903ac
1 Parent(s): d86605b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import numpy as np
5
+ import multiprocessing
6
+ import chromadb
7
+ import hashlib
8
+
9
+ # Carga el modelo
10
+ model = SentenceTransformer('Maite89/Roberta_finetuning_semantic_similarity_stsb_multi_mt')
11
+
12
+ # Crea el cliente ChromaDB
13
+ chroma_client = chromadb.Client()
14
+ collection = chroma_client.create_collection(name="my_collection")
15
+
16
+ def generate_hash(text):
17
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
18
+
19
+ # Funci贸n para obtener embeddings del modelo
20
+ def get_embeddings(sentences):
21
+ embeddings = []
22
+ for sentence in sentences:
23
+ sentence_hash = generate_hash(sentence)
24
+ # Verificar si el embedding ya est谩 en la base de datos
25
+ results = collection.query(query_texts=[sentence], n_results=1)
26
+ if results:
27
+ embeddings.append(np.array(results[0]['embedding']))
28
+ else:
29
+ # Si no est谩 en la base de datos, calcula el embedding y lo almacena
30
+ embedding = model.encode(sentence, show_progress_bar=False)
31
+ collection.add(
32
+ embeddings=[embedding.tolist()],
33
+ documents=[sentence],
34
+ metadatas=[{"source": "my_source"}],
35
+ ids=[sentence_hash] # Usa el hash como ID
36
+ )
37
+ embeddings.append(embedding)
38
+ return np.array(embeddings)
39
+
40
+ # Funci贸n para comparar las sentencias
41
+ def calculate_similarity(args):
42
+ source_embedding, compare_embedding = args
43
+ return cosine_similarity(source_embedding.reshape(1, -1), compare_embedding.reshape(1, -1))[0][0]
44
+
45
+ def compare(source_sentence, compare_sentences):
46
+ compare_list = compare_sentences.split("--")
47
+
48
+ # Obtiene todos los embeddings a la vez para acelerar el proceso
49
+ all_sentences = [source_sentence] + compare_list
50
+ all_embeddings = get_embeddings(all_sentences)
51
+
52
+ # Prepara los datos para el multiprocesamiento
53
+ source_embedding = all_embeddings[0]
54
+ data_for_multiprocessing = [(source_embedding, emb) for emb in all_embeddings[1:]]
55
+
56
+ # Utiliza un pool de procesos para calcular las similitudes en paralelo
57
+ with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
58
+ similarities = pool.map(calculate_similarity, data_for_multiprocessing)
59
+
60
+ return ', '.join([str(sim) for sim in similarities])
61
+
62
+ # Define las interfaces de entrada y salida de Gradio
63
+ iface = gr.Interface(
64
+ fn=compare,
65
+ inputs=["text", "text"],
66
+ outputs="text",
67
+ live=False
68
+ )
69
+
70
+ # Inicia la interfaz de Gradio
71
+ iface.launch()