colonelwatch commited on
Commit
4751a57
·
1 Parent(s): a8b8684

Handle new params.json format, including truncation and normalization

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -23,6 +23,12 @@ class IndexParameters(TypedDict):
23
  param_string: str # pass directly to faiss index
24
 
25
 
 
 
 
 
 
 
26
  @dataclass
27
  class Work:
28
  title: str | None
@@ -108,8 +114,16 @@ def get_env_var[T, U](
108
  return var
109
 
110
 
111
- def get_model(model_name: str, trust_remote_code: bool) -> SentenceTransformer:
112
- return SentenceTransformer(model_name, trust_remote_code=trust_remote_code)
 
 
 
 
 
 
 
 
113
 
114
 
115
  def get_index(dir: Path, search_time_s: float) -> Dataset:
@@ -118,14 +132,14 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
118
  faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
119
 
120
  with open(dir / "params.json", "r") as f:
121
- params: list[IndexParameters] = json.load(f)
122
- params = [p for p in params if p["exec_time"] < search_time_s]
123
- param = max(params, key=(lambda p: p["recall"]))
124
- param_string = param["param_string"]
125
 
126
  ps = faiss.ParameterSpace()
127
  ps.initialize(faiss_index)
128
- ps.set_index_parameters(faiss_index, param_string)
129
 
130
  return index
131
 
@@ -218,9 +232,10 @@ def main():
218
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
219
  mailto = get_env_var("MAILTO", str, None)
220
 
221
- model = get_model(model_name, trust_remote_code)
222
  index = get_index(dir, search_time_s)
223
 
 
224
  if torch.cuda.is_available():
225
  model = model.half().cuda() if fp16 else model.bfloat16().cuda()
226
  # TODO: if huggingface datasets exposes an fp16 gpu option, use it here
@@ -230,7 +245,9 @@ def main():
230
 
231
  # function signature: (expanded tuple of input batches) -> tuple of output batches
232
  def search(query: list[str]) -> tuple[list[str]]:
233
- query_embedding = model.encode(query, prompt_name)
 
 
234
  distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
235
 
236
  faiss_ids_flat = list(chain(*faiss_ids))
 
23
  param_string: str # pass directly to faiss index
24
 
25
 
26
+ class Params(TypedDict):
27
+ dimensions: int | None
28
+ normalize: bool
29
+ optimal_params: list[IndexParameters]
30
+
31
+
32
  @dataclass
33
  class Work:
34
  title: str | None
 
114
  return var
115
 
116
 
117
+ def get_model(
118
+ model_name: str, params_dir: Path, trust_remote_code: bool
119
+ ) -> tuple[bool, SentenceTransformer]:
120
+ with open(params_dir / "params.json", "r") as f:
121
+ params: Params = json.load(f)
122
+ return params["normalize"], SentenceTransformer(
123
+ model_name,
124
+ trust_remote_code=trust_remote_code,
125
+ truncate_dim=params["dimensions"]
126
+ )
127
 
128
 
129
  def get_index(dir: Path, search_time_s: float) -> Dataset:
 
132
  faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
133
 
134
  with open(dir / "params.json", "r") as f:
135
+ params: Params = json.load(f)
136
+ under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s]
137
+ optimal = max(under, key=(lambda p: p["recall"]))
138
+ optimal_string = optimal["param_string"]
139
 
140
  ps = faiss.ParameterSpace()
141
  ps.initialize(faiss_index)
142
+ ps.set_index_parameters(faiss_index, optimal_string)
143
 
144
  return index
145
 
 
232
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
233
  mailto = get_env_var("MAILTO", str, None)
234
 
235
+ normalize, model = get_model(model_name, dir, trust_remote_code)
236
  index = get_index(dir, search_time_s)
237
 
238
+ model.eval()
239
  if torch.cuda.is_available():
240
  model = model.half().cuda() if fp16 else model.bfloat16().cuda()
241
  # TODO: if huggingface datasets exposes an fp16 gpu option, use it here
 
245
 
246
  # function signature: (expanded tuple of input batches) -> tuple of output batches
247
  def search(query: list[str]) -> tuple[list[str]]:
248
+ query_embedding = model.encode(
249
+ query, prompt_name, normalize_embeddings=normalize
250
+ )
251
  distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
252
 
253
  faiss_ids_flat = list(chain(*faiss_ids))