Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ def save_embeddings(sentences, filename):
|
|
8 |
torch.save(embeddings, filename)
|
9 |
|
10 |
def load_embeddings(filename):
|
11 |
-
return torch.load(filename)
|
12 |
|
13 |
def preprocess_model_descriptions(file_path):
|
14 |
encodings = ['utf-8', 'latin-1', 'utf-16']
|
@@ -44,13 +44,13 @@ embeddings = load_embeddings('embeddings_hf_spaces_descriptions.pt')
|
|
44 |
|
45 |
with gr.Blocks() as demo:
|
46 |
input = gr.Textbox(label="Enter your query")
|
47 |
-
|
48 |
|
49 |
df_output = gr.Dataframe(label="Similarity Results", wrap=True)
|
50 |
|
51 |
def search(query):
|
52 |
query_embedding = model.encode([query], convert_to_tensor=True)
|
53 |
-
return perform_similarity_search(query_embedding, embeddings, model_ids, descriptions)
|
54 |
|
55 |
input.submit(search, inputs=input, outputs=df_output)
|
56 |
button.click(search, inputs=input, outputs=df_output)
|
|
|
8 |
torch.save(embeddings, filename)
|
9 |
|
10 |
def load_embeddings(filename):
|
11 |
+
return torch.load(filename, map_location=torch.device('cpu'))
|
12 |
|
13 |
def preprocess_model_descriptions(file_path):
|
14 |
encodings = ['utf-8', 'latin-1', 'utf-16']
|
|
|
44 |
|
45 |
with gr.Blocks() as demo:
|
46 |
input = gr.Textbox(label="Enter your query")
|
47 |
+
num_results = gr.Slider(10, 100, value=10, label="Number of results")
|
48 |
|
49 |
df_output = gr.Dataframe(label="Similarity Results", wrap=True)
|
50 |
|
51 |
def search(query):
|
52 |
query_embedding = model.encode([query], convert_to_tensor=True)
|
53 |
+
return perform_similarity_search(query_embedding, embeddings, model_ids, descriptions, top_k=num_results)
|
54 |
|
55 |
input.submit(search, inputs=input, outputs=df_output)
|
56 |
button.click(search, inputs=input, outputs=df_output)
|