Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
030ddc3
1
Parent(s):
a228730
Create main function, add generic get_env_var function
Browse files
app.py
CHANGED
@@ -6,16 +6,12 @@ from sentence_transformers import SentenceTransformer
|
|
6 |
import faiss
|
7 |
import gradio as gr
|
8 |
from datasets import Dataset
|
9 |
-
from typing import TypedDict, Self, Any
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from dataclasses import dataclass
|
13 |
from itertools import batched, chain
|
14 |
-
|
15 |
-
MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
|
16 |
-
DIR = Path("index")
|
17 |
-
SEARCH_TIME_S = 1 # TODO: optimize
|
18 |
-
K = 20
|
19 |
|
20 |
|
21 |
class IndexParameters(TypedDict):
|
@@ -98,6 +94,17 @@ class Work:
|
|
98 |
return " ".join(word for word in abstract_words if word is not None)
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def get_model(model_name: str, device: str) -> SentenceTransformer:
|
102 |
return SentenceTransformer(model_name, device=device)
|
103 |
|
@@ -107,6 +114,7 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
|
|
107 |
index.load_faiss_index("embeddings", dir / "index.faiss", None)
|
108 |
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
109 |
|
|
|
110 |
with open(dir / "params.json", "r") as f:
|
111 |
params: list[IndexParameters] = json.load(f)
|
112 |
params = [p for p in params if p["exec_time"] < search_time_s]
|
@@ -197,51 +205,62 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
|
|
197 |
return result_string
|
198 |
|
199 |
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
gr.Markdown(
|
222 |
-
"Explore 95 million academic publications selected from the "
|
223 |
-
"[OpenAlex](https://openalex.org) dataset. This project is an index of the "
|
224 |
-
"embeddings generated from their titles and abstracts. The embeddings were "
|
225 |
-
"generated using the `all-MiniLM-L6-v2` model provided by the "
|
226 |
-
"[sentence-transformers](https://www.sbert.net/) module, and the index was "
|
227 |
-
"built using the [faiss](https://github.com/facebookresearch/faiss) module. "
|
228 |
-
"The build scripts and more information available at the main repo "
|
229 |
-
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
|
230 |
-
"Github."
|
231 |
-
)
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
results = gr.Markdown(
|
236 |
-
latex_delimiters=[
|
237 |
-
{"left": "$$", "right": "$$", "display": False},
|
238 |
-
{"left": "$", "right": "$", "display": False},
|
239 |
-
],
|
240 |
-
container=True,
|
241 |
-
)
|
242 |
|
243 |
-
query.submit(search, inputs=[query], outputs=[results], batch=True)
|
244 |
-
btn.click(search, inputs=[query], outputs=[results], batch=True)
|
245 |
|
246 |
-
|
247 |
-
|
|
|
6 |
import faiss
|
7 |
import gradio as gr
|
8 |
from datasets import Dataset
|
9 |
+
from typing import TypedDict, Self, Any, Callable
|
10 |
import json
|
11 |
from pathlib import Path
|
12 |
from dataclasses import dataclass
|
13 |
from itertools import batched, chain
|
14 |
+
import os
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
class IndexParameters(TypedDict):
|
|
|
94 |
return " ".join(word for word in abstract_words if word is not None)
|
95 |
|
96 |
|
97 |
+
def get_env_var[T, U](
|
98 |
+
key: str, type_: Callable[[str], T] = str, default: U = None
|
99 |
+
) -> T | U:
|
100 |
+
var = os.getenv(key)
|
101 |
+
if var is not None:
|
102 |
+
var = type_(var)
|
103 |
+
else:
|
104 |
+
var = default
|
105 |
+
return var
|
106 |
+
|
107 |
+
|
108 |
def get_model(model_name: str, device: str) -> SentenceTransformer:
|
109 |
return SentenceTransformer(model_name, device=device)
|
110 |
|
|
|
114 |
index.load_faiss_index("embeddings", dir / "index.faiss", None)
|
115 |
faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
|
116 |
|
117 |
+
# TODO: search for what minimized distance from utopia point
|
118 |
with open(dir / "params.json", "r") as f:
|
119 |
params: list[IndexParameters] = json.load(f)
|
120 |
params = [p for p in params if p["exec_time"] < search_time_s]
|
|
|
205 |
return result_string
|
206 |
|
207 |
|
208 |
+
def main():
|
209 |
+
# TODO: figure out some better defaults?
|
210 |
+
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
|
211 |
+
dir = get_env_var("DIR", Path, default=Path("index"))
|
212 |
+
search_time_s = get_env_var("SEARCH_TIME_S", int, default=1)
|
213 |
+
k = get_env_var("K", int, default=20)
|
214 |
+
|
215 |
+
model = get_model(model_name, "cpu")
|
216 |
+
index = get_index(dir, search_time_s)
|
217 |
+
|
218 |
+
# function signature: (expanded tuple of input batches) -> tuple of output batches
|
219 |
+
def search(query: list[str]) -> tuple[list[str]]:
|
220 |
+
query_embedding = model.encode(query)
|
221 |
+
distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
|
222 |
+
|
223 |
+
faiss_ids_flat = list(chain(*faiss_ids))
|
224 |
+
openalex_ids_flat = index[faiss_ids_flat]["idxs"]
|
225 |
+
works_flat = execute_request(openalex_ids_flat)
|
226 |
+
works = [list(batch) for batch in batched(works_flat, k)]
|
227 |
+
|
228 |
+
result_strings = [format_response(w, d) for w, d in zip(works, distances)]
|
229 |
+
|
230 |
+
return (result_strings, )
|
231 |
+
|
232 |
+
with gr.Blocks() as demo:
|
233 |
+
gr.Markdown("# abstracts-index")
|
234 |
+
gr.Markdown(
|
235 |
+
"Explore 95 million academic publications selected from the "
|
236 |
+
"[OpenAlex](https://openalex.org) dataset. This project is an index of the "
|
237 |
+
"embeddings generated from their titles and abstracts. The embeddings were "
|
238 |
+
f"generated using the {model_name} model provided by the "
|
239 |
+
"[sentence-transformers](https://www.sbert.net/) module, and the index was "
|
240 |
+
"built using the [faiss](https://github.com/facebookresearch/faiss) "
|
241 |
+
"module. The build scripts and more information available at the main repo "
|
242 |
+
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
|
243 |
+
"Github."
|
244 |
+
)
|
245 |
|
246 |
+
query = gr.Textbox(
|
247 |
+
lines=1, placeholder="Enter your query here", show_label=False
|
248 |
+
)
|
249 |
+
btn = gr.Button("Search")
|
250 |
+
results = gr.Markdown(
|
251 |
+
latex_delimiters=[
|
252 |
+
{"left": "$$", "right": "$$", "display": False},
|
253 |
+
{"left": "$", "right": "$", "display": False},
|
254 |
+
],
|
255 |
+
container=True,
|
256 |
+
)
|
257 |
|
258 |
+
query.submit(search, inputs=[query], outputs=[results], batch=True)
|
259 |
+
btn.click(search, inputs=[query], outputs=[results], batch=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
+
demo.queue()
|
262 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
|
|
|
|
264 |
|
265 |
+
if __name__ == "__main__":
|
266 |
+
main()
|