colonelwatch commited on
Commit
95aa608
·
1 Parent(s): 4751a57

Normalize column names

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -128,8 +128,8 @@ def get_model(
128
 
129
  def get_index(dir: Path, search_time_s: float) -> Dataset:
130
  index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
131
- index.load_faiss_index("embeddings", dir / "index.faiss", None)
132
- faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
133
 
134
  with open(dir / "params.json", "r") as f:
135
  params: Params = json.load(f)
@@ -248,10 +248,10 @@ def main():
248
  query_embedding = model.encode(
249
  query, prompt_name, normalize_embeddings=normalize
250
  )
251
- distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
252
 
253
  faiss_ids_flat = list(chain(*faiss_ids))
254
- openalex_ids_flat = index[faiss_ids_flat]["idxs"]
255
  works_flat = execute_request(openalex_ids_flat, mailto)
256
  works = [list(batch) for batch in batched(works_flat, k)]
257
 
 
128
 
129
  def get_index(dir: Path, search_time_s: float) -> Dataset:
130
  index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
131
+ index.load_faiss_index("embedding", dir / "index.faiss", None)
132
+ faiss_index: faiss.Index = index.get_index("embedding").faiss_index # type: ignore
133
 
134
  with open(dir / "params.json", "r") as f:
135
  params: Params = json.load(f)
 
248
  query_embedding = model.encode(
249
  query, prompt_name, normalize_embeddings=normalize
250
  )
251
+ distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
252
 
253
  faiss_ids_flat = list(chain(*faiss_ids))
254
+ openalex_ids_flat = index[faiss_ids_flat]["id"]
255
  works_flat = execute_request(openalex_ids_flat, mailto)
256
  works = [list(batch) for batch in batched(works_flat, k)]
257