abstracts-index / app.py
colonelwatch's picture
Merge the index on disk to keep file sizes under 4GB
2db96ca
raw
history blame
11.1 kB
# app.py
# Loads all completed shards and finds the most similar vector to a given query vector.
from dataclasses import dataclass
from itertools import batched, chain
import json
import os
from pathlib import Path
from sys import stderr
from typing import TypedDict, Self, Any, Callable
from datasets import Dataset
from datasets.search import FaissIndex
import faiss
from faiss.contrib.ondisk import merge_ondisk
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer
import torch
class IndexParameters(TypedDict):
recall: float # in this case 10-recall@10
exec_time: float # seconds (raw faiss measure is in milliseconds)
param_string: str # pass directly to faiss index
class Params(TypedDict):
dimensions: int | None
normalize: bool
optimal_params: list[IndexParameters]
@dataclass
class Work:
title: str | None
abstract: str | None # recovered from abstract_inverted_index
authors: list[str] # takes raw_author_name field from Authorship objects
journal_name: str | None # takes the display_name field of the first location
year: int
citations: int
doi: str | None
def __post_init__(self):
self._check_type(self.title, str, nullable=True)
self._check_type(self.abstract, str, nullable=True)
self._check_type(self.authors, list)
for author in self.authors:
self._check_type(author, str)
self._check_type(self.journal_name, str, nullable=True)
self._check_type(self.year, int)
self._check_type(self.citations, int)
self._check_type(self.doi, str, nullable=True)
@classmethod
def from_dict(cls, d: dict) -> Self:
inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"]
abstract = cls._recover_abstract(inverted_index) if inverted_index else None
try:
journal_name = d["primary_location"]["source"]["display_name"]
except (TypeError, KeyError): # key didn't exist or a value was null
journal_name = None
return cls(
title=d["title"],
abstract=abstract,
authors=[authorship["raw_author_name"] for authorship in d["authorships"]],
journal_name=journal_name,
year=d["publication_year"],
citations=d["cited_by_count"],
doi=d["doi"],
)
@staticmethod
def get_raw_fields() -> list[str]:
return [
"title",
"abstract_inverted_index",
"authorships",
"primary_location",
"publication_year",
"cited_by_count",
"doi"
]
@staticmethod
def _check_type(v: Any, t: type, nullable: bool = False):
if not ((nullable and v is None) or isinstance(v, t)):
v_type_name = f"{type(v)}" if v is not None else "None"
t_name = f"{t}"
if nullable:
t_name += " | None"
raise ValueError(f"expected {t_name}, got {v_type_name}")
@staticmethod
def _recover_abstract(inverted_index: dict[str, list[int]]) -> str:
abstract_size = max(max(locs) for locs in inverted_index.values())+1
abstract_words: list[str | None] = [None] * abstract_size
for word, locs in inverted_index.items():
for loc in locs:
abstract_words[loc] = word
return " ".join(word for word in abstract_words if word is not None)
def get_env_var[T, U](
key: str, type_: Callable[[str], T] = str, default: U = None
) -> T | U:
var = os.getenv(key)
if var is not None:
var = type_(var)
else:
var = default
return var
def get_model(
model_name: str, params_dir: Path, trust_remote_code: bool
) -> tuple[bool, SentenceTransformer]:
with open(params_dir / "params.json", "r") as f:
params: Params = json.load(f)
return params["normalize"], SentenceTransformer(
model_name,
trust_remote_code=trust_remote_code,
truncate_dim=params["dimensions"]
)
def merge_shards(dir: Path) -> faiss.Index:
empty_path = dir / "empty.faiss"
shard_paths = [str(p) for p in dir.glob("shard_*.faiss")]
merged_ivfdata_path = Path("temp.ivfdata")
index = faiss.read_index(str(empty_path))
merged_ivfdata_path.unlink(missing_ok=True) # overwrite previous if it exists (TODO: do I need this?)
merge_ondisk(index, shard_paths, str(merged_ivfdata_path))
return index
def get_index(dir: Path, search_time_s: float) -> Dataset:
# NOTE: a private attr is used to get the faiss.IO_FLAG_ONDISK_SAME_DIR flag!
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
faiss_index = merge_shards(dir / "shards")
index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index)
with open(dir / "params.json", "r") as f:
params: Params = json.load(f)
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s]
optimal = max(under, key=(lambda p: p["recall"]))
optimal_string = optimal["param_string"]
ps = faiss.ParameterSpace()
ps.initialize(faiss_index)
ps.set_index_parameters(faiss_index, optimal_string)
return index
def execute_request(ids: list[str], mailto: str | None) -> list[Work]:
if len(ids) > 100:
raise ValueError("querying /works endpoint with more than 100 works")
# query with the /works endpoint with a specific list of IDs and fields
search_filter = f"openalex_id:{"|".join(ids)}"
search_select = ",".join(["id"] + Work.get_raw_fields())
params = {"filter": search_filter, "select": search_select, "per-page": 100}
if mailto is not None:
params["mailto"] = mailto
response = requests.get("https://api.openalex.org/works", params)
response.raise_for_status()
# the response is not necessarily ordered, so order them
response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]}
return [response[id_] for id_ in ids]
def collapse_newlines(x: str) -> str:
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
def format_response(neighbors: list[Work], distances: list[float]) -> str:
result_string = ""
for work, distance in zip(neighbors, distances):
entry_string = "## "
if work.title and work.doi:
entry_string += f"[{collapse_newlines(work.title)}]({work.doi})"
elif work.title:
entry_string += f"{collapse_newlines(work.title)}"
elif work.doi:
entry_string += f"[No title]({work.doi})"
else:
entry_string += "No title"
entry_string += "\n\n**"
if len(work.authors) >= 3: # truncate to 3 if necessary
entry_string += ", ".join(work.authors[:3]) + ", ..."
elif work.authors:
entry_string += ", ".join(work.authors)
else:
entry_string += "No author"
entry_string += f", {work.year}"
if work.journal_name:
entry_string += " - " + work.journal_name
entry_string += "**\n\n"
if work.abstract:
abstract = collapse_newlines(work.abstract)
if len(abstract) > 2000:
abstract = abstract[:2000] + "..."
entry_string += abstract
else:
entry_string += "No abstract"
entry_string += "\n\n*"
meta: list[tuple[str, str]] = []
if work.citations: # don't tack "Cited-by count: 0" on someones's work
meta.append(("Cited-by count", str(work.citations)))
if work.doi:
meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
meta.append(("Similarity", f"{distance:.2f}"))
entry_string += ("&nbsp;" * 4).join(": ".join(tup) for tup in meta)
entry_string += "*\n"
result_string += entry_string
return result_string
def main():
# TODO: figure out some better defaults?
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
prompt_name = get_env_var("PROMPT_NAME")
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
fp16 = get_env_var("FP16", bool, default=False)
dir = get_env_var("DIR", Path, default=Path("index"))
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
mailto = get_env_var("MAILTO", str, None)
normalize, model = get_model(model_name, dir, trust_remote_code)
index = get_index(dir, search_time_s)
model.eval()
if torch.cuda.is_available():
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
elif fp16:
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
model.compile(mode="reduce-overhead")
# function signature: (expanded tuple of input batches) -> tuple of output batches
def search(query: list[str]) -> tuple[list[str]]:
query_embedding = model.encode(
query, prompt_name, normalize_embeddings=normalize
)
distances, faiss_ids = index.search_batch("embedding", query_embedding, k)
faiss_ids_flat = list(chain(*faiss_ids))
openalex_ids_flat = index[faiss_ids_flat]["id"]
works_flat = execute_request(openalex_ids_flat, mailto)
works = [list(batch) for batch in batched(works_flat, k)]
result_strings = [format_response(w, d) for w, d in zip(works, distances)]
return (result_strings, )
with gr.Blocks() as demo:
gr.Markdown("# abstracts-index")
gr.Markdown(
"Explore 95 million academic publications selected from the "
"[OpenAlex](https://openalex.org) dataset. This project is an index of the "
"embeddings generated from their titles and abstracts. The embeddings were "
f"generated using the {model_name} model provided by the "
"[sentence-transformers](https://www.sbert.net/) module, and the index was "
"built using the [faiss](https://github.com/facebookresearch/faiss) "
"module. The build scripts and more information available at the main repo "
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
"Github."
)
query = gr.Textbox(
lines=1, placeholder="Enter your query here", show_label=False
)
btn = gr.Button("Search")
results = gr.Markdown(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": False},
{"left": "$", "right": "$", "display": False},
],
container=True,
)
query.submit(search, inputs=[query], outputs=[results], batch=True)
btn.click(search, inputs=[query], outputs=[results], batch=True)
demo.queue()
demo.launch()
if __name__ == "__main__":
main()