colonelwatch commited on
Commit
303669c
·
1 Parent(s): 65eeead

Minimize ZeroGPU utilization time by cutting out SentenceTransformer overhead

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -267,6 +267,20 @@ def main():
267
  normalize, model = get_model(model_name, dir, trust_remote_code)
268
  index = get_index(dir, search_time_s)
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  model.eval()
271
  if torch.cuda.is_available():
272
  model = model.half().cuda() if fp16 else model.bfloat16().cuda()
@@ -275,17 +289,35 @@ def main():
275
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
276
  model.compile(mode="reduce-overhead")
277
 
278
- # TODO: use something like the encode_faster function from the main repo to minimize
279
- # alloc'd GPU time
280
- def encode(query: str) -> npt.NDArray[np.float16 | np.float32]:
281
- return model.encode(
282
- query, prompt_name, convert_to_numpy=True, normalize_embeddings=normalize
283
- )
 
 
 
 
 
 
 
 
 
 
 
 
284
  if spaces:
285
- encode = spaces.GPU(encode)
 
 
 
 
 
 
286
 
287
  def search(query: str) -> str:
288
- query_embedding = encode(query)
289
  distances, faiss_ids = index.search("embedding", query_embedding, k)
290
 
291
  openalex_ids = index[faiss_ids]["id"]
 
267
  normalize, model = get_model(model_name, dir, trust_remote_code)
268
  index = get_index(dir, search_time_s)
269
 
270
+ # follow model.encode logic for acquiring the prompt
271
+ if prompt_name is None and model.default_prompt_name is not None:
272
+ prompt_name = model.default_prompt_name
273
+ if not isinstance(prompt_name, str):
274
+ raise TypeError("invalid prompt name type")
275
+ prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None
276
+
277
+ # follow model.encode logic for setting extra_features
278
+ extra_features: dict[str, Any] = {}
279
+ if prompt is not None:
280
+ tokenized = model.tokenize([prompt])
281
+ if "input_ids" in tokenized:
282
+ extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1
283
+
284
  model.eval()
285
  if torch.cuda.is_available():
286
  model = model.half().cuda() if fp16 else model.bfloat16().cuda()
 
289
  print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
290
  model.compile(mode="reduce-overhead")
291
 
292
+ def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]:
293
+ # Tokenize (which yields a dict) then do a non-blocking transfer
294
+ features = {
295
+ k: v.to(model.device, non_blocking=True) for k, v in features.items()
296
+ } | extra_features
297
+
298
+ with torch.no_grad():
299
+ out_features = model.forward(features)
300
+ embeddings = out_features["sentence_embedding"]
301
+
302
+ embeddings = embeddings[0]
303
+ if model.truncate_dim:
304
+ embeddings = embeddings[:model.truncate_dim]
305
+ if normalize:
306
+ embeddings = torch.nn.functional.normalize(embeddings, dim=0)
307
+
308
+ return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr
309
+
310
  if spaces:
311
+ encode_tokens = spaces.GPU(encode_tokens)
312
+
313
+ def encode_string(query: str) -> npt.NDArray[np.float32]:
314
+ if prompt:
315
+ query = prompt + query
316
+ tokens = model.tokenize([query])
317
+ return encode_tokens(tokens)
318
 
319
  def search(query: str) -> str:
320
+ query_embedding = encode_string(query)
321
  distances, faiss_ids = index.search("embedding", query_embedding, k)
322
 
323
  openalex_ids = index[faiss_ids]["id"]