colonelwatch commited on
Commit
2428d17
Β·
1 Parent(s): 9af2095

Add index and app.py, update 2023-05-10

Browse files
Files changed (6) hide show
  1. .gitattributes +3 -0
  2. README.md +4 -4
  3. app.py +152 -0
  4. idxs.txt +3 -0
  5. index.faiss +3 -0
  6. requirements.txt +3 -0
.gitattributes CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.txt filter=lfs diff=lfs merge=lfs -text
2
+ *.faiss filter=lfs diff=lfs merge=lfs -text
3
+
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
6
  *.bin filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Abstracts Index
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
  app_file: app.py
 
1
  ---
2
+ title: abstracts-index
3
+ emoji: πŸ“
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.29.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # serve.py
2
+ # Loads all completed shards and finds the most similar vector to a given query vector.
3
+
4
+ import requests
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import gradio as gr
8
+
9
+ from markdown_it import MarkdownIt # used for overriding default markdown renderer
10
+
11
+ model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
12
+
13
+ works_ids_path = 'idxs.txt'
14
+ with open(works_ids_path) as f:
15
+ idxs = f.read().splitlines()
16
+ index = faiss.read_index('index.faiss')
17
+
18
+ ps = faiss.ParameterSpace()
19
+ ps.initialize(index)
20
+ ps.set_index_parameters(index, 'nprobe=32,ht=128')
21
+
22
+
23
+ def _recover_abstract(inverted_index):
24
+ abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
25
+
26
+ abstract = [None]*abstract_size
27
+ for word, appearances in inverted_index.items(): # yes, this is a second iteration over inverted_index
28
+ for appearance in appearances:
29
+ abstract[appearance] = word
30
+
31
+ abstract = [word for word in abstract if word is not None]
32
+ abstract = ' '.join(abstract)
33
+ return abstract
34
+
35
+ def search(query):
36
+ global model, index, idxs
37
+
38
+ query_embedding = model.encode(query)
39
+ query_embedding = query_embedding.reshape(1, -1)
40
+ distances, faiss_ids = index.search(query_embedding, 10)
41
+
42
+ distances = distances[0]
43
+ faiss_ids = faiss_ids[0]
44
+
45
+ openalex_ids = [idxs[faiss_id] for faiss_id in faiss_ids]
46
+ search_filter = f'openalex_id:{"|".join(openalex_ids)}'
47
+ search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
48
+
49
+ neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
50
+ request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
51
+
52
+ return neighbors, request_str
53
+
54
+ def execute_request(request_str):
55
+ response = requests.get(request_str).json()
56
+ return response
57
+
58
+ def format_response(neighbors, response):
59
+ response = {doc['id']: doc for doc in response['results']}
60
+
61
+ result_string = ''
62
+ for distance, openalex_id in neighbors:
63
+ doc = response[openalex_id]
64
+
65
+ # collect attributes from openalex doc for the given openalex_id
66
+ title = doc['title']
67
+ abstract = _recover_abstract(doc['abstract_inverted_index'])
68
+ author_names = [authorship['author']['display_name'] for authorship in doc['authorships']]
69
+ # journal_name = doc['primary_location']['source']['display_name']
70
+ publication_year = doc['publication_year']
71
+ citation_count = doc['cited_by_count']
72
+ doi = doc['doi']
73
+
74
+ # try to get journal name or else set it to None
75
+ try:
76
+ journal_name = doc['primary_location']['source']['display_name']
77
+ except (TypeError, KeyError):
78
+ journal_name = None
79
+
80
+ # title: knock out escape sequences
81
+ title = title.replace('\n', '\\n').replace('\r', '\\r')
82
+
83
+ # abstract: knock out escape sequences, then truncate to 1500 characters if necessary
84
+ abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
85
+ if len(abstract) > 2000:
86
+ abstract = abstract[:2000] + '...'
87
+
88
+ # authors: truncate to 3 authors if necessary
89
+ if len(author_names) >= 3:
90
+ authors_str = ', '.join(author_names[:3]) + ', ...'
91
+ else:
92
+ authors_str = ', '.join(author_names)
93
+
94
+
95
+ entry_string = ''
96
+
97
+ if doi: # edge case: for now, no doi -> no link
98
+ entry_string += f'## [{title}]({doi})\n\n'
99
+ else:
100
+ entry_string += f'## {title}\n\n'
101
+
102
+ if journal_name:
103
+ entry_string += f'**{authors_str} - {journal_name}, {publication_year}**\n'
104
+ else:
105
+ entry_string += f'**{authors_str}, {publication_year}**\n'
106
+
107
+ entry_string += f'{abstract}\n\n'
108
+
109
+ if citation_count: # edge case: we shouldn't tack "Cited-by count: 0" onto someone's paper
110
+ entry_string += f'*Cited-by count: {citation_count}*'
111
+ entry_string += '    '
112
+
113
+ if doi: # list the doi if it exists
114
+ entry_string += f'*DOI: {doi.replace("https://doi.org/", "")}*'
115
+ entry_string += '    '
116
+
117
+ entry_string += f'*Similarity: {distance:.2f}*'
118
+ entry_string += '    \n'
119
+
120
+ result_string += entry_string
121
+
122
+ return result_string
123
+
124
+ with gr.Blocks() as demo:
125
+ gr.Markdown('# abstracts-search demo')
126
+ gr.Markdown(
127
+ 'Explore 95 million academic publications selected from the [OpenAlex](https://openalex.org) dataset. This '
128
+ 'project is an index of the embeddings generated from their titles and abstracts. The embeddings were '
129
+ 'generated using the `all-MiniLM-L6-v2` model provided by the [sentence-transformers](https://www.sbert.net/) '
130
+ 'module, and the index was built using the [faiss](https://github.com/facebookresearch/faiss) module.'
131
+ )
132
+
133
+ neighbors_var = gr.State()
134
+ request_str_var = gr.State()
135
+ response_var = gr.State()
136
+ query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False)
137
+ btn = gr.Button('Search')
138
+ with gr.Box():
139
+ results = gr.Markdown()
140
+
141
+ md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex!
142
+ results.md = md
143
+
144
+ query.submit(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
145
+ .success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
146
+ .success(format_response, inputs=[neighbors_var, response_var], outputs=[results])
147
+ btn.click(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
148
+ .success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
149
+ .success(format_response, inputs=[neighbors_var, response_var], outputs=[results])
150
+
151
+ demo.queue()
152
+ demo.launch()
idxs.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeb2e963d4b3e1026ad550ea8a2e4fca92a1b59aa4d6fd005953ec5505415396
3
+ size 3141761935
index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:febbf3e28b7ccd726714de3efebbbeb5bbb0ec1de67928cce3d306ff1b803bcd
3
+ size 2304082443
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e05f71c99a0770d093e10449a16987a81c4c11bba93dbab6da1540e65b76e042
3
+ size 38