Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,22 @@
|
|
1 |
import gradio as gr
|
2 |
from langchain.vectorstores import FAISS
|
3 |
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
-
# import torch
|
5 |
|
6 |
|
7 |
def get_matches(query, db_name="miread_contrastive"):
|
|
|
|
|
|
|
8 |
matches = vecdbs[index_names.index(
|
9 |
db_name)].similarity_search_with_score(query, k=60)
|
10 |
return matches
|
11 |
|
12 |
|
13 |
def inference(query, model="miread_contrastive"):
|
|
|
|
|
|
|
|
|
14 |
matches = get_matches(query, model)
|
15 |
auth_counts = {}
|
16 |
j_bucket = {}
|
@@ -28,12 +34,12 @@ def inference(query, model="miread_contrastive"):
|
|
28 |
date = doc.metadata.get('date', 'None')
|
29 |
link = doc.metadata.get('link', 'None')
|
30 |
submitter = doc.metadata.get('submitter', 'None')
|
31 |
-
# journal = doc.metadata.get('journal', 'None').strip()
|
32 |
journal = doc.metadata['journal']
|
33 |
if (journal is None or journal.strip() == ''):
|
34 |
journal = 'None'
|
35 |
else:
|
36 |
journal = journal.strip()
|
|
|
37 |
# For journals
|
38 |
if journal not in j_bucket:
|
39 |
j_bucket[journal] = score
|
@@ -87,7 +93,9 @@ model_names = [
|
|
87 |
model_kwargs = {'device': 'cpu'}
|
88 |
encode_kwargs = {'normalize_embeddings': False}
|
89 |
faiss_embedders = [HuggingFaceEmbeddings(
|
90 |
-
name,
|
|
|
|
|
91 |
|
92 |
vecdbs = [FAISS.load_local(index_name, faiss_embedder)
|
93 |
for index_name, faiss_embedder in zip(index_names, faiss_embedders)]
|
|
|
1 |
import gradio as gr
|
2 |
from langchain.vectorstores import FAISS
|
3 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
4 |
|
5 |
|
6 |
def get_matches(query, db_name="miread_contrastive"):
|
7 |
+
"""
|
8 |
+
Wrapper to call the similarity search on the required index
|
9 |
+
"""
|
10 |
matches = vecdbs[index_names.index(
|
11 |
db_name)].similarity_search_with_score(query, k=60)
|
12 |
return matches
|
13 |
|
14 |
|
15 |
def inference(query, model="miread_contrastive"):
|
16 |
+
"""
|
17 |
+
This function processes information retrieved by the get_matches() function
|
18 |
+
Returns - Gradio update commands for the authors, abstracts and journals tablular output
|
19 |
+
"""
|
20 |
matches = get_matches(query, model)
|
21 |
auth_counts = {}
|
22 |
j_bucket = {}
|
|
|
34 |
date = doc.metadata.get('date', 'None')
|
35 |
link = doc.metadata.get('link', 'None')
|
36 |
submitter = doc.metadata.get('submitter', 'None')
|
|
|
37 |
journal = doc.metadata['journal']
|
38 |
if (journal is None or journal.strip() == ''):
|
39 |
journal = 'None'
|
40 |
else:
|
41 |
journal = journal.strip()
|
42 |
+
|
43 |
# For journals
|
44 |
if journal not in j_bucket:
|
45 |
j_bucket[journal] = score
|
|
|
93 |
model_kwargs = {'device': 'cpu'}
|
94 |
encode_kwargs = {'normalize_embeddings': False}
|
95 |
faiss_embedders = [HuggingFaceEmbeddings(
|
96 |
+
model_name=name,
|
97 |
+
model_kwargs=model_kwargs,
|
98 |
+
encode_kwargs=encode_kwargs) for name in model_names]
|
99 |
|
100 |
vecdbs = [FAISS.load_local(index_name, faiss_embedder)
|
101 |
for index_name, faiss_embedder in zip(index_names, faiss_embedders)]
|