Almaatla commited on
Commit
744d14e
·
verified ·
1 Parent(s): f238fcb

added databased management api

Browse files
Files changed (1) hide show
  1. app.py +79 -4
app.py CHANGED
@@ -8,6 +8,8 @@ import numpy as np
8
  app = FastAPI()
9
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
10
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
 
 
11
 
12
  templates = Jinja2Templates(directory=".")
13
 
@@ -22,14 +24,87 @@ class SearchRequest(BaseModel):
22
  def read_root(request: Request):
23
  return templates.TemplateResponse("index.html", {"request": request})
24
 
 
25
  @app.post("/embed")
26
  def embed_strings(request: EmbedRequest):
27
- embeddings = model.encode(request.texts)
28
- index.add(np.array(embeddings))
29
- return {"message": "Strings embedded and added to FAISS database"}
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @app.post("/search")
32
  def search_string(request: SearchRequest):
33
  embedding = model.encode([request.text])
34
  distances, indices = index.search(np.array(embedding), request.n)
35
- return {"distances": distances[0].tolist(), "indices": indices[0].tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app = FastAPI()
9
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
10
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
11
+ # Create a list to store the documents
12
+ documents = []
13
 
14
  templates = Jinja2Templates(directory=".")
15
 
 
24
  def read_root(request: Request):
25
  return templates.TemplateResponse("index.html", {"request": request})
26
 
27
+
28
  @app.post("/embed")
29
  def embed_strings(request: EmbedRequest):
30
+ # Add the new texts to the documents list
31
+ new_documents = request.texts
32
+ documents.extend(new_documents)
33
+
34
+ # Encode the new documents and add them to the FAISS database
35
+ new_embeddings = model.encode(new_documents)
36
+ index.add(np.array(new_embeddings))
37
+
38
+ # Get the new size of the FAISS database
39
+ new_size = len(documents)
40
+
41
+ return {
42
+ "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
43
+ }
44
 
45
  @app.post("/search")
46
  def search_string(request: SearchRequest):
47
  embedding = model.encode([request.text])
48
  distances, indices = index.search(np.array(embedding), request.n)
49
+
50
+ # Get the documents associated with the returned indices
51
+ found_documents = [documents[i] for i in indices[0]]
52
+
53
+ return {
54
+ "distances": distances[0].tolist(),
55
+ "indices": indices[0].tolist(),
56
+ "documents": found_documents
57
+ }
58
+
59
+ #########################
60
+ ## database management ##
61
+ #########################
62
+ @app.get("/admin/database/length")
63
+ def get_database_length():
64
+ return {"length": len(documents)}
65
+
66
+ @app.post("/admin/database/clear")
67
+ def clear_database():
68
+ documents.clear()
69
+ index.reset()
70
+ return {"message": "Database cleared"}
71
+
72
+ @app.get("/admin/documents/download")
73
+ def download_documents():
74
+ # Convert the documents list to a JSON string
75
+ documents_json = json.dumps(documents)
76
+
77
+ # Create a response with the JSON string as the content
78
+ response = Response(content=documents_json, media_type="application/json")
79
+
80
+ # Set the content disposition header to trigger a download
81
+ response.headers["Content-Disposition"] = "attachment; filename=documents.json"
82
+
83
+ return response
84
+
85
+ @app.get("/admin/database/download")
86
+ def download_database():
87
+ # Save the FAISS index to a file
88
+ faiss.write_index(index, "database.index")
89
+
90
+ # Create a response with the index file as the content
91
+ response = FileResponse("database.index")
92
+
93
+ # Set the content disposition header to trigger a download
94
+ response.headers["Content-Disposition"] = "attachment; filename=database.index"
95
+
96
+ return response
97
+
98
+ @app.post("/admin/database/upload")
99
+ def upload_database(file: UploadFile = File(...)):
100
+ # Read the contents of the uploaded file
101
+ contents = file.file.read()
102
+
103
+ # Load the FAISS index from the file contents
104
+ index = faiss.read_index_binary(contents)
105
+
106
+ # Clear the existing documents and add the new ones
107
+ documents.clear()
108
+ documents.extend(index.reconstruct_n(0, index.ntotal))
109
+
110
+ return {"message": "Database uploaded"}