fastAPI / app.py
Almaatla's picture
Update app.py
3a1f579 verified
raw
history blame
5.92 kB
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import Response
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import io
app = FastAPI()
#model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
#embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model
#1. Specify preffered dimensions
embedding_dimension = 512
# 2. load model
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=embedding_dimension)
index = faiss.IndexFlatL2(embedding_dimension)
documents = []
templates = Jinja2Templates(directory=".")
class EmbedRequest(BaseModel):
texts: list[str]
class SearchRequest(BaseModel):
text: str
n: int = 5
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/embed")
def embed_strings(request: EmbedRequest):
new_documents = request.texts
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
new_size = index.ntotal
documents.extend(new_documents)
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
@app.post("/search")
def search_string(request: SearchRequest):
embedding = model.encode([request.text])
distances, indices = index.search(np.array(embedding), request.n)
# Get the documents associated with the returned indices
found_documents = [documents[i] for i in indices[0]]
return {
"distances": distances[0].tolist(),
"indices": indices[0].tolist(),
"documents": found_documents
}
#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
return {"length": index.ntotal}
@app.post("/admin/database/reset")
def reset_database():
global index
global documents
index = faiss.IndexFlatL2(embedding_dimension)
documents = []
return {"message": "Database reset"}
@app.get("/admin/documents/download")
def download_documents():
# Convert the documents list to a JSON string
documents_json = json.dumps({"texts": documents})
# Create a response with the JSON string as the content
response = Response(content=documents_json, media_type="application/json")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=documents.json"
return response
@app.post("/admin/documents/upload")
def upload_documents(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the JSON data from the file contents
data = json.loads(contents)
# Get the list of documents from the JSON data
new_documents = data["texts"]
# Add the new documents to the documents list
documents.extend(new_documents)
return {"message": f"{len(new_documents)} new documents uploaded"}
@app.post("/admin/documents/embed")
def embed_documents(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the JSON data from the file contents
data = json.loads(contents)
# Get the list of documents from the JSON data
new_documents = data["texts"]
# Encode the new documents and add them to the FAISS database
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
# Add the new documents to the documents list
documents.extend(new_documents)
return {"message": f"{len(new_documents)} new documents uploaded and embedded"}
@app.get("/admin/database/download")
def download_database():
# Save the FAISS index to a file
faiss.write_index(index, "database.index")
# Create a response with the index file as the content
response = FileResponse("database.index", media_type="application/octet-stream")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=database.index"
return response
@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
# Read the contents of the uploaded file
#contents = file.file.read()
# Open the uploaded file as a binary file object
with open(file.filename, "wb") as f:
f.write(file.file.read())
# Load the FAISS index from the file contents
global index
index = faiss.read_index(file.filename)
return {"message": f"Database uploaded with {index.ntotal} embeddings"}
def upload_database_1(file: UploadFile = File(...)):
# Open the uploaded file as a binary file object
with open(file.filename, "wb") as f:
f.write(file.file.read())
# Open the file as a binary file object
with open(file.filename, "rb") as f:
# Load the FAISS index from the file object
global index
index = faiss.read_index_binary(f)
# Clear the existing documents list and add the new documents
global documents
documents = index.reconstruct_n(0, index.ntotal).tolist()
return {"message": f"Database uploaded with {len(documents)} documents"}
def upload_database_0(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the FAISS index from the file contents
global index
index = faiss.read_index_binary(contents)
# Clear the existing documents list and add the new documents
#global documents
#documents = index.reconstruct_n(0, index.ntotal).tolist()
return {"message": f"Database uploaded with {index.ntotal} embeddings"}