colonelwatch commited on
Commit
abd37c8
·
1 Parent(s): becad39

Create and clean up function declarations

Browse files
Files changed (1) hide show
  1. app.py +44 -29
app.py CHANGED
@@ -8,9 +8,12 @@ import gradio as gr
8
  from datasets import Dataset
9
  from typing import TypedDict
10
  import json
 
11
 
12
  from markdown_it import MarkdownIt # used for overriding default markdown renderer
13
 
 
 
14
  SEARCH_TIME_S = 1 # TODO: optimize
15
 
16
 
@@ -20,24 +23,29 @@ class IndexParameters(TypedDict):
20
  param_string: str # pass directly to faiss index
21
 
22
 
23
- model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
 
24
 
25
- index: Dataset = Dataset.from_parquet("index/ids.parquet") # type: ignore
26
- index.load_faiss_index("embeddings", "index/index.faiss", None)
27
- faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
28
 
29
- with open("index/params.json", "r") as f:
30
- params: list[IndexParameters] = json.load(f)
31
- params = [p for p in params if p["exec_time"] < SEARCH_TIME_S]
32
- param = max(params, key=(lambda p: p["recall"]))
33
- param_string = param["param_string"]
34
 
35
- ps = faiss.ParameterSpace()
36
- ps.initialize(faiss_index)
37
- ps.set_index_parameters(faiss_index, param_string)
 
 
38
 
 
 
 
39
 
40
- def _recover_abstract(inverted_index):
 
 
 
41
  abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
42
 
43
  abstract = [None]*abstract_size
@@ -49,26 +57,12 @@ def _recover_abstract(inverted_index):
49
  abstract = ' '.join(abstract)
50
  return abstract
51
 
52
- def search(query):
53
- global model, index
54
-
55
- # TODO: pass in param string directly?
56
- query_embedding = model.encode(query)
57
- distances, faiss_ids = index.search("embeddings", query_embedding, 20)
58
-
59
- openalex_ids = index[faiss_ids]["idxs"]
60
- search_filter = f'openalex_id:{"|".join(openalex_ids)}'
61
- search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
62
-
63
- neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
64
- request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
65
-
66
- return neighbors, request_str
67
 
68
  def execute_request(request_str):
69
  response = requests.get(request_str).json()
70
  return response
71
 
 
72
  def format_response(neighbors, response):
73
  response = {doc['id']: doc for doc in response['results']}
74
 
@@ -92,7 +86,7 @@ def format_response(neighbors, response):
92
  if abstract_inverted_index is None: # edge case: no abstract
93
  abstract = 'No abstract'
94
  else:
95
- abstract = _recover_abstract(abstract_inverted_index)
96
  abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
97
 
98
  # try to get journal name or else set it to None
@@ -146,6 +140,27 @@ def format_response(neighbors, response):
146
 
147
  return result_string
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  with gr.Blocks() as demo:
150
  gr.Markdown('# abstracts-index')
151
  gr.Markdown(
 
8
  from datasets import Dataset
9
  from typing import TypedDict
10
  import json
11
+ from pathlib import Path
12
 
13
  from markdown_it import MarkdownIt # used for overriding default markdown renderer
14
 
15
+ MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
16
+ DIR = Path("index")
17
  SEARCH_TIME_S = 1 # TODO: optimize
18
 
19
 
 
23
  param_string: str # pass directly to faiss index
24
 
25
 
26
+ def get_model(model_name: str, device: str) -> SentenceTransformer:
27
+ return SentenceTransformer(model_name, device=device)
28
 
 
 
 
29
 
30
+ def get_index(dir: Path, search_time_s: float) -> Dataset:
31
+ index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
32
+ index.load_faiss_index("embeddings", dir / "index.faiss", None)
33
+ faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
 
34
 
35
+ with open(dir / "params.json", "r") as f:
36
+ params: list[IndexParameters] = json.load(f)
37
+ params = [p for p in params if p["exec_time"] < search_time_s]
38
+ param = max(params, key=(lambda p: p["recall"]))
39
+ param_string = param["param_string"]
40
 
41
+ ps = faiss.ParameterSpace()
42
+ ps.initialize(faiss_index)
43
+ ps.set_index_parameters(faiss_index, param_string)
44
 
45
+ return index
46
+
47
+
48
+ def recover_abstract(inverted_index):
49
  abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
50
 
51
  abstract = [None]*abstract_size
 
57
  abstract = ' '.join(abstract)
58
  return abstract
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def execute_request(request_str):
62
  response = requests.get(request_str).json()
63
  return response
64
 
65
+
66
  def format_response(neighbors, response):
67
  response = {doc['id']: doc for doc in response['results']}
68
 
 
86
  if abstract_inverted_index is None: # edge case: no abstract
87
  abstract = 'No abstract'
88
  else:
89
+ abstract = recover_abstract(abstract_inverted_index)
90
  abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
91
 
92
  # try to get journal name or else set it to None
 
140
 
141
  return result_string
142
 
143
+
144
+ model = get_model(MODEL_NAME, "cpu")
145
+ index = get_index(DIR, SEARCH_TIME_S)
146
+
147
+
148
+ def search(query):
149
+ global model, index
150
+
151
+ # TODO: pass in param string directly?
152
+ query_embedding = model.encode(query)
153
+ distances, faiss_ids = index.search("embeddings", query_embedding, 20)
154
+
155
+ openalex_ids = index[faiss_ids]["idxs"]
156
+ search_filter = f'openalex_id:{"|".join(openalex_ids)}'
157
+ search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
158
+
159
+ neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
160
+ request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
161
+
162
+ return neighbors, request_str
163
+
164
  with gr.Blocks() as demo:
165
  gr.Markdown('# abstracts-index')
166
  gr.Markdown(