Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a228730
1
Parent(s):
2a39b6d
Implement request batching
Browse files
app.py
CHANGED
@@ -10,10 +10,12 @@ from typing import TypedDict, Self, Any
|
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from dataclasses import dataclass
|
|
|
13 |
|
14 |
MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
|
15 |
DIR = Path("index")
|
16 |
SEARCH_TIME_S = 1 # TODO: optimize
|
|
|
17 |
|
18 |
|
19 |
class IndexParameters(TypedDict):
|
@@ -119,12 +121,15 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
|
|
119 |
|
120 |
|
121 |
def execute_request(ids: list[str]) -> list[Work]:
|
|
|
|
|
|
|
122 |
# query with the /works endpoint with a specific list of IDs and fields
|
123 |
search_filter = f"openalex_id:{"|".join(ids)}"
|
124 |
search_select = ",".join(["id"] + Work.get_raw_fields())
|
125 |
response = requests.get(
|
126 |
"https://api.openalex.org/works",
|
127 |
-
{"filter": search_filter, "select": search_select}
|
128 |
)
|
129 |
response.raise_for_status()
|
130 |
|
@@ -196,17 +201,19 @@ model = get_model(MODEL_NAME, "cpu")
|
|
196 |
index = get_index(DIR, SEARCH_TIME_S)
|
197 |
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
query_embedding = model.encode(query)
|
203 |
-
distances, faiss_ids = index.
|
204 |
-
openalex_ids = index[faiss_ids]["idxs"]
|
205 |
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
|
209 |
-
|
|
|
|
|
210 |
|
211 |
|
212 |
with gr.Blocks() as demo:
|
@@ -233,8 +240,8 @@ with gr.Blocks() as demo:
|
|
233 |
container=True,
|
234 |
)
|
235 |
|
236 |
-
query.submit(search, inputs=[query], outputs=[results])
|
237 |
-
btn.click(search, inputs=[query], outputs=[results])
|
238 |
|
239 |
-
demo.queue(
|
240 |
demo.launch()
|
|
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from dataclasses import dataclass
|
13 |
+
from itertools import batched, chain
|
14 |
|
15 |
MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
|
16 |
DIR = Path("index")
|
17 |
SEARCH_TIME_S = 1 # TODO: optimize
|
18 |
+
K = 20
|
19 |
|
20 |
|
21 |
class IndexParameters(TypedDict):
|
|
|
121 |
|
122 |
|
123 |
def execute_request(ids: list[str]) -> list[Work]:
|
124 |
+
if len(ids) > 100:
|
125 |
+
raise ValueError("querying /works endpoint with more than 100 works")
|
126 |
+
|
127 |
# query with the /works endpoint with a specific list of IDs and fields
|
128 |
search_filter = f"openalex_id:{"|".join(ids)}"
|
129 |
search_select = ",".join(["id"] + Work.get_raw_fields())
|
130 |
response = requests.get(
|
131 |
"https://api.openalex.org/works",
|
132 |
+
{"filter": search_filter, "select": search_select, "per-page": 100}
|
133 |
)
|
134 |
response.raise_for_status()
|
135 |
|
|
|
201 |
index = get_index(DIR, SEARCH_TIME_S)
|
202 |
|
203 |
|
204 |
+
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
205 |
+
def search(query: list[str]) -> tuple[list[str]]:
|
|
|
206 |
query_embedding = model.encode(query)
|
207 |
+
distances, faiss_ids = index.search_batch("embeddings", query_embedding, K)
|
|
|
208 |
|
209 |
+
faiss_ids_flat = list(chain(*faiss_ids))
|
210 |
+
openalex_ids_flat = index[faiss_ids_flat]["idxs"]
|
211 |
+
works_flat = execute_request(openalex_ids_flat)
|
212 |
+
works = [list(batch) for batch in batched(works_flat, K)]
|
213 |
|
214 |
+
result_strings = [format_response(w, d) for w, d in zip(works, distances)]
|
215 |
+
|
216 |
+
return (result_strings, )
|
217 |
|
218 |
|
219 |
with gr.Blocks() as demo:
|
|
|
240 |
container=True,
|
241 |
)
|
242 |
|
243 |
+
query.submit(search, inputs=[query], outputs=[results], batch=True)
|
244 |
+
btn.click(search, inputs=[query], outputs=[results], batch=True)
|
245 |
|
246 |
+
demo.queue()
|
247 |
demo.launch()
|