File size: 5,951 Bytes
2428d17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70fd5de
2428d17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# serve.py
# Loads all completed shards and finds the most similar vector to a given query vector.

import requests
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr

from markdown_it import MarkdownIt # used for overriding default markdown renderer

model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')

works_ids_path = 'idxs.txt'
with open(works_ids_path) as f:
    idxs = f.read().splitlines()
index = faiss.read_index('index.faiss')

ps = faiss.ParameterSpace()
ps.initialize(index)
ps.set_index_parameters(index, 'nprobe=32,ht=512')


def _recover_abstract(inverted_index):
    abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1

    abstract = [None]*abstract_size
    for word, appearances in inverted_index.items(): # yes, this is a second iteration over inverted_index
        for appearance in appearances:
            abstract[appearance] = word

    abstract = [word for word in abstract if word is not None]
    abstract = ' '.join(abstract)
    return abstract

def search(query):
    global model, index, idxs

    query_embedding = model.encode(query)
    query_embedding = query_embedding.reshape(1, -1)
    distances, faiss_ids = index.search(query_embedding, 10)

    distances = distances[0]
    faiss_ids = faiss_ids[0]

    openalex_ids = [idxs[faiss_id] for faiss_id in faiss_ids]
    search_filter = f'openalex_id:{"|".join(openalex_ids)}'
    search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'

    neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
    request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'

    return neighbors, request_str

def execute_request(request_str):
    response = requests.get(request_str).json()
    return response

def format_response(neighbors, response):
    response = {doc['id']: doc for doc in response['results']}
    
    result_string = ''
    for distance, openalex_id in neighbors:
        doc = response[openalex_id]

        # collect attributes from openalex doc for the given openalex_id
        title = doc['title']
        abstract = _recover_abstract(doc['abstract_inverted_index'])
        author_names = [authorship['author']['display_name'] for authorship in doc['authorships']]
        # journal_name = doc['primary_location']['source']['display_name']
        publication_year = doc['publication_year']
        citation_count = doc['cited_by_count']
        doi = doc['doi']

        # try to get journal name or else set it to None
        try:
            journal_name = doc['primary_location']['source']['display_name']
        except (TypeError, KeyError):
            journal_name = None

        # title: knock out escape sequences
        title = title.replace('\n', '\\n').replace('\r', '\\r')
        
        # abstract: knock out escape sequences, then truncate to 1500 characters if necessary
        abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
        if len(abstract) > 2000:
            abstract = abstract[:2000] + '...'
        
        # authors: truncate to 3 authors if necessary
        if len(author_names) >= 3:
            authors_str = ', '.join(author_names[:3]) + ', ...'
        else:
            authors_str = ', '.join(author_names)

        
        entry_string = ''

        if doi: # edge case: for now, no doi -> no link
            entry_string += f'## [{title}]({doi})\n\n'
        else:
            entry_string += f'## {title}\n\n'
        
        if journal_name:
            entry_string += f'**{authors_str} - {journal_name}, {publication_year}**\n'
        else:
            entry_string += f'**{authors_str}, {publication_year}**\n'
        
        entry_string += f'{abstract}\n\n'
        
        if citation_count: # edge case: we shouldn't tack "Cited-by count: 0" onto someone's paper
            entry_string += f'*Cited-by count: {citation_count}*'
            entry_string += '    '
        
        if doi: # list the doi if it exists
            entry_string += f'*DOI: {doi.replace("https://doi.org/", "")}*'
            entry_string += '    '
        
        entry_string += f'*Similarity: {distance:.2f}*'
        entry_string += '    \n'

        result_string += entry_string
    
    return result_string

with gr.Blocks() as demo:
    gr.Markdown('# abstracts-search demo')
    gr.Markdown(
        'Explore 95 million academic publications selected from the [OpenAlex](https://openalex.org) dataset. This '
        'project is an index of the embeddings generated from their titles and abstracts. The embeddings were '
        'generated using the `all-MiniLM-L6-v2` model provided by the [sentence-transformers](https://www.sbert.net/) '
        'module, and the index was built using the [faiss](https://github.com/facebookresearch/faiss) module.'
    )

    neighbors_var = gr.State()
    request_str_var = gr.State()
    response_var = gr.State()
    query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False)
    btn = gr.Button('Search')
    with gr.Box():
        results = gr.Markdown()

    md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex!
    results.md = md
    
    query.submit(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
        .success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
        .success(format_response, inputs=[neighbors_var, response_var], outputs=[results])
    btn.click(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
        .success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
        .success(format_response, inputs=[neighbors_var, response_var], outputs=[results])

demo.queue()
demo.launch()