Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
303669c
1
Parent(s):
65eeead
Minimize ZeroGPU utilization time by cutting out SentenceTransformer overhead
Browse files
app.py
CHANGED
@@ -267,6 +267,20 @@ def main():
|
|
267 |
normalize, model = get_model(model_name, dir, trust_remote_code)
|
268 |
index = get_index(dir, search_time_s)
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
model.eval()
|
271 |
if torch.cuda.is_available():
|
272 |
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
|
@@ -275,17 +289,35 @@ def main():
|
|
275 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
276 |
model.compile(mode="reduce-overhead")
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
if spaces:
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
def search(query: str) -> str:
|
288 |
-
query_embedding =
|
289 |
distances, faiss_ids = index.search("embedding", query_embedding, k)
|
290 |
|
291 |
openalex_ids = index[faiss_ids]["id"]
|
|
|
267 |
normalize, model = get_model(model_name, dir, trust_remote_code)
|
268 |
index = get_index(dir, search_time_s)
|
269 |
|
270 |
+
# follow model.encode logic for acquiring the prompt
|
271 |
+
if prompt_name is None and model.default_prompt_name is not None:
|
272 |
+
prompt_name = model.default_prompt_name
|
273 |
+
if not isinstance(prompt_name, str):
|
274 |
+
raise TypeError("invalid prompt name type")
|
275 |
+
prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None
|
276 |
+
|
277 |
+
# follow model.encode logic for setting extra_features
|
278 |
+
extra_features: dict[str, Any] = {}
|
279 |
+
if prompt is not None:
|
280 |
+
tokenized = model.tokenize([prompt])
|
281 |
+
if "input_ids" in tokenized:
|
282 |
+
extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1
|
283 |
+
|
284 |
model.eval()
|
285 |
if torch.cuda.is_available():
|
286 |
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
|
|
|
289 |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
290 |
model.compile(mode="reduce-overhead")
|
291 |
|
292 |
+
def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]:
|
293 |
+
# Tokenize (which yields a dict) then do a non-blocking transfer
|
294 |
+
features = {
|
295 |
+
k: v.to(model.device, non_blocking=True) for k, v in features.items()
|
296 |
+
} | extra_features
|
297 |
+
|
298 |
+
with torch.no_grad():
|
299 |
+
out_features = model.forward(features)
|
300 |
+
embeddings = out_features["sentence_embedding"]
|
301 |
+
|
302 |
+
embeddings = embeddings[0]
|
303 |
+
if model.truncate_dim:
|
304 |
+
embeddings = embeddings[:model.truncate_dim]
|
305 |
+
if normalize:
|
306 |
+
embeddings = torch.nn.functional.normalize(embeddings, dim=0)
|
307 |
+
|
308 |
+
return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr
|
309 |
+
|
310 |
if spaces:
|
311 |
+
encode_tokens = spaces.GPU(encode_tokens)
|
312 |
+
|
313 |
+
def encode_string(query: str) -> npt.NDArray[np.float32]:
|
314 |
+
if prompt:
|
315 |
+
query = prompt + query
|
316 |
+
tokens = model.tokenize([query])
|
317 |
+
return encode_tokens(tokens)
|
318 |
|
319 |
def search(query: str) -> str:
|
320 |
+
query_embedding = encode_string(query)
|
321 |
distances, faiss_ids = index.search("embedding", query_embedding, k)
|
322 |
|
323 |
openalex_ids = index[faiss_ids]["id"]
|