colonelwatch commited on
Commit
7bb8a83
·
1 Parent(s): 138f923

Compute the similarity instead of the distance if applicable

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -118,6 +118,7 @@ def get_env_var[T, U](
118
  def get_model(
119
  model_name: str, params_dir: Path, trust_remote_code: bool
120
  ) -> tuple[bool, SentenceTransformer]:
 
121
  with open(params_dir / "params.json", "r") as f:
122
  params: Params = json.load(f)
123
  return params["normalize"], SentenceTransformer(
@@ -173,7 +174,9 @@ def collapse_newlines(x: str) -> str:
173
  return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
174
 
175
 
176
- def format_response(neighbors: list[Work], distances: list[float]) -> str:
 
 
177
  result_string = ""
178
  for work, distance in zip(neighbors, distances):
179
  entry_string = "## "
@@ -218,7 +221,11 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
218
  meta.append(("Cited-by count", str(work.citations)))
219
  if work.doi:
220
  meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
221
- meta.append(("Similarity", f"{distance:.2f}"))
 
 
 
 
222
  entry_string += (" " * 4).join(": ".join(tup) for tup in meta)
223
 
224
  entry_string += "*\n"
@@ -264,7 +271,10 @@ def main():
264
  works_flat = execute_request(openalex_ids_flat, mailto)
265
  works = [list(batch) for batch in batched(works_flat, k)]
266
 
267
- result_strings = [format_response(w, d) for w, d in zip(works, distances)]
 
 
 
268
 
269
  return (result_strings, )
270
 
 
118
  def get_model(
119
  model_name: str, params_dir: Path, trust_remote_code: bool
120
  ) -> tuple[bool, SentenceTransformer]:
121
+ # TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize?
122
  with open(params_dir / "params.json", "r") as f:
123
  params: Params = json.load(f)
124
  return params["normalize"], SentenceTransformer(
 
174
  return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
175
 
176
 
177
+ def format_response(
178
+ neighbors: list[Work], distances: list[float], calculate_similarity: bool = False
179
+ ) -> str:
180
  result_string = ""
181
  for work, distance in zip(neighbors, distances):
182
  entry_string = "## "
 
221
  meta.append(("Cited-by count", str(work.citations)))
222
  if work.doi:
223
  meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
224
+ if calculate_similarity:
225
+ # if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2
226
+ meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2
227
+ else:
228
+ meta.append(("Distance", f"{distance:.2f}"))
229
  entry_string += (" " * 4).join(": ".join(tup) for tup in meta)
230
 
231
  entry_string += "*\n"
 
271
  works_flat = execute_request(openalex_ids_flat, mailto)
272
  works = [list(batch) for batch in batched(works_flat, k)]
273
 
274
+ result_strings = [
275
+ format_response(w, d, calculate_similarity=normalize)
276
+ for w, d in zip(works, distances)
277
+ ]
278
 
279
  return (result_strings, )
280