colonelwatch commited on
Commit
65eeead
·
1 Parent(s): 561e8f7

Drop batching because ZeroGPU crashes with it enabled

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -275,35 +275,23 @@ def main():
275
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
276
  model.compile(mode="reduce-overhead")
277
 
278
- def encode(query: list[str]) -> npt.NDArray[np.float16 | np.float32]:
279
- return model.encode(query, prompt_name, normalize_embeddings=normalize)
 
 
 
 
280
  if spaces:
281
  encode = spaces.GPU(encode)
282
 
283
- # function signature: (expanded tuple of input batches) -> tuple of output batches
284
- def search(query: list[str]) -> tuple[list[str]]:
285
  query_embedding = encode(query)
286
- distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
287
-
288
- faiss_ids_flat = list(chain(*faiss_ids))
289
- openalex_ids_flat = index[faiss_ids_flat]["id"]
290
- works_flat = execute_request(openalex_ids_flat, mailto)
291
-
292
- temp: list[Work] = []
293
- works: list[list[Work]] = []
294
- for work in works_flat:
295
- temp.append(work)
296
- if len(temp) == k:
297
- works.append(temp)
298
- temp = []
299
- assert not temp, "request a multiple of k IDs, did not get a multiple back"
300
-
301
- result_strings = [
302
- format_response(w, d, calculate_similarity=normalize)
303
- for w, d in zip(works, distances)
304
- ]
305
 
306
- return (result_strings, )
307
 
308
  with gr.Blocks() as demo:
309
  # figure out the words to describe the quantity
@@ -349,8 +337,9 @@ def main():
349
  container=True,
350
  )
351
 
352
- query.submit(search, inputs=[query], outputs=[results], batch=True)
353
- btn.click(search, inputs=[query], outputs=[results], batch=True)
 
354
 
355
  demo.queue()
356
  demo.launch()
 
275
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
276
  model.compile(mode="reduce-overhead")
277
 
278
+ # TODO: use something like the encode_faster function from the main repo to minimize
279
+ # alloc'd GPU time
280
+ def encode(query: str) -> npt.NDArray[np.float16 | np.float32]:
281
+ return model.encode(
282
+ query, prompt_name, convert_to_numpy=True, normalize_embeddings=normalize
283
+ )
284
  if spaces:
285
  encode = spaces.GPU(encode)
286
 
287
+ def search(query: str) -> str:
 
288
  query_embedding = encode(query)
289
+ distances, faiss_ids = index.search("embedding", query_embedding, k)
290
+
291
+ openalex_ids = index[faiss_ids]["id"]
292
+ works = execute_request(openalex_ids, mailto)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
+ return format_response(works, distances, calculate_similarity=normalize)
295
 
296
  with gr.Blocks() as demo:
297
  # figure out the words to describe the quantity
 
337
  container=True,
338
  )
339
 
340
+ # NOTE: ZeroGPU doesn't seem to support batching
341
+ query.submit(search, inputs=[query], outputs=[results])
342
+ btn.click(search, inputs=[query], outputs=[results])
343
 
344
  demo.queue()
345
  demo.launch()