colonelwatch commited on
Commit
fb6a589
·
1 Parent(s): c65ad60

Replace broken submodule with a download scheme

Browse files
Files changed (3) hide show
  1. .gitmodules +0 -3
  2. app.py +10 -3
  3. faiss +0 -1
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "faiss"]
2
- path = faiss
3
- url = [email protected]:colonelwatch/abstracts-index-faiss
 
 
 
 
app.py CHANGED
@@ -10,9 +10,10 @@ from pathlib import Path
10
  from sys import stderr
11
  from typing import TypedDict, Self, Any, Callable
12
 
13
- from datasets import Dataset, disable_caching
14
  from datasets.search import FaissIndex
15
  import faiss
 
16
  import gradio as gr
17
  import requests
18
  from sentence_transformers import SentenceTransformer
@@ -242,12 +243,18 @@ def main():
242
  prompt_name = get_env_var("PROMPT_NAME")
243
  trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
244
  fp16 = get_env_var("FP16", bool, default=False)
245
- dir = get_env_var("DIR", Path, default=Path("faiss/index"))
 
246
  search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
247
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
248
  mailto = get_env_var("MAILTO", str, None)
249
 
250
- disable_caching() # disable caching in the datasets library
 
 
 
 
 
251
 
252
  normalize, model = get_model(model_name, dir, trust_remote_code)
253
  index = get_index(dir, search_time_s)
 
10
  from sys import stderr
11
  from typing import TypedDict, Self, Any, Callable
12
 
13
+ from datasets import Dataset
14
  from datasets.search import FaissIndex
15
  import faiss
16
+ from huggingface_hub import snapshot_download
17
  import gradio as gr
18
  import requests
19
  from sentence_transformers import SentenceTransformer
 
243
  prompt_name = get_env_var("PROMPT_NAME")
244
  trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
245
  fp16 = get_env_var("FP16", bool, default=False)
246
+ dir = get_env_var("DIR", Path)
247
+ repo = get_env_var("REPO", str)
248
  search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
249
  k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
250
  mailto = get_env_var("MAILTO", str, None)
251
 
252
+ if dir is None: # acquire the index if it's not local
253
+ if repo is None:
254
+ repo = "colonelwatch/abstracts-faiss"
255
+ dir = Path(snapshot_download(repo, repo_type="dataset")) / "index"
256
+ elif repo is not None:
257
+ print('warning: used "REPO" and also "DIR", ignoring "REPO"...', file=stderr)
258
 
259
  normalize, model = get_model(model_name, dir, trust_remote_code)
260
  index = get_index(dir, search_time_s)
faiss DELETED
@@ -1 +0,0 @@
1
- Subproject commit 652454ca78e307c0d9262da278619570f4ad120d