scooter7 commited on
Commit
befbdb6
·
verified ·
1 Parent(s): 5124f8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -85,14 +85,14 @@ for doc in documents:
85
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
86
  chunk_embeddings = embedding_model.encode(all_chunks)
87
  embedding_dim = chunk_embeddings.shape[1]
88
- index = faiss.IndexFlatL2(embedding_dim)
89
- index.add(np.array(chunk_embeddings))
90
 
91
  generator = pipeline("text-generation", model="gpt2", max_length=256)
92
 
93
  def retrieve_context(query: str, k: int = 5) -> List[str]:
94
  query_embedding = embedding_model.encode([query])
95
- distances, indices = index.search(np.array(query_embedding), k)
96
  return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)]
97
 
98
  def generate_answer(query: str) -> str:
@@ -238,7 +238,7 @@ async def chat_endpoint(payload: dict):
238
  return {"answer": answer}
239
 
240
  @app.get("/")
241
- async def index():
242
  index_path = current_dir / "index.html"
243
  html_content = index_path.read_text()
244
  return HTMLResponse(content=html_content)
 
85
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
86
  chunk_embeddings = embedding_model.encode(all_chunks)
87
  embedding_dim = chunk_embeddings.shape[1]
88
+ faiss_index = faiss.IndexFlatL2(embedding_dim)
89
+ faiss_index.add(np.array(chunk_embeddings))
90
 
91
  generator = pipeline("text-generation", model="gpt2", max_length=256)
92
 
93
  def retrieve_context(query: str, k: int = 5) -> List[str]:
94
  query_embedding = embedding_model.encode([query])
95
+ distances, indices = faiss_index.search(np.array(query_embedding), k)
96
  return [all_chunks[idx] for idx in indices[0] if idx < len(all_chunks)]
97
 
98
  def generate_answer(query: str) -> str:
 
238
  return {"answer": answer}
239
 
240
  @app.get("/")
241
+ async def index_endpoint():
242
  index_path = current_dir / "index.html"
243
  html_content = index_path.read_text()
244
  return HTMLResponse(content=html_content)