fastAPI / app.py
Almaatla's picture
added databased management api
744d14e verified
raw
history blame
3.36 kB
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
# Create a list to store the documents
documents = []
templates = Jinja2Templates(directory=".")
class EmbedRequest(BaseModel):
texts: list[str]
class SearchRequest(BaseModel):
text: str
n: int = 5
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/embed")
def embed_strings(request: EmbedRequest):
# Add the new texts to the documents list
new_documents = request.texts
documents.extend(new_documents)
# Encode the new documents and add them to the FAISS database
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
# Get the new size of the FAISS database
new_size = len(documents)
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
@app.post("/search")
def search_string(request: SearchRequest):
embedding = model.encode([request.text])
distances, indices = index.search(np.array(embedding), request.n)
# Get the documents associated with the returned indices
found_documents = [documents[i] for i in indices[0]]
return {
"distances": distances[0].tolist(),
"indices": indices[0].tolist(),
"documents": found_documents
}
#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
return {"length": len(documents)}
@app.post("/admin/database/clear")
def clear_database():
documents.clear()
index.reset()
return {"message": "Database cleared"}
@app.get("/admin/documents/download")
def download_documents():
# Convert the documents list to a JSON string
documents_json = json.dumps(documents)
# Create a response with the JSON string as the content
response = Response(content=documents_json, media_type="application/json")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=documents.json"
return response
@app.get("/admin/database/download")
def download_database():
# Save the FAISS index to a file
faiss.write_index(index, "database.index")
# Create a response with the index file as the content
response = FileResponse("database.index")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=database.index"
return response
@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the FAISS index from the file contents
index = faiss.read_index_binary(contents)
# Clear the existing documents and add the new ones
documents.clear()
documents.extend(index.reconstruct_n(0, index.ntotal))
return {"message": "Database uploaded"}