Spaces:
Sleeping
Sleeping
# app.py | |
# Loads all completed shards and finds the most similar vector to a given query vector. | |
from dataclasses import dataclass | |
from itertools import batched, chain | |
import json | |
import os | |
from pathlib import Path | |
from sys import stderr | |
from typing import TypedDict, Self, Any, Callable | |
from datasets import Dataset | |
from datasets.search import FaissIndex | |
import faiss | |
from faiss.contrib.ondisk import merge_ondisk | |
import gradio as gr | |
import requests | |
from sentence_transformers import SentenceTransformer | |
import torch | |
class IndexParameters(TypedDict): | |
recall: float # in this case 10-recall@10 | |
exec_time: float # seconds (raw faiss measure is in milliseconds) | |
param_string: str # pass directly to faiss index | |
class Params(TypedDict): | |
dimensions: int | None | |
normalize: bool | |
optimal_params: list[IndexParameters] | |
class Work: | |
title: str | None | |
abstract: str | None # recovered from abstract_inverted_index | |
authors: list[str] # takes raw_author_name field from Authorship objects | |
journal_name: str | None # takes the display_name field of the first location | |
year: int | |
citations: int | |
doi: str | None | |
def __post_init__(self): | |
self._check_type(self.title, str, nullable=True) | |
self._check_type(self.abstract, str, nullable=True) | |
self._check_type(self.authors, list) | |
for author in self.authors: | |
self._check_type(author, str) | |
self._check_type(self.journal_name, str, nullable=True) | |
self._check_type(self.year, int) | |
self._check_type(self.citations, int) | |
self._check_type(self.doi, str, nullable=True) | |
def from_dict(cls, d: dict) -> Self: | |
inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"] | |
abstract = cls._recover_abstract(inverted_index) if inverted_index else None | |
try: | |
journal_name = d["primary_location"]["source"]["display_name"] | |
except (TypeError, KeyError): # key didn't exist or a value was null | |
journal_name = None | |
return cls( | |
title=d["title"], | |
abstract=abstract, | |
authors=[authorship["raw_author_name"] for authorship in d["authorships"]], | |
journal_name=journal_name, | |
year=d["publication_year"], | |
citations=d["cited_by_count"], | |
doi=d["doi"], | |
) | |
def get_raw_fields() -> list[str]: | |
return [ | |
"title", | |
"abstract_inverted_index", | |
"authorships", | |
"primary_location", | |
"publication_year", | |
"cited_by_count", | |
"doi" | |
] | |
def _check_type(v: Any, t: type, nullable: bool = False): | |
if not ((nullable and v is None) or isinstance(v, t)): | |
v_type_name = f"{type(v)}" if v is not None else "None" | |
t_name = f"{t}" | |
if nullable: | |
t_name += " | None" | |
raise ValueError(f"expected {t_name}, got {v_type_name}") | |
def _recover_abstract(inverted_index: dict[str, list[int]]) -> str: | |
abstract_size = max(max(locs) for locs in inverted_index.values())+1 | |
abstract_words: list[str | None] = [None] * abstract_size | |
for word, locs in inverted_index.items(): | |
for loc in locs: | |
abstract_words[loc] = word | |
return " ".join(word for word in abstract_words if word is not None) | |
def get_env_var[T, U]( | |
key: str, type_: Callable[[str], T] = str, default: U = None | |
) -> T | U: | |
var = os.getenv(key) | |
if var is not None: | |
var = type_(var) | |
else: | |
var = default | |
return var | |
def get_model( | |
model_name: str, params_dir: Path, trust_remote_code: bool | |
) -> tuple[bool, SentenceTransformer]: | |
with open(params_dir / "params.json", "r") as f: | |
params: Params = json.load(f) | |
return params["normalize"], SentenceTransformer( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
truncate_dim=params["dimensions"] | |
) | |
def merge_shards(dir: Path) -> faiss.Index: | |
empty_path = dir / "empty.faiss" | |
shard_paths = [str(p) for p in dir.glob("shard_*.faiss")] | |
merged_ivfdata_path = Path("temp.ivfdata") | |
index = faiss.read_index(str(empty_path)) | |
merged_ivfdata_path.unlink(missing_ok=True) # overwrite previous if it exists (TODO: do I need this?) | |
merge_ondisk(index, shard_paths, str(merged_ivfdata_path)) | |
return index | |
def get_index(dir: Path, search_time_s: float) -> Dataset: | |
# NOTE: a private attr is used to get the faiss.IO_FLAG_ONDISK_SAME_DIR flag! | |
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore | |
faiss_index = merge_shards(dir / "shards") | |
index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index) | |
with open(dir / "params.json", "r") as f: | |
params: Params = json.load(f) | |
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s] | |
optimal = max(under, key=(lambda p: p["recall"])) | |
optimal_string = optimal["param_string"] | |
ps = faiss.ParameterSpace() | |
ps.initialize(faiss_index) | |
ps.set_index_parameters(faiss_index, optimal_string) | |
return index | |
def execute_request(ids: list[str], mailto: str | None) -> list[Work]: | |
if len(ids) > 100: | |
raise ValueError("querying /works endpoint with more than 100 works") | |
# query with the /works endpoint with a specific list of IDs and fields | |
search_filter = f"openalex_id:{"|".join(ids)}" | |
search_select = ",".join(["id"] + Work.get_raw_fields()) | |
params = {"filter": search_filter, "select": search_select, "per-page": 100} | |
if mailto is not None: | |
params["mailto"] = mailto | |
response = requests.get("https://api.openalex.org/works", params) | |
response.raise_for_status() | |
# the response is not necessarily ordered, so order them | |
response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]} | |
return [response[id_] for id_ in ids] | |
def collapse_newlines(x: str) -> str: | |
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ") | |
def format_response(neighbors: list[Work], distances: list[float]) -> str: | |
result_string = "" | |
for work, distance in zip(neighbors, distances): | |
entry_string = "## " | |
if work.title and work.doi: | |
entry_string += f"[{collapse_newlines(work.title)}]({work.doi})" | |
elif work.title: | |
entry_string += f"{collapse_newlines(work.title)}" | |
elif work.doi: | |
entry_string += f"[No title]({work.doi})" | |
else: | |
entry_string += "No title" | |
entry_string += "\n\n**" | |
if len(work.authors) >= 3: # truncate to 3 if necessary | |
entry_string += ", ".join(work.authors[:3]) + ", ..." | |
elif work.authors: | |
entry_string += ", ".join(work.authors) | |
else: | |
entry_string += "No author" | |
entry_string += f", {work.year}" | |
if work.journal_name: | |
entry_string += " - " + work.journal_name | |
entry_string += "**\n\n" | |
if work.abstract: | |
abstract = collapse_newlines(work.abstract) | |
if len(abstract) > 2000: | |
abstract = abstract[:2000] + "..." | |
entry_string += abstract | |
else: | |
entry_string += "No abstract" | |
entry_string += "\n\n*" | |
meta: list[tuple[str, str]] = [] | |
if work.citations: # don't tack "Cited-by count: 0" on someones's work | |
meta.append(("Cited-by count", str(work.citations))) | |
if work.doi: | |
meta.append(("DOI", work.doi.replace("https://doi.org/", ""))) | |
meta.append(("Similarity", f"{distance:.2f}")) | |
entry_string += (" " * 4).join(": ".join(tup) for tup in meta) | |
entry_string += "*\n" | |
result_string += entry_string | |
return result_string | |
def main(): | |
# TODO: figure out some better defaults? | |
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2") | |
prompt_name = get_env_var("PROMPT_NAME") | |
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False) | |
fp16 = get_env_var("FP16", bool, default=False) | |
dir = get_env_var("DIR", Path, default=Path("index")) | |
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1) | |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet | |
mailto = get_env_var("MAILTO", str, None) | |
normalize, model = get_model(model_name, dir, trust_remote_code) | |
index = get_index(dir, search_time_s) | |
model.eval() | |
if torch.cuda.is_available(): | |
model = model.half().cuda() if fp16 else model.bfloat16().cuda() | |
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here | |
elif fp16: | |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr) | |
model.compile(mode="reduce-overhead") | |
# function signature: (expanded tuple of input batches) -> tuple of output batches | |
def search(query: list[str]) -> tuple[list[str]]: | |
query_embedding = model.encode( | |
query, prompt_name, normalize_embeddings=normalize | |
) | |
distances, faiss_ids = index.search_batch("embedding", query_embedding, k) | |
faiss_ids_flat = list(chain(*faiss_ids)) | |
openalex_ids_flat = index[faiss_ids_flat]["id"] | |
works_flat = execute_request(openalex_ids_flat, mailto) | |
works = [list(batch) for batch in batched(works_flat, k)] | |
result_strings = [format_response(w, d) for w, d in zip(works, distances)] | |
return (result_strings, ) | |
with gr.Blocks() as demo: | |
gr.Markdown("# abstracts-index") | |
gr.Markdown( | |
"Explore 95 million academic publications selected from the " | |
"[OpenAlex](https://openalex.org) dataset. This project is an index of the " | |
"embeddings generated from their titles and abstracts. The embeddings were " | |
f"generated using the {model_name} model provided by the " | |
"[sentence-transformers](https://www.sbert.net/) module, and the index was " | |
"built using the [faiss](https://github.com/facebookresearch/faiss) " | |
"module. The build scripts and more information available at the main repo " | |
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on " | |
"Github." | |
) | |
query = gr.Textbox( | |
lines=1, placeholder="Enter your query here", show_label=False | |
) | |
btn = gr.Button("Search") | |
results = gr.Markdown( | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": False}, | |
{"left": "$", "right": "$", "display": False}, | |
], | |
container=True, | |
) | |
query.submit(search, inputs=[query], outputs=[results], batch=True) | |
btn.click(search, inputs=[query], outputs=[results], batch=True) | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |