Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4751a57
1
Parent(s):
a8b8684
Handle new params.json format, including truncation and normalization
Browse files
app.py
CHANGED
@@ -23,6 +23,12 @@ class IndexParameters(TypedDict):
|
|
23 |
param_string: str # pass directly to faiss index
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
@dataclass
|
27 |
class Work:
|
28 |
title: str | None
|
@@ -108,8 +114,16 @@ def get_env_var[T, U](
|
|
108 |
return var
|
109 |
|
110 |
|
111 |
-
def get_model(
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
|
115 |
def get_index(dir: Path, search_time_s: float) -> Dataset:
|
@@ -118,14 +132,14 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
|
|
118 |
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
119 |
|
120 |
with open(dir / "params.json", "r") as f:
|
121 |
-
params:
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
ps = faiss.ParameterSpace()
|
127 |
ps.initialize(faiss_index)
|
128 |
-
ps.set_index_parameters(faiss_index,
|
129 |
|
130 |
return index
|
131 |
|
@@ -218,9 +232,10 @@ def main():
|
|
218 |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
|
219 |
mailto = get_env_var("MAILTO", str, None)
|
220 |
|
221 |
-
model = get_model(model_name, trust_remote_code)
|
222 |
index = get_index(dir, search_time_s)
|
223 |
|
|
|
224 |
if torch.cuda.is_available():
|
225 |
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
|
226 |
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
|
@@ -230,7 +245,9 @@ def main():
|
|
230 |
|
231 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
232 |
def search(query: list[str]) -> tuple[list[str]]:
|
233 |
-
query_embedding = model.encode(
|
|
|
|
|
234 |
distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
|
235 |
|
236 |
faiss_ids_flat = list(chain(*faiss_ids))
|
|
|
23 |
param_string: str # pass directly to faiss index
|
24 |
|
25 |
|
26 |
+
class Params(TypedDict):
|
27 |
+
dimensions: int | None
|
28 |
+
normalize: bool
|
29 |
+
optimal_params: list[IndexParameters]
|
30 |
+
|
31 |
+
|
32 |
@dataclass
|
33 |
class Work:
|
34 |
title: str | None
|
|
|
114 |
return var
|
115 |
|
116 |
|
117 |
+
def get_model(
|
118 |
+
model_name: str, params_dir: Path, trust_remote_code: bool
|
119 |
+
) -> tuple[bool, SentenceTransformer]:
|
120 |
+
with open(params_dir / "params.json", "r") as f:
|
121 |
+
params: Params = json.load(f)
|
122 |
+
return params["normalize"], SentenceTransformer(
|
123 |
+
model_name,
|
124 |
+
trust_remote_code=trust_remote_code,
|
125 |
+
truncate_dim=params["dimensions"]
|
126 |
+
)
|
127 |
|
128 |
|
129 |
def get_index(dir: Path, search_time_s: float) -> Dataset:
|
|
|
132 |
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
133 |
|
134 |
with open(dir / "params.json", "r") as f:
|
135 |
+
params: Params = json.load(f)
|
136 |
+
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s]
|
137 |
+
optimal = max(under, key=(lambda p: p["recall"]))
|
138 |
+
optimal_string = optimal["param_string"]
|
139 |
|
140 |
ps = faiss.ParameterSpace()
|
141 |
ps.initialize(faiss_index)
|
142 |
+
ps.set_index_parameters(faiss_index, optimal_string)
|
143 |
|
144 |
return index
|
145 |
|
|
|
232 |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
|
233 |
mailto = get_env_var("MAILTO", str, None)
|
234 |
|
235 |
+
normalize, model = get_model(model_name, dir, trust_remote_code)
|
236 |
index = get_index(dir, search_time_s)
|
237 |
|
238 |
+
model.eval()
|
239 |
if torch.cuda.is_available():
|
240 |
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
|
241 |
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
|
|
|
245 |
|
246 |
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
247 |
def search(query: list[str]) -> tuple[list[str]]:
|
248 |
+
query_embedding = model.encode(
|
249 |
+
query, prompt_name, normalize_embeddings=normalize
|
250 |
+
)
|
251 |
distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
|
252 |
|
253 |
faiss_ids_flat = list(chain(*faiss_ids))
|