anzorq commited on
Commit
d49588a
·
1 Parent(s): 400d3a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sentence_transformers import SentenceTransformer, util
3
+ import pandas as pd
4
+ import gradio as gr
5
+
6
+ def save_embeddings(sentences, filename):
7
+ embeddings = model.encode(sentences, convert_to_tensor=True)
8
+ torch.save(embeddings, filename)
9
+
10
+ def load_embeddings(filename):
11
+ return torch.load(filename)
12
+
13
+ def preprocess_model_descriptions(file_path):
14
+ encodings = ['utf-8', 'latin-1', 'utf-16']
15
+ for encoding in encodings:
16
+ try:
17
+ df = pd.read_csv(file_path, sep='\t', header=None, names=['model_id', 'description'])
18
+ df.dropna(subset=['description'], inplace=True)
19
+ model_ids = df['model_id'].tolist()
20
+ descriptions = df['description'].tolist()
21
+ break
22
+ except UnicodeDecodeError:
23
+ continue
24
+ else:
25
+ raise UnicodeDecodeError("Unable to decode the file using the available encodings.")
26
+
27
+ return model_ids, descriptions
28
+
29
+ def perform_similarity_search(query_embeddings, embeddings, model_ids, descriptions, top_k=10):
30
+ cosine_scores = util.cos_sim(query_embeddings, embeddings)
31
+ similarity_scores = cosine_scores.tolist()
32
+
33
+ results = []
34
+ for i, query_embedding in enumerate(query_embeddings):
35
+ query_results = sorted(zip(model_ids, descriptions, similarity_scores[i]), key=lambda x: x[2], reverse=True)[:top_k]
36
+ results.extend(query_results)
37
+
38
+ return pd.DataFrame(results, columns=["model_id", "description", "score"])
39
+
40
+ model = SentenceTransformer('all-MiniLM-L6-v2')
41
+
42
+ model_ids, descriptions = preprocess_model_descriptions('model_descriptions.tsv')
43
+ embeddings = load_embeddings('embeddings_model_descriptions.pt')
44
+
45
+ with gr.Blocks() as demo:
46
+ input = gr.Textbox(label="Enter your query")
47
+ button = gr.Button(label="Search")
48
+
49
+ df_output = gr.Dataframe(label="Similarity Results", wrap=True)
50
+
51
+ def search(query):
52
+ query_embedding = model.encode([query], convert_to_tensor=True)
53
+ return perform_similarity_search(query_embedding, embeddings, model_ids, descriptions)
54
+
55
+ input.submit(search, inputs=input, outputs=df_output)
56
+ button.click(search, inputs=input, outputs=df_output)
57
+
58
+ demo.launch()