colonelwatch commited on
Commit
22ed20d
·
1 Parent(s): 2c87116

Run model on GPU and add fp16 and trust_remote_code options

Browse files
Files changed (2) hide show
  1. app.py +12 -3
  2. requirements.txt +1 -1
app.py CHANGED
@@ -12,6 +12,8 @@ from pathlib import Path
12
  from dataclasses import dataclass
13
  from itertools import batched, chain
14
  import os
 
 
15
 
16
 
17
  class IndexParameters(TypedDict):
@@ -105,8 +107,8 @@ def get_env_var[T, U](
105
  return var
106
 
107
 
108
- def get_model(model_name: str, device: str) -> SentenceTransformer:
109
- return SentenceTransformer(model_name, device=device)
110
 
111
 
112
  def get_index(dir: Path, search_time_s: float) -> Dataset:
@@ -207,13 +209,20 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
207
  def main():
208
  # TODO: figure out some better defaults?
209
  model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
 
 
210
  dir = get_env_var("DIR", Path, default=Path("index"))
211
  search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
212
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
213
  mailto = get_env_var("MAILTO", str, None)
214
 
215
- model = get_model(model_name, "cpu")
216
  index = get_index(dir, search_time_s)
 
 
 
 
 
217
 
218
  # function signature: (expanded tuple of input batches) -> tuple of output batches
219
  def search(query: list[str]) -> tuple[list[str]]:
 
12
  from dataclasses import dataclass
13
  from itertools import batched, chain
14
  import os
15
+ import torch
16
+ from sys import stderr
17
 
18
 
19
  class IndexParameters(TypedDict):
 
107
  return var
108
 
109
 
110
+ def get_model(model_name: str, trust_remote_code: bool) -> SentenceTransformer:
111
+ return SentenceTransformer(model_name, trust_remote_code=trust_remote_code)
112
 
113
 
114
  def get_index(dir: Path, search_time_s: float) -> Dataset:
 
209
  def main():
210
  # TODO: figure out some better defaults?
211
  model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
212
+ trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
213
+ fp16 = get_env_var("FP16", bool, default=False)
214
  dir = get_env_var("DIR", Path, default=Path("index"))
215
  search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
216
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
217
  mailto = get_env_var("MAILTO", str, None)
218
 
219
+ model = get_model(model_name, trust_remote_code)
220
  index = get_index(dir, search_time_s)
221
+ if torch.cuda.is_available():
222
+ model = model.half().cuda() if fp16 else model.bfloat16().cuda()
223
+ # TODO: if huggingface datasets exposes an fp16 gpu option, use it here
224
+ elif fp16:
225
+ print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
226
 
227
  # function signature: (expanded tuple of input batches) -> tuple of output batches
228
  def search(query: list[str]) -> tuple[list[str]]:
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  sentence-transformers
2
- faiss-cpu
3
  datasets
 
1
  sentence-transformers
2
+ faiss-gpu
3
  datasets