colonelwatch commited on
Commit
09f2d58
·
1 Parent(s): 628360c

Initial rework to handle new index format

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. app.py +28 -14
  4. index.faiss +0 -3
  5. index.ivfdata +0 -3
  6. openalex_ids.txt +0 -3
  7. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📝
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc0-1.0
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc0-1.0
app.py CHANGED
@@ -5,19 +5,36 @@ import requests
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
7
  import gradio as gr
 
 
 
8
 
9
  from markdown_it import MarkdownIt # used for overriding default markdown renderer
10
 
 
 
 
 
 
 
 
 
 
11
  model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
12
 
13
- works_ids_path = 'openalex_ids.txt'
14
- with open(works_ids_path) as f:
15
- idxs = f.read().splitlines()
16
- index = faiss.read_index('index.faiss')
 
 
 
 
 
17
 
18
  ps = faiss.ParameterSpace()
19
- ps.initialize(index)
20
- ps.set_index_parameters(index, 'nprobe=16,ht=512')
21
 
22
 
23
  def _recover_abstract(inverted_index):
@@ -33,16 +50,13 @@ def _recover_abstract(inverted_index):
33
  return abstract
34
 
35
  def search(query):
36
- global model, index, idxs
37
 
 
38
  query_embedding = model.encode(query)
39
- query_embedding = query_embedding.reshape(1, -1)
40
- distances, faiss_ids = index.search(query_embedding, 20)
41
-
42
- distances = distances[0]
43
- faiss_ids = faiss_ids[0]
44
 
45
- openalex_ids = [idxs[faiss_id] for faiss_id in faiss_ids]
46
  search_filter = f'openalex_id:{"|".join(openalex_ids)}'
47
  search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
48
 
@@ -148,7 +162,7 @@ with gr.Blocks() as demo:
148
  response_var = gr.State()
149
  query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False)
150
  btn = gr.Button('Search')
151
- with gr.Box():
152
  results = gr.Markdown()
153
 
154
  md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex!
 
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
7
  import gradio as gr
8
+ from datasets import Dataset
9
+ from typing import TypedDict
10
+ import json
11
 
12
  from markdown_it import MarkdownIt # used for overriding default markdown renderer
13
 
14
+ SEARCH_TIME_S = 1 # TODO: optimize
15
+
16
+
17
+ class IndexParameters(TypedDict):
18
+ recall: float # in this case 10-recall@10
19
+ exec_time: float # seconds (raw faiss measure is in milliseconds)
20
+ param_string: str # pass directly to faiss index
21
+
22
+
23
  model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
24
 
25
+ index: Dataset = Dataset.from_parquet("index/ids.parquet") # type: ignore
26
+ index.load_faiss_index("embeddings", "index/index.faiss", None)
27
+ faiss_index: faiss.Index = index.get_index("embeddings").faiss_index # type: ignore
28
+
29
+ with open("index/params.json", "r") as f:
30
+ params: list[IndexParameters] = json.load(f)
31
+ params = [p for p in params if p["exec_time"] < SEARCH_TIME_S]
32
+ param = max(params, key=(lambda p: p["recall"]))
33
+ param_string = param["param_string"]
34
 
35
  ps = faiss.ParameterSpace()
36
+ ps.initialize(faiss_index)
37
+ ps.set_index_parameters(faiss_index, param_string)
38
 
39
 
40
  def _recover_abstract(inverted_index):
 
50
  return abstract
51
 
52
  def search(query):
53
+ global model, index
54
 
55
+ # TODO: pass in param string directly?
56
  query_embedding = model.encode(query)
57
+ distances, faiss_ids = index.search("embeddings", query_embedding, 20)
 
 
 
 
58
 
59
+ openalex_ids = index[faiss_ids]["idxs"]
60
  search_filter = f'openalex_id:{"|".join(openalex_ids)}'
61
  search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
62
 
 
162
  response_var = gr.State()
163
  query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False)
164
  btn = gr.Button('Search')
165
+ with gr.Group():
166
  results = gr.Markdown()
167
 
168
  md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex!
index.faiss DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf114f0cc57e5674b171e113553e182aa705be7df530493862d21bb48b7ecf9b
3
- size 138019116
 
 
 
 
index.ivfdata DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4d1bbbbb3094702c549d2a2576280a2587622241d0e7ec2765955a3461d89bae
3
- size 6868108080
 
 
 
 
openalex_ids.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eeb2e963d4b3e1026ad550ea8a2e4fca92a1b59aa4d6fd005953ec5505415396
3
- size 3141761935
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  sentence-transformers
2
  faiss-cpu
 
 
1
  sentence-transformers
2
  faiss-cpu
3
+ datasets