colonelwatch commited on
Commit
030ddc3
·
1 Parent(s): a228730

Create main function, add generic get_env_var function

Browse files
Files changed (1) hide show
  1. app.py +68 -49
app.py CHANGED
@@ -6,16 +6,12 @@ from sentence_transformers import SentenceTransformer
6
  import faiss
7
  import gradio as gr
8
  from datasets import Dataset
9
- from typing import TypedDict, Self, Any
10
  import json
11
  from pathlib import Path
12
  from dataclasses import dataclass
13
  from itertools import batched, chain
14
-
15
- MODEL_NAME = "all-MiniLM-L6-v2" # TODO: make configurable
16
- DIR = Path("index")
17
- SEARCH_TIME_S = 1 # TODO: optimize
18
- K = 20
19
 
20
 
21
  class IndexParameters(TypedDict):
@@ -98,6 +94,17 @@ class Work:
98
  return " ".join(word for word in abstract_words if word is not None)
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  def get_model(model_name: str, device: str) -> SentenceTransformer:
102
  return SentenceTransformer(model_name, device=device)
103
 
@@ -107,6 +114,7 @@ def get_index(dir: Path, search_time_s: float) -> Dataset:
107
  index.load_faiss_index("embeddings", dir / "index.faiss", None)
108
  faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
109
 
 
110
  with open(dir / "params.json", "r") as f:
111
  params: list[IndexParameters] = json.load(f)
112
  params = [p for p in params if p["exec_time"] < search_time_s]
@@ -197,51 +205,62 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
197
  return result_string
198
 
199
 
200
- model = get_model(MODEL_NAME, "cpu")
201
- index = get_index(DIR, SEARCH_TIME_S)
202
-
203
-
204
- # function signature: (expanded tuple of input batches) -> tuple of output batches
205
- def search(query: list[str]) -> tuple[list[str]]:
206
- query_embedding = model.encode(query)
207
- distances, faiss_ids = index.search_batch("embeddings", query_embedding, K)
208
-
209
- faiss_ids_flat = list(chain(*faiss_ids))
210
- openalex_ids_flat = index[faiss_ids_flat]["idxs"]
211
- works_flat = execute_request(openalex_ids_flat)
212
- works = [list(batch) for batch in batched(works_flat, K)]
213
-
214
- result_strings = [format_response(w, d) for w, d in zip(works, distances)]
215
-
216
- return (result_strings, )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- with gr.Blocks() as demo:
220
- gr.Markdown("# abstracts-index")
221
- gr.Markdown(
222
- "Explore 95 million academic publications selected from the "
223
- "[OpenAlex](https://openalex.org) dataset. This project is an index of the "
224
- "embeddings generated from their titles and abstracts. The embeddings were "
225
- "generated using the `all-MiniLM-L6-v2` model provided by the "
226
- "[sentence-transformers](https://www.sbert.net/) module, and the index was "
227
- "built using the [faiss](https://github.com/facebookresearch/faiss) module. "
228
- "The build scripts and more information available at the main repo "
229
- "[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
230
- "Github."
231
- )
232
 
233
- query = gr.Textbox(lines=1, placeholder="Enter your query here", show_label=False)
234
- btn = gr.Button("Search")
235
- results = gr.Markdown(
236
- latex_delimiters=[
237
- {"left": "$$", "right": "$$", "display": False},
238
- {"left": "$", "right": "$", "display": False},
239
- ],
240
- container=True,
241
- )
242
 
243
- query.submit(search, inputs=[query], outputs=[results], batch=True)
244
- btn.click(search, inputs=[query], outputs=[results], batch=True)
245
 
246
- demo.queue()
247
- demo.launch()
 
6
  import faiss
7
  import gradio as gr
8
  from datasets import Dataset
9
+ from typing import TypedDict, Self, Any, Callable
10
  import json
11
  from pathlib import Path
12
  from dataclasses import dataclass
13
  from itertools import batched, chain
14
+ import os
 
 
 
 
15
 
16
 
17
  class IndexParameters(TypedDict):
 
94
  return " ".join(word for word in abstract_words if word is not None)
95
 
96
 
97
+ def get_env_var[T, U](
98
+ key: str, type_: Callable[[str], T] = str, default: U = None
99
+ ) -> T | U:
100
+ var = os.getenv(key)
101
+ if var is not None:
102
+ var = type_(var)
103
+ else:
104
+ var = default
105
+ return var
106
+
107
+
108
  def get_model(model_name: str, device: str) -> SentenceTransformer:
109
  return SentenceTransformer(model_name, device=device)
110
 
 
114
  index.load_faiss_index("embeddings", dir / "index.faiss", None)
115
  faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
116
 
117
+ # TODO: search for what minimized distance from utopia point
118
  with open(dir / "params.json", "r") as f:
119
  params: list[IndexParameters] = json.load(f)
120
  params = [p for p in params if p["exec_time"] < search_time_s]
 
205
  return result_string
206
 
207
 
208
+ def main():
209
+ # TODO: figure out some better defaults?
210
+ model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
211
+ dir = get_env_var("DIR", Path, default=Path("index"))
212
+ search_time_s = get_env_var("SEARCH_TIME_S", int, default=1)
213
+ k = get_env_var("K", int, default=20)
214
+
215
+ model = get_model(model_name, "cpu")
216
+ index = get_index(dir, search_time_s)
217
+
218
+ # function signature: (expanded tuple of input batches) -> tuple of output batches
219
+ def search(query: list[str]) -> tuple[list[str]]:
220
+ query_embedding = model.encode(query)
221
+ distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
222
+
223
+ faiss_ids_flat = list(chain(*faiss_ids))
224
+ openalex_ids_flat = index[faiss_ids_flat]["idxs"]
225
+ works_flat = execute_request(openalex_ids_flat)
226
+ works = [list(batch) for batch in batched(works_flat, k)]
227
+
228
+ result_strings = [format_response(w, d) for w, d in zip(works, distances)]
229
+
230
+ return (result_strings, )
231
+
232
+ with gr.Blocks() as demo:
233
+ gr.Markdown("# abstracts-index")
234
+ gr.Markdown(
235
+ "Explore 95 million academic publications selected from the "
236
+ "[OpenAlex](https://openalex.org) dataset. This project is an index of the "
237
+ "embeddings generated from their titles and abstracts. The embeddings were "
238
+ f"generated using the {model_name} model provided by the "
239
+ "[sentence-transformers](https://www.sbert.net/) module, and the index was "
240
+ "built using the [faiss](https://github.com/facebookresearch/faiss) "
241
+ "module. The build scripts and more information available at the main repo "
242
+ "[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
243
+ "Github."
244
+ )
245
 
246
+ query = gr.Textbox(
247
+ lines=1, placeholder="Enter your query here", show_label=False
248
+ )
249
+ btn = gr.Button("Search")
250
+ results = gr.Markdown(
251
+ latex_delimiters=[
252
+ {"left": "$$", "right": "$$", "display": False},
253
+ {"left": "$", "right": "$", "display": False},
254
+ ],
255
+ container=True,
256
+ )
257
 
258
+ query.submit(search, inputs=[query], outputs=[results], batch=True)
259
+ btn.click(search, inputs=[query], outputs=[results], batch=True)
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ demo.queue()
262
+ demo.launch()
 
 
 
 
 
 
 
263
 
 
 
264
 
265
+ if __name__ == "__main__":
266
+ main()