Spaces:
Running
Running
# serve.py | |
# Loads all completed shards and finds the most similar vector to a given query vector. | |
import requests | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import gradio as gr | |
from markdown_it import MarkdownIt # used for overriding default markdown renderer | |
model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
works_ids_path = 'openalex_ids.txt' | |
with open(works_ids_path) as f: | |
idxs = f.read().splitlines() | |
index = faiss.read_index('index.faiss') | |
ps = faiss.ParameterSpace() | |
ps.initialize(index) | |
ps.set_index_parameters(index, 'nprobe=16,ht=512') | |
def _recover_abstract(inverted_index): | |
abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1 | |
abstract = [None]*abstract_size | |
for word, appearances in inverted_index.items(): # yes, this is a second iteration over inverted_index | |
for appearance in appearances: | |
abstract[appearance] = word | |
abstract = [word for word in abstract if word is not None] | |
abstract = ' '.join(abstract) | |
return abstract | |
def search(query): | |
global model, index, idxs | |
query_embedding = model.encode(query) | |
query_embedding = query_embedding.reshape(1, -1) | |
distances, faiss_ids = index.search(query_embedding, 20) | |
distances = distances[0] | |
faiss_ids = faiss_ids[0] | |
openalex_ids = [idxs[faiss_id] for faiss_id in faiss_ids] | |
search_filter = f'openalex_id:{"|".join(openalex_ids)}' | |
search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi' | |
neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)] | |
request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}' | |
return neighbors, request_str | |
def execute_request(request_str): | |
response = requests.get(request_str).json() | |
return response | |
def format_response(neighbors, response): | |
response = {doc['id']: doc for doc in response['results']} | |
result_string = '' | |
for distance, openalex_id in neighbors: | |
doc = response[openalex_id] | |
# collect attributes from openalex doc for the given openalex_id | |
title = doc['title'] | |
# abstract = _recover_abstract(doc['abstract_inverted_index']) | |
abstract_inverted_index = doc['abstract_inverted_index'] | |
author_names = [authorship['author']['display_name'] for authorship in doc['authorships']] | |
# journal_name = doc['primary_location']['source']['display_name'] | |
publication_year = doc['publication_year'] | |
citation_count = doc['cited_by_count'] | |
doi = doc['doi'] | |
if title is None: # edge case: no title | |
title = 'No title' | |
if abstract_inverted_index is None: # edge case: no abstract | |
abstract = 'No abstract' | |
else: | |
abstract = _recover_abstract(abstract_inverted_index) | |
abstract = abstract.replace('\n', '\\n').replace('\r', '\\r') | |
# try to get journal name or else set it to None | |
try: | |
journal_name = doc['primary_location']['source']['display_name'] | |
except (TypeError, KeyError): | |
journal_name = None | |
# title: knock out escape sequences | |
title = title.replace('\n', '\\n').replace('\r', '\\r') | |
# abstract: knock out escape sequences, then truncate to 1500 characters if necessary | |
abstract = abstract.replace('\n', '\\n').replace('\r', '\\r') | |
if len(abstract) > 2000: | |
abstract = abstract[:2000] + '...' | |
# authors: cover no name edge case, truncate to 3 authors if necessary | |
author_names = [author_name if author_name else 'No name' for author_name in author_names] | |
if len(author_names) >= 3: | |
authors_str = ', '.join(author_names[:3]) + ', ...' | |
else: | |
authors_str = ', '.join(author_names) | |
entry_string = '' | |
if doi: # edge case: for now, no doi -> no link | |
entry_string += f'## [{title}]({doi})\n\n' | |
else: | |
entry_string += f'## {title}\n\n' | |
if journal_name: | |
entry_string += f'**{authors_str} - {journal_name}, {publication_year}**\n\n' | |
else: | |
entry_string += f'**{authors_str}, {publication_year}**\n\n' | |
entry_string += f'{abstract}\n\n' | |
if citation_count: # edge case: we shouldn't tack "Cited-by count: 0" onto someone's paper | |
entry_string += f'*Cited-by count: {citation_count}*' | |
entry_string += ' ' | |
if doi: # list the doi if it exists | |
entry_string += f'*DOI: {doi.replace("https://doi.org/", "")}*' | |
entry_string += ' ' | |
entry_string += f'*Similarity: {distance:.2f}*' | |
entry_string += ' \n' | |
result_string += entry_string | |
return result_string | |
with gr.Blocks() as demo: | |
gr.Markdown('# abstracts-index') | |
gr.Markdown( | |
'Explore 95 million academic publications selected from the [OpenAlex](https://openalex.org) dataset. This ' | |
'project is an index of the embeddings generated from their titles and abstracts. The embeddings were ' | |
'generated using the `all-MiniLM-L6-v2` model provided by the [sentence-transformers](https://www.sbert.net/) ' | |
'module, and the index was built using the [faiss](https://github.com/facebookresearch/faiss) module. The build ' | |
'scripts and more information available at the main repo ' | |
'[abstracts-search](https://github.com/colonelwatch/abstracts-search) on Github.' | |
) | |
neighbors_var = gr.State() | |
request_str_var = gr.State() | |
response_var = gr.State() | |
query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False) | |
btn = gr.Button('Search') | |
with gr.Box(): | |
results = gr.Markdown() | |
md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex! | |
results.md = md | |
query.submit(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \ | |
.success(execute_request, inputs=[request_str_var], outputs=[response_var]) \ | |
.success(format_response, inputs=[neighbors_var, response_var], outputs=[results]) | |
btn.click(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \ | |
.success(execute_request, inputs=[request_str_var], outputs=[response_var]) \ | |
.success(format_response, inputs=[neighbors_var, response_var], outputs=[results]) | |
demo.queue(2) | |
demo.launch() |