atrytone commited on
Commit
5c7bf4d
Β·
1 Parent(s): 2e2ab19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -1,16 +1,22 @@
1
  import gradio as gr
2
  from langchain.vectorstores import FAISS
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
- # import torch
5
 
6
 
7
  def get_matches(query, db_name="miread_contrastive"):
 
 
 
8
  matches = vecdbs[index_names.index(
9
  db_name)].similarity_search_with_score(query, k=60)
10
  return matches
11
 
12
 
13
  def inference(query, model="miread_contrastive"):
 
 
 
 
14
  matches = get_matches(query, model)
15
  auth_counts = {}
16
  j_bucket = {}
@@ -28,12 +34,12 @@ def inference(query, model="miread_contrastive"):
28
  date = doc.metadata.get('date', 'None')
29
  link = doc.metadata.get('link', 'None')
30
  submitter = doc.metadata.get('submitter', 'None')
31
- # journal = doc.metadata.get('journal', 'None').strip()
32
  journal = doc.metadata['journal']
33
  if (journal is None or journal.strip() == ''):
34
  journal = 'None'
35
  else:
36
  journal = journal.strip()
 
37
  # For journals
38
  if journal not in j_bucket:
39
  j_bucket[journal] = score
@@ -87,7 +93,9 @@ model_names = [
87
  model_kwargs = {'device': 'cpu'}
88
  encode_kwargs = {'normalize_embeddings': False}
89
  faiss_embedders = [HuggingFaceEmbeddings(
90
- name, model_kwargs, encode_kwargs) for name in model_names]
 
 
91
 
92
  vecdbs = [FAISS.load_local(index_name, faiss_embedder)
93
  for index_name, faiss_embedder in zip(index_names, faiss_embedders)]
 
1
  import gradio as gr
2
  from langchain.vectorstores import FAISS
3
  from langchain.embeddings import HuggingFaceEmbeddings
 
4
 
5
 
6
  def get_matches(query, db_name="miread_contrastive"):
7
+ """
8
+ Wrapper to call the similarity search on the required index
9
+ """
10
  matches = vecdbs[index_names.index(
11
  db_name)].similarity_search_with_score(query, k=60)
12
  return matches
13
 
14
 
15
  def inference(query, model="miread_contrastive"):
16
+ """
17
+ This function processes information retrieved by the get_matches() function
18
+ Returns - Gradio update commands for the authors, abstracts and journals tablular output
19
+ """
20
  matches = get_matches(query, model)
21
  auth_counts = {}
22
  j_bucket = {}
 
34
  date = doc.metadata.get('date', 'None')
35
  link = doc.metadata.get('link', 'None')
36
  submitter = doc.metadata.get('submitter', 'None')
 
37
  journal = doc.metadata['journal']
38
  if (journal is None or journal.strip() == ''):
39
  journal = 'None'
40
  else:
41
  journal = journal.strip()
42
+
43
  # For journals
44
  if journal not in j_bucket:
45
  j_bucket[journal] = score
 
93
  model_kwargs = {'device': 'cpu'}
94
  encode_kwargs = {'normalize_embeddings': False}
95
  faiss_embedders = [HuggingFaceEmbeddings(
96
+ model_name=name,
97
+ model_kwargs=model_kwargs,
98
+ encode_kwargs=encode_kwargs) for name in model_names]
99
 
100
  vecdbs = [FAISS.load_local(index_name, faiss_embedder)
101
  for index_name, faiss_embedder in zip(index_names, faiss_embedders)]