colonelwatch commited on
Commit
a228730
·
1 Parent(s): 2a39b6d

Implement request batching

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -10,10 +10,12 @@ from typing import TypedDict, Self, Any
10
  import json
11
  from pathlib import Path
12
  from dataclasses import dataclass
 
13
 
14
  MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
15
  DIR = Path("index")
16
  SEARCH_TIME_S = 1 # TODO: optimize
 
17
 
18
 
19
  class IndexParameters(TypedDict):
@@ -119,12 +121,15 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
119
 
120
 
121
  def execute_request(ids: list[str]) -> list[Work]:
 
 
 
122
  # query with the /works endpoint with a specific list of IDs and fields
123
  search_filter = f"openalex_id:{"|".join(ids)}"
124
  search_select = ",".join(["id"] + Work.get_raw_fields())
125
  response = requests.get(
126
  "https://api.openalex.org/works",
127
- {"filter": search_filter, "select": search_select}
128
  )
129
  response.raise_for_status()
130
 
@@ -196,17 +201,19 @@ model = get_model(MODEL_NAME, "cpu")
196
  index = get_index(DIR, SEARCH_TIME_S)
197
 
198
 
199
- def search(query: str) -> str:
200
- global model, index
201
-
202
  query_embedding = model.encode(query)
203
- distances, faiss_ids = index.search("embeddings", query_embedding, 20)
204
- openalex_ids = index[faiss_ids]["idxs"]
205
 
206
- works = execute_request(openalex_ids)
207
- result_string = format_response(works, distances)
 
 
208
 
209
- return result_string
 
 
210
 
211
 
212
  with gr.Blocks() as demo:
@@ -233,8 +240,8 @@ with gr.Blocks() as demo:
233
  container=True,
234
  )
235
 
236
- query.submit(search, inputs=[query], outputs=[results])
237
- btn.click(search, inputs=[query], outputs=[results])
238
 
239
- demo.queue(2)
240
  demo.launch()
 
10
  import json
11
  from pathlib import Path
12
  from dataclasses import dataclass
13
+ from itertools import batched, chain
14
 
15
  MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
16
  DIR = Path("index")
17
  SEARCH_TIME_S = 1 # TODO: optimize
18
+ K = 20
19
 
20
 
21
  class IndexParameters(TypedDict):
 
121
 
122
 
123
  def execute_request(ids: list[str]) -> list[Work]:
124
+ if len(ids) > 100:
125
+ raise ValueError("querying /works endpoint with more than 100 works")
126
+
127
  # query with the /works endpoint with a specific list of IDs and fields
128
  search_filter = f"openalex_id:{"|".join(ids)}"
129
  search_select = ",".join(["id"] + Work.get_raw_fields())
130
  response = requests.get(
131
  "https://api.openalex.org/works",
132
+ {"filter": search_filter, "select": search_select, "per-page": 100}
133
  )
134
  response.raise_for_status()
135
 
 
201
  index = get_index(DIR, SEARCH_TIME_S)
202
 
203
 
204
+ # function signature: (expanded tuple of input batches) -> tuple of output batches
205
+ def search(query: list[str]) -> tuple[list[str]]:
 
206
  query_embedding = model.encode(query)
207
+ distances, faiss_ids = index.search_batch("embeddings", query_embedding, K)
 
208
 
209
+ faiss_ids_flat = list(chain(*faiss_ids))
210
+ openalex_ids_flat = index[faiss_ids_flat]["idxs"]
211
+ works_flat = execute_request(openalex_ids_flat)
212
+ works = [list(batch) for batch in batched(works_flat, K)]
213
 
214
+ result_strings = [format_response(w, d) for w, d in zip(works, distances)]
215
+
216
+ return (result_strings, )
217
 
218
 
219
  with gr.Blocks() as demo:
 
240
  container=True,
241
  )
242
 
243
+ query.submit(search, inputs=[query], outputs=[results], batch=True)
244
+ btn.click(search, inputs=[query], outputs=[results], batch=True)
245
 
246
+ demo.queue()
247
  demo.launch()