Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
65eeead
1
Parent(s):
561e8f7
Drop batching because ZeroGPU crashes with it enabled
Browse files
app.py
CHANGED
@@ -275,35 +275,23 @@ def main():
|
|
275 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
276 |
model.compile(mode="reduce-overhead")
|
277 |
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
280 |
if spaces:
|
281 |
encode = spaces.GPU(encode)
|
282 |
|
283 |
-
|
284 |
-
def search(query: list[str]) -> tuple[list[str]]:
|
285 |
query_embedding = encode(query)
|
286 |
-
distances, faiss_ids = index.
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
works_flat = execute_request(openalex_ids_flat, mailto)
|
291 |
-
|
292 |
-
temp: list[Work] = []
|
293 |
-
works: list[list[Work]] = []
|
294 |
-
for work in works_flat:
|
295 |
-
temp.append(work)
|
296 |
-
if len(temp) == k:
|
297 |
-
works.append(temp)
|
298 |
-
temp = []
|
299 |
-
assert not temp, "request a multiple of k IDs, did not get a multiple back"
|
300 |
-
|
301 |
-
result_strings = [
|
302 |
-
format_response(w, d, calculate_similarity=normalize)
|
303 |
-
for w, d in zip(works, distances)
|
304 |
-
]
|
305 |
|
306 |
-
return (
|
307 |
|
308 |
with gr.Blocks() as demo:
|
309 |
# figure out the words to describe the quantity
|
@@ -349,8 +337,9 @@ def main():
|
|
349 |
container=True,
|
350 |
)
|
351 |
|
352 |
-
|
353 |
-
|
|
|
354 |
|
355 |
demo.queue()
|
356 |
demo.launch()
|
|
|
275 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
276 |
model.compile(mode="reduce-overhead")
|
277 |
|
278 |
+
# TODO: use something like the encode_faster function from the main repo to minimize
|
279 |
+
# alloc'd GPU time
|
280 |
+
def encode(query: str) -> npt.NDArray[np.float16 | np.float32]:
|
281 |
+
return model.encode(
|
282 |
+
query, prompt_name, convert_to_numpy=True, normalize_embeddings=normalize
|
283 |
+
)
|
284 |
if spaces:
|
285 |
encode = spaces.GPU(encode)
|
286 |
|
287 |
+
def search(query: str) -> str:
|
|
|
288 |
query_embedding = encode(query)
|
289 |
+
distances, faiss_ids = index.search("embedding", query_embedding, k)
|
290 |
+
|
291 |
+
openalex_ids = index[faiss_ids]["id"]
|
292 |
+
works = execute_request(openalex_ids, mailto)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
+
return format_response(works, distances, calculate_similarity=normalize)
|
295 |
|
296 |
with gr.Blocks() as demo:
|
297 |
# figure out the words to describe the quantity
|
|
|
337 |
container=True,
|
338 |
)
|
339 |
|
340 |
+
# NOTE: ZeroGPU doesn't seem to support batching
|
341 |
+
query.submit(search, inputs=[query], outputs=[results])
|
342 |
+
btn.click(search, inputs=[query], outputs=[results])
|
343 |
|
344 |
demo.queue()
|
345 |
demo.launch()
|