Almaatla commited on
Commit
53897dc
·
verified ·
1 Parent(s): 23f0ebc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -13,7 +13,7 @@ import json
13
  app = FastAPI()
14
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
15
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
16
-
17
 
18
  templates = Jinja2Templates(directory=".")
19
 
@@ -35,6 +35,7 @@ def embed_strings(request: EmbedRequest):
35
  new_embeddings = model.encode(new_documents)
36
  index.add(np.array(new_embeddings))
37
  new_size = index.ntotal
 
38
  return {
39
  "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
40
  }
@@ -44,11 +45,14 @@ def embed_strings(request: EmbedRequest):
44
  def search_string(request: SearchRequest):
45
  embedding = model.encode([request.text])
46
  distances, indices = index.search(np.array(embedding), request.n)
47
- found_documents = index.reconstruct_n(indices[0], int(request.n))
 
 
 
48
  return {
49
  "distances": distances[0].tolist(),
50
  "indices": indices[0].tolist(),
51
- "documents": found_documents.tolist()
52
  }
53
 
54
  #########################
@@ -63,19 +67,19 @@ def reset_database():
63
  index.reset()
64
  return {"message": "Database reset"}
65
 
66
- @app.get("/admin/documents/download")
67
- def download_documents():
68
- # Reconstruct the documents from the FAISS index
69
- documents = index.reconstruct_n(0, index.ntotal)
70
 
71
- # Convert the documents list to a JSON string
72
- documents_json = json.dumps(documents.tolist())
73
 
74
  # Create a response with the JSON string as the content
75
- response = Response(content=documents_json, media_type="application/json")
76
 
77
  # Set the content disposition header to trigger a download
78
- response.headers["Content-Disposition"] = "attachment; filename=documents.json"
79
 
80
  return response
81
 
 
13
  app = FastAPI()
14
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
15
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
16
+ documents = []
17
 
18
  templates = Jinja2Templates(directory=".")
19
 
 
35
  new_embeddings = model.encode(new_documents)
36
  index.add(np.array(new_embeddings))
37
  new_size = index.ntotal
38
+ documents.extend(new_documents)
39
  return {
40
  "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
41
  }
 
45
  def search_string(request: SearchRequest):
46
  embedding = model.encode([request.text])
47
  distances, indices = index.search(np.array(embedding), request.n)
48
+
49
+ # Get the documents associated with the returned indices
50
+ found_documents = [documents[i] for i in indices[0]]
51
+
52
  return {
53
  "distances": distances[0].tolist(),
54
  "indices": indices[0].tolist(),
55
+ "documents": found_documents
56
  }
57
 
58
  #########################
 
67
  index.reset()
68
  return {"message": "Database reset"}
69
 
70
+ @app.get("/admin/embeddings/download")
71
+ def download_embeddings():
72
+ # Reconstruct the embeddings from the FAISS index
73
+ embeddings = index.reconstruct_n(0, index.ntotal)
74
 
75
+ # Convert the embeddings list to a JSON string
76
+ embeddings_json = json.dumps(embeddings.tolist())
77
 
78
  # Create a response with the JSON string as the content
79
+ response = Response(content=embeddings_json, media_type="application/json")
80
 
81
  # Set the content disposition header to trigger a download
82
+ response.headers["Content-Disposition"] = "attachment; filename=embeddings.json"
83
 
84
  return response
85