Almaatla commited on
Commit
daedc24
·
verified ·
1 Parent(s): 744d14e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -21
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, Request, Query
2
  from fastapi.templating import Jinja2Templates
 
3
  from pydantic import BaseModel
4
  from sentence_transformers import SentenceTransformer
5
  import faiss
@@ -8,8 +9,7 @@ import numpy as np
8
  app = FastAPI()
9
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
10
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
11
- # Create a list to store the documents
12
- documents = []
13
 
14
  templates = Jinja2Templates(directory=".")
15
 
@@ -27,33 +27,24 @@ def read_root(request: Request):
27
 
28
  @app.post("/embed")
29
  def embed_strings(request: EmbedRequest):
30
- # Add the new texts to the documents list
31
  new_documents = request.texts
32
- documents.extend(new_documents)
33
-
34
- # Encode the new documents and add them to the FAISS database
35
  new_embeddings = model.encode(new_documents)
36
  index.add(np.array(new_embeddings))
37
-
38
- # Get the new size of the FAISS database
39
- new_size = len(documents)
40
-
41
  return {
42
  "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
43
  }
44
 
 
45
  @app.post("/search")
46
  def search_string(request: SearchRequest):
47
  embedding = model.encode([request.text])
48
  distances, indices = index.search(np.array(embedding), request.n)
49
-
50
- # Get the documents associated with the returned indices
51
- found_documents = [documents[i] for i in indices[0]]
52
-
53
  return {
54
  "distances": distances[0].tolist(),
55
  "indices": indices[0].tolist(),
56
- "documents": found_documents
57
  }
58
 
59
  #########################
@@ -61,18 +52,20 @@ def search_string(request: SearchRequest):
61
  #########################
62
  @app.get("/admin/database/length")
63
  def get_database_length():
64
- return {"length": len(documents)}
65
 
66
- @app.post("/admin/database/clear")
67
- def clear_database():
68
- documents.clear()
69
  index.reset()
70
- return {"message": "Database cleared"}
71
 
72
  @app.get("/admin/documents/download")
73
  def download_documents():
 
 
 
74
  # Convert the documents list to a JSON string
75
- documents_json = json.dumps(documents)
76
 
77
  # Create a response with the JSON string as the content
78
  response = Response(content=documents_json, media_type="application/json")
 
1
  from fastapi import FastAPI, Request, Query
2
  from fastapi.templating import Jinja2Templates
3
+ from fastapi import File, UploadFile
4
  from pydantic import BaseModel
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
 
9
  app = FastAPI()
10
  model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
11
  index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
12
+
 
13
 
14
  templates = Jinja2Templates(directory=".")
15
 
 
27
 
28
  @app.post("/embed")
29
  def embed_strings(request: EmbedRequest):
 
30
  new_documents = request.texts
 
 
 
31
  new_embeddings = model.encode(new_documents)
32
  index.add(np.array(new_embeddings))
33
+ new_size = index.ntotal
 
 
 
34
  return {
35
  "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
36
  }
37
 
38
+
39
  @app.post("/search")
40
  def search_string(request: SearchRequest):
41
  embedding = model.encode([request.text])
42
  distances, indices = index.search(np.array(embedding), request.n)
43
+ found_documents = index.reconstruct_n(indices[0], request.n)
 
 
 
44
  return {
45
  "distances": distances[0].tolist(),
46
  "indices": indices[0].tolist(),
47
+ "documents": found_documents.tolist()
48
  }
49
 
50
  #########################
 
52
  #########################
53
  @app.get("/admin/database/length")
54
  def get_database_length():
55
+ return {"length": index.ntotal}
56
 
57
+ @app.post("/admin/database/reset")
58
+ def reset_database():
 
59
  index.reset()
60
+ return {"message": "Database reset"}
61
 
62
  @app.get("/admin/documents/download")
63
  def download_documents():
64
+ # Reconstruct the documents from the FAISS index
65
+ documents = index.reconstruct_n(0, index.ntotal)
66
+
67
  # Convert the documents list to a JSON string
68
+ documents_json = json.dumps(documents.tolist())
69
 
70
  # Create a response with the JSON string as the content
71
  response = Response(content=documents_json, media_type="application/json")