Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
22ed20d
1
Parent(s):
2c87116
Run model on GPU and add fp16 and trust_remote_code options
Browse files- app.py +12 -3
- requirements.txt +1 -1
app.py
CHANGED
@@ -12,6 +12,8 @@ from pathlib import Path
|
|
12 |
from dataclasses import dataclass
|
13 |
from itertools import batched, chain
|
14 |
import os
|
|
|
|
|
15 |
|
16 |
|
17 |
class IndexParameters(TypedDict):
|
@@ -105,8 +107,8 @@ def get_env_var[T, U](
|
|
105 |
return var
|
106 |
|
107 |
|
108 |
-
def get_model(model_name: str,
|
109 |
-
return SentenceTransformer(model_name,
|
110 |
|
111 |
|
112 |
def get_index(dir: Path, search_time_s: float) -> Dataset:
|
@@ -207,13 +209,20 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
|
|
207 |
def main():
|
208 |
# TODO: figure out some better defaults?
|
209 |
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
|
|
|
|
|
210 |
dir = get_env_var("DIR", Path, default=Path("index"))
|
211 |
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
|
212 |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
|
213 |
mailto = get_env_var("MAILTO", str, None)
|
214 |
|
215 |
-
model = get_model(model_name,
|
216 |
index = get_index(dir, search_time_s)
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
219 |
def search(query: list[str]) -> tuple[list[str]]:
|
|
|
12 |
from dataclasses import dataclass
|
13 |
from itertools import batched, chain
|
14 |
import os
|
15 |
+
import torch
|
16 |
+
from sys import stderr
|
17 |
|
18 |
|
19 |
class IndexParameters(TypedDict):
|
|
|
107 |
return var
|
108 |
|
109 |
|
110 |
+
def get_model(model_name: str, trust_remote_code: bool) -> SentenceTransformer:
|
111 |
+
return SentenceTransformer(model_name, trust_remote_code=trust_remote_code)
|
112 |
|
113 |
|
114 |
def get_index(dir: Path, search_time_s: float) -> Dataset:
|
|
|
209 |
def main():
|
210 |
# TODO: figure out some better defaults?
|
211 |
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
|
212 |
+
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
|
213 |
+
fp16 = get_env_var("FP16", bool, default=False)
|
214 |
dir = get_env_var("DIR", Path, default=Path("index"))
|
215 |
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
|
216 |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
|
217 |
mailto = get_env_var("MAILTO", str, None)
|
218 |
|
219 |
+
model = get_model(model_name, trust_remote_code)
|
220 |
index = get_index(dir, search_time_s)
|
221 |
+
if torch.cuda.is_available():
|
222 |
+
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
|
223 |
+
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
|
224 |
+
elif fp16:
|
225 |
+
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
|
226 |
|
227 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
228 |
def search(query: list[str]) -> tuple[list[str]]:
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
sentence-transformers
|
2 |
-
faiss-
|
3 |
datasets
|
|
|
1 |
sentence-transformers
|
2 |
+
faiss-gpu
|
3 |
datasets
|