colonelwatch commited on
Commit
2db96ca
·
1 Parent(s): 67dc9b0

Merge the index on disk to keep file sizes under 4GB

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +14 -3
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ temp.ivfdata
app.py CHANGED
@@ -12,6 +12,7 @@ from typing import TypedDict, Self, Any, Callable
12
  from datasets import Dataset
13
  from datasets.search import FaissIndex
14
  import faiss
 
15
  import gradio as gr
16
  import requests
17
  from sentence_transformers import SentenceTransformer
@@ -127,12 +128,22 @@ def get_model(
127
  )
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def get_index(dir: Path, search_time_s: float) -> Dataset:
131
  # NOTE: a private attr is used to get the faiss.IO_FLAG_ONDISK_SAME_DIR flag!
132
  index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
133
- faiss_index: faiss.Index = faiss.read_index(
134
- str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR
135
- )
136
  index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index)
137
 
138
  with open(dir / "params.json", "r") as f:
 
12
  from datasets import Dataset
13
  from datasets.search import FaissIndex
14
  import faiss
15
+ from faiss.contrib.ondisk import merge_ondisk
16
  import gradio as gr
17
  import requests
18
  from sentence_transformers import SentenceTransformer
 
128
  )
129
 
130
 
131
+ def merge_shards(dir: Path) -> faiss.Index:
132
+ empty_path = dir / "empty.faiss"
133
+ shard_paths = [str(p) for p in dir.glob("shard_*.faiss")]
134
+ merged_ivfdata_path = Path("temp.ivfdata")
135
+
136
+ index = faiss.read_index(str(empty_path))
137
+ merged_ivfdata_path.unlink(missing_ok=True) # overwrite previous if it exists (TODO: do I need this?)
138
+ merge_ondisk(index, shard_paths, str(merged_ivfdata_path))
139
+
140
+ return index
141
+
142
+
143
  def get_index(dir: Path, search_time_s: float) -> Dataset:
144
  # NOTE: a private attr is used to get the faiss.IO_FLAG_ONDISK_SAME_DIR flag!
145
  index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
146
+ faiss_index = merge_shards(dir / "shards")
 
 
147
  index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index)
148
 
149
  with open(dir / "params.json", "r") as f: