Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
abd37c8
1
Parent(s):
becad39
Create and clean up function declarations
Browse files
app.py
CHANGED
@@ -8,9 +8,12 @@ import gradio as gr
|
|
8 |
from datasets import Dataset
|
9 |
from typing import TypedDict
|
10 |
import json
|
|
|
11 |
|
12 |
from markdown_it import MarkdownIt # used for overriding default markdown renderer
|
13 |
|
|
|
|
|
14 |
SEARCH_TIME_S = 1 # TODO: optimize
|
15 |
|
16 |
|
@@ -20,24 +23,29 @@ class IndexParameters(TypedDict):
|
|
20 |
param_string: str # pass directly to faiss index
|
21 |
|
22 |
|
23 |
-
|
|
|
24 |
|
25 |
-
index: Dataset = Dataset.from_parquet("index/ids.parquet") # type: ignore
|
26 |
-
index.load_faiss_index("embeddings", "index/index.faiss", None)
|
27 |
-
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
param_string = param["param_string"]
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
38 |
|
|
|
|
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
41 |
abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
|
42 |
|
43 |
abstract = [None]*abstract_size
|
@@ -49,26 +57,12 @@ def _recover_abstract(inverted_index):
|
|
49 |
abstract = ' '.join(abstract)
|
50 |
return abstract
|
51 |
|
52 |
-
def search(query):
|
53 |
-
global model, index
|
54 |
-
|
55 |
-
# TODO: pass in param string directly?
|
56 |
-
query_embedding = model.encode(query)
|
57 |
-
distances, faiss_ids = index.search("embeddings", query_embedding, 20)
|
58 |
-
|
59 |
-
openalex_ids = index[faiss_ids]["idxs"]
|
60 |
-
search_filter = f'openalex_id:{"|".join(openalex_ids)}'
|
61 |
-
search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
|
62 |
-
|
63 |
-
neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
|
64 |
-
request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
|
65 |
-
|
66 |
-
return neighbors, request_str
|
67 |
|
68 |
def execute_request(request_str):
|
69 |
response = requests.get(request_str).json()
|
70 |
return response
|
71 |
|
|
|
72 |
def format_response(neighbors, response):
|
73 |
response = {doc['id']: doc for doc in response['results']}
|
74 |
|
@@ -92,7 +86,7 @@ def format_response(neighbors, response):
|
|
92 |
if abstract_inverted_index is None: # edge case: no abstract
|
93 |
abstract = 'No abstract'
|
94 |
else:
|
95 |
-
abstract =
|
96 |
abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
|
97 |
|
98 |
# try to get journal name or else set it to None
|
@@ -146,6 +140,27 @@ def format_response(neighbors, response):
|
|
146 |
|
147 |
return result_string
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
with gr.Blocks() as demo:
|
150 |
gr.Markdown('# abstracts-index')
|
151 |
gr.Markdown(
|
|
|
8 |
from datasets import Dataset
|
9 |
from typing import TypedDict
|
10 |
import json
|
11 |
+
from pathlib import Path
|
12 |
|
13 |
from markdown_it import MarkdownIt # used for overriding default markdown renderer
|
14 |
|
15 |
+
MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
|
16 |
+
DIR = Path("index")
|
17 |
SEARCH_TIME_S = 1 # TODO: optimize
|
18 |
|
19 |
|
|
|
23 |
param_string: str # pass directly to faiss index
|
24 |
|
25 |
|
26 |
+
def get_model(model_name: str, device: str) -> SentenceTransformer:
|
27 |
+
return SentenceTransformer(model_name, device=device)
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
+
def get_index(dir: Path, search_time_s: float) -> Dataset:
|
31 |
+
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
|
32 |
+
index.load_faiss_index("embeddings", dir / "index.faiss", None)
|
33 |
+
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
|
|
34 |
|
35 |
+
with open(dir / "params.json", "r") as f:
|
36 |
+
params: list[IndexParameters] = json.load(f)
|
37 |
+
params = [p for p in params if p["exec_time"] < search_time_s]
|
38 |
+
param = max(params, key=(lambda p: p["recall"]))
|
39 |
+
param_string = param["param_string"]
|
40 |
|
41 |
+
ps = faiss.ParameterSpace()
|
42 |
+
ps.initialize(faiss_index)
|
43 |
+
ps.set_index_parameters(faiss_index, param_string)
|
44 |
|
45 |
+
return index
|
46 |
+
|
47 |
+
|
48 |
+
def recover_abstract(inverted_index):
|
49 |
abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
|
50 |
|
51 |
abstract = [None]*abstract_size
|
|
|
57 |
abstract = ' '.join(abstract)
|
58 |
return abstract
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def execute_request(request_str):
|
62 |
response = requests.get(request_str).json()
|
63 |
return response
|
64 |
|
65 |
+
|
66 |
def format_response(neighbors, response):
|
67 |
response = {doc['id']: doc for doc in response['results']}
|
68 |
|
|
|
86 |
if abstract_inverted_index is None: # edge case: no abstract
|
87 |
abstract = 'No abstract'
|
88 |
else:
|
89 |
+
abstract = recover_abstract(abstract_inverted_index)
|
90 |
abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
|
91 |
|
92 |
# try to get journal name or else set it to None
|
|
|
140 |
|
141 |
return result_string
|
142 |
|
143 |
+
|
144 |
+
model = get_model(MODEL_NAME, "cpu")
|
145 |
+
index = get_index(DIR, SEARCH_TIME_S)
|
146 |
+
|
147 |
+
|
148 |
+
def search(query):
|
149 |
+
global model, index
|
150 |
+
|
151 |
+
# TODO: pass in param string directly?
|
152 |
+
query_embedding = model.encode(query)
|
153 |
+
distances, faiss_ids = index.search("embeddings", query_embedding, 20)
|
154 |
+
|
155 |
+
openalex_ids = index[faiss_ids]["idxs"]
|
156 |
+
search_filter = f'openalex_id:{"|".join(openalex_ids)}'
|
157 |
+
search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
|
158 |
+
|
159 |
+
neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
|
160 |
+
request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
|
161 |
+
|
162 |
+
return neighbors, request_str
|
163 |
+
|
164 |
with gr.Blocks() as demo:
|
165 |
gr.Markdown('# abstracts-index')
|
166 |
gr.Markdown(
|