Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7bb8a83
1
Parent(s):
138f923
Compute the similarity instead of the distance if applicable
Browse files
app.py
CHANGED
@@ -118,6 +118,7 @@ def get_env_var[T, U](
|
|
118 |
def get_model(
|
119 |
model_name: str, params_dir: Path, trust_remote_code: bool
|
120 |
) -> tuple[bool, SentenceTransformer]:
|
|
|
121 |
with open(params_dir / "params.json", "r") as f:
|
122 |
params: Params = json.load(f)
|
123 |
return params["normalize"], SentenceTransformer(
|
@@ -173,7 +174,9 @@ def collapse_newlines(x: str) -> str:
|
|
173 |
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
|
174 |
|
175 |
|
176 |
-
def format_response(
|
|
|
|
|
177 |
result_string = ""
|
178 |
for work, distance in zip(neighbors, distances):
|
179 |
entry_string = "## "
|
@@ -218,7 +221,11 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
|
|
218 |
meta.append(("Cited-by count", str(work.citations)))
|
219 |
if work.doi:
|
220 |
meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
entry_string += (" " * 4).join(": ".join(tup) for tup in meta)
|
223 |
|
224 |
entry_string += "*\n"
|
@@ -264,7 +271,10 @@ def main():
|
|
264 |
works_flat = execute_request(openalex_ids_flat, mailto)
|
265 |
works = [list(batch) for batch in batched(works_flat, k)]
|
266 |
|
267 |
-
result_strings = [
|
|
|
|
|
|
|
268 |
|
269 |
return (result_strings, )
|
270 |
|
|
|
118 |
def get_model(
|
119 |
model_name: str, params_dir: Path, trust_remote_code: bool
|
120 |
) -> tuple[bool, SentenceTransformer]:
|
121 |
+
# TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize?
|
122 |
with open(params_dir / "params.json", "r") as f:
|
123 |
params: Params = json.load(f)
|
124 |
return params["normalize"], SentenceTransformer(
|
|
|
174 |
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
|
175 |
|
176 |
|
177 |
+
def format_response(
|
178 |
+
neighbors: list[Work], distances: list[float], calculate_similarity: bool = False
|
179 |
+
) -> str:
|
180 |
result_string = ""
|
181 |
for work, distance in zip(neighbors, distances):
|
182 |
entry_string = "## "
|
|
|
221 |
meta.append(("Cited-by count", str(work.citations)))
|
222 |
if work.doi:
|
223 |
meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
|
224 |
+
if calculate_similarity:
|
225 |
+
# if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2
|
226 |
+
meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2
|
227 |
+
else:
|
228 |
+
meta.append(("Distance", f"{distance:.2f}"))
|
229 |
entry_string += (" " * 4).join(": ".join(tup) for tup in meta)
|
230 |
|
231 |
entry_string += "*\n"
|
|
|
271 |
works_flat = execute_request(openalex_ids_flat, mailto)
|
272 |
works = [list(batch) for batch in batched(works_flat, k)]
|
273 |
|
274 |
+
result_strings = [
|
275 |
+
format_response(w, d, calculate_similarity=normalize)
|
276 |
+
for w, d in zip(works, distances)
|
277 |
+
]
|
278 |
|
279 |
return (result_strings, )
|
280 |
|