colonelwatch commited on
Commit
a8b8684
·
1 Parent(s): c574006

Add prompt name as an env var option

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -210,6 +210,7 @@ def format_response(neighbors: list[Work], distances: list[float]) -> str:
210
  def main():
211
  # TODO: figure out some better defaults?
212
  model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
 
213
  trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
214
  fp16 = get_env_var("FP16", bool, default=False)
215
  dir = get_env_var("DIR", Path, default=Path("index"))
@@ -229,7 +230,7 @@ def main():
229
 
230
  # function signature: (expanded tuple of input batches) -> tuple of output batches
231
  def search(query: list[str]) -> tuple[list[str]]:
232
- query_embedding = model.encode(query)
233
  distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
234
 
235
  faiss_ids_flat = list(chain(*faiss_ids))
 
210
  def main():
211
  # TODO: figure out some better defaults?
212
  model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
213
+ prompt_name = get_env_var("PROMPT_NAME")
214
  trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
215
  fp16 = get_env_var("FP16", bool, default=False)
216
  dir = get_env_var("DIR", Path, default=Path("index"))
 
230
 
231
  # function signature: (expanded tuple of input batches) -> tuple of output batches
232
  def search(query: list[str]) -> tuple[list[str]]:
233
+ query_embedding = model.encode(query, prompt_name)
234
  distances, faiss_ids = index.search_batch("embeddings", query_embedding, k)
235
 
236
  faiss_ids_flat = list(chain(*faiss_ids))