# app.py # Loads all completed shards and finds the most similar vector to a given query vector. from dataclasses import dataclass from itertools import chain import json import os from math import log10 from pathlib import Path from sys import stderr from typing import TypedDict, TypeVar, Any, Callable from datasets import Dataset from datasets.search import FaissIndex import faiss from huggingface_hub import snapshot_download import numpy as np import numpy.typing as npt import gradio as gr import requests from sentence_transformers import SentenceTransformer import torch try: import spaces except ImportError: spaces = None T = TypeVar("T") U = TypeVar("U") class IndexParameters(TypedDict): recall: float # in this case 10-recall@10 exec_time: float # seconds (raw faiss measure is in milliseconds) param_string: str # pass directly to faiss index class Params(TypedDict): dimensions: int | None normalize: bool optimal_params: list[IndexParameters] @dataclass class Work: title: str | None abstract: str | None # recovered from abstract_inverted_index authors: list[str] # takes raw_author_name field from Authorship objects journal_name: str | None # takes the display_name field of the first location year: int citations: int doi: str | None def __post_init__(self): self._check_type(self.title, str, nullable=True) self._check_type(self.abstract, str, nullable=True) self._check_type(self.authors, list) for author in self.authors: self._check_type(author, str) self._check_type(self.journal_name, str, nullable=True) self._check_type(self.year, int) self._check_type(self.citations, int) self._check_type(self.doi, str, nullable=True) @classmethod def from_dict(cls, d: dict) -> "Work": inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"] abstract = cls._recover_abstract(inverted_index) if inverted_index else None try: journal_name = d["primary_location"]["source"]["display_name"] except (TypeError, KeyError): # key didn't exist or a value was null journal_name = None return cls( title=d["title"], abstract=abstract, authors=[authorship["raw_author_name"] for authorship in d["authorships"]], journal_name=journal_name, year=d["publication_year"], citations=d["cited_by_count"], doi=d["doi"], ) @staticmethod def get_raw_fields() -> list[str]: return [ "title", "abstract_inverted_index", "authorships", "primary_location", "publication_year", "cited_by_count", "doi" ] @staticmethod def _check_type(v: Any, t: type, nullable: bool = False): if not ((nullable and v is None) or isinstance(v, t)): v_type_name = f"{type(v)}" if v is not None else "None" t_name = f"{t}" if nullable: t_name += " | None" raise ValueError(f"expected {t_name}, got {v_type_name}") @staticmethod def _recover_abstract(inverted_index: dict[str, list[int]]) -> str: abstract_size = max(max(locs) for locs in inverted_index.values())+1 abstract_words: list[str | None] = [None] * abstract_size for word, locs in inverted_index.items(): for loc in locs: abstract_words[loc] = word return " ".join(word for word in abstract_words if word is not None) def get_env_var(key: str, type_: Callable[[str], T] = str, default: U = None) -> T | U: var = os.getenv(key) if var is not None: var = type_(var) else: var = default return var def get_model( model_name: str, params_dir: Path, trust_remote_code: bool ) -> tuple[bool, SentenceTransformer]: # TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize? with open(params_dir / "params.json", "r") as f: params: Params = json.load(f) return params["normalize"], SentenceTransformer( model_name, trust_remote_code=trust_remote_code, truncate_dim=params["dimensions"] ) def open_ondisk(dir: Path) -> faiss.Index: # without IO_FLAG_ONDISK_SAME_DIR, read_index gets on-disk indices in working dir return faiss.read_index(str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR) def get_index(dir: Path, search_time_s: float) -> Dataset: # NOTE: use a private attr to load the index with IO_FLAG_ONDISK_SAME_DIR! index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore faiss_index = open_ondisk(dir) index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index) with open(dir / "params.json", "r") as f: params: Params = json.load(f) under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s] optimal = max(under, key=(lambda p: p["recall"])) optimal_string = optimal["param_string"] ps = faiss.ParameterSpace() ps.initialize(faiss_index) ps.set_index_parameters(faiss_index, optimal_string) return index def execute_request(ids: list[str], mailto: str | None) -> list[Work]: if len(ids) > 100: raise ValueError("querying /works endpoint with more than 100 works") # query with the /works endpoint with a specific list of IDs and fields search_filter = f'openalex_id:{"|".join(ids)}' search_select = ",".join(["id"] + Work.get_raw_fields()) params = {"filter": search_filter, "select": search_select, "per-page": 100} if mailto is not None: params["mailto"] = mailto response = requests.get("https://api.openalex.org/works", params) response.raise_for_status() # the response is not necessarily ordered, so order them response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]} return [response[id_] for id_ in ids] def collapse_newlines(x: str) -> str: return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ") def format_response( neighbors: list[Work], distances: list[float], calculate_similarity: bool = False ) -> str: result_string = "" for work, distance in zip(neighbors, distances): entry_string = "## " if work.title and work.doi: entry_string += f"[{collapse_newlines(work.title)}]({work.doi})" elif work.title: entry_string += f"{collapse_newlines(work.title)}" elif work.doi: entry_string += f"[No title]({work.doi})" else: entry_string += "No title" entry_string += "\n\n**" if len(work.authors) >= 3: # truncate to 3 if necessary entry_string += ", ".join(work.authors[:3]) + ", ..." elif work.authors: entry_string += ", ".join(work.authors) else: entry_string += "No author" entry_string += f", {work.year}" if work.journal_name: entry_string += " - " + work.journal_name entry_string += "**\n\n" if work.abstract: abstract = collapse_newlines(work.abstract) if len(abstract) > 2000: abstract = abstract[:2000] + "..." entry_string += abstract else: entry_string += "No abstract" entry_string += "\n\n*" meta: list[tuple[str, str]] = [] if work.citations: # don't tack "Cited-by count: 0" on someones's work meta.append(("Cited-by count", str(work.citations))) if work.doi: meta.append(("DOI", work.doi.replace("https://doi.org/", ""))) if calculate_similarity: # if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2 meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2 else: meta.append(("Distance", f"{distance:.2f}")) entry_string += (" " * 4).join(": ".join(tup) for tup in meta) entry_string += "*\n" result_string += entry_string return result_string def main(): # TODO: figure out some better defaults? model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2") prompt_name = get_env_var("PROMPT_NAME") trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False) fp16 = get_env_var("FP16", bool, default=False) dir = get_env_var("DIR", Path) repo = get_env_var("REPO", str) search_time_s = get_env_var("SEARCH_TIME_S", float, default=1) k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet mailto = get_env_var("MAILTO", str, None) if dir is None: # acquire the index if it's not local if repo is None: repo = "colonelwatch/abstracts-faiss" dir = Path(snapshot_download(repo, repo_type="dataset")) / "index" elif repo is not None: print('warning: used "REPO" and also "DIR", ignoring "REPO"...', file=stderr) normalize, model = get_model(model_name, dir, trust_remote_code) index = get_index(dir, search_time_s) # follow model.encode logic for acquiring the prompt if prompt_name is None and model.default_prompt_name is not None: prompt_name = model.default_prompt_name if not isinstance(prompt_name, str): raise TypeError("invalid prompt name type") prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None # follow model.encode logic for setting extra_features extra_features: dict[str, Any] = {} if prompt is not None: tokenized = model.tokenize([prompt]) if "input_ids" in tokenized: extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1 model.eval() if torch.cuda.is_available(): model = model.half().cuda() if fp16 else model.bfloat16().cuda() # TODO: if huggingface datasets exposes an fp16 gpu option, use it here elif fp16: print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr) model.compile(mode="reduce-overhead") def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]: # Tokenize (which yields a dict) then do a non-blocking transfer features = { k: v.to(model.device, non_blocking=True) for k, v in features.items() } | extra_features with torch.no_grad(): out_features = model.forward(features) embeddings = out_features["sentence_embedding"] embeddings = embeddings[0] if model.truncate_dim: embeddings = embeddings[:model.truncate_dim] if normalize: embeddings = torch.nn.functional.normalize(embeddings, dim=0) return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr if spaces: encode_tokens = spaces.GPU(encode_tokens) def encode_string(query: str) -> npt.NDArray[np.float32]: if prompt: query = prompt + query tokens = model.tokenize([query]) return encode_tokens(tokens) def search(query: str) -> str: query_embedding = encode_string(query) distances, faiss_ids = index.search("embedding", query_embedding, k) openalex_ids = index[faiss_ids]["id"] works = execute_request(openalex_ids, mailto) return format_response(works, distances, calculate_similarity=normalize) with gr.Blocks() as demo: # figure out the words to describe the quantity n_entries = len(index) n_digits = int(log10(n_entries)) divisor, postfix = { 0: (1, ""), 1: (1000, " thousand"), 2: (1000000, " million"), 3: (1000000000, " billion"), }[n_digits // 3] significand = n_entries / divisor significand = round(significand, 1 if (n_digits % 3 == 1) else None) quantity = str(significand) + postfix # split the (huggingface) model name and get the link model_publisher, model_human_name = model_name.split("/") model_link = f"https://huggingface.co/{model_publisher}/{model_human_name}" gr.Markdown("# abstracts-index") gr.Markdown( f"Explore {quantity} academic publications selected from the " "[OpenAlex](https://openalex.org) dataset (as of January 1st, 2025) with " "semantic search, not keyword search. This project is an index of the " "embeddings generated from their titles and abstracts. The embeddings were " f"generated using the [{model_human_name}]({model_link}) model, and the " "index was built using the " "[faiss](https://github.com/facebookresearch/faiss) module. The build " "scripts and more information available at the main repo " "[abstracts-search](https://github.com/colonelwatch/abstracts-search) on " "Github." ) query = gr.Textbox( lines=1, placeholder="Enter your query here", show_label=False ) btn = gr.Button("Search") results = gr.Markdown( latex_delimiters=[ {"left": "$$", "right": "$$", "display": False}, {"left": "$", "right": "$", "display": False}, ], container=True, ) # NOTE: ZeroGPU doesn't seem to support batching query.submit(search, inputs=[query], outputs=[results]) btn.click(search, inputs=[query], outputs=[results]) demo.queue() demo.launch() if __name__ == "__main__": main()