from fastapi import FastAPI, Request, Query from fastapi.templating import Jinja2Templates from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import numpy as np app = FastAPI() model = SentenceTransformer('paraphrase-MiniLM-L6-v2') index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model # Create a list to store the documents documents = [] templates = Jinja2Templates(directory=".") class EmbedRequest(BaseModel): texts: list[str] class SearchRequest(BaseModel): text: str n: int = 5 @app.get("/") def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/embed") def embed_strings(request: EmbedRequest): # Add the new texts to the documents list new_documents = request.texts documents.extend(new_documents) # Encode the new documents and add them to the FAISS database new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) # Get the new size of the FAISS database new_size = len(documents) return { "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}" } @app.post("/search") def search_string(request: SearchRequest): embedding = model.encode([request.text]) distances, indices = index.search(np.array(embedding), request.n) # Get the documents associated with the returned indices found_documents = [documents[i] for i in indices[0]] return { "distances": distances[0].tolist(), "indices": indices[0].tolist(), "documents": found_documents } ######################### ## database management ## ######################### @app.get("/admin/database/length") def get_database_length(): return {"length": len(documents)} @app.post("/admin/database/clear") def clear_database(): documents.clear() index.reset() return {"message": "Database cleared"} @app.get("/admin/documents/download") def download_documents(): # Convert the documents list to a JSON string documents_json = json.dumps(documents) # Create a response with the JSON string as the content response = Response(content=documents_json, media_type="application/json") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=documents.json" return response @app.get("/admin/database/download") def download_database(): # Save the FAISS index to a file faiss.write_index(index, "database.index") # Create a response with the index file as the content response = FileResponse("database.index") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=database.index" return response @app.post("/admin/database/upload") def upload_database(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the FAISS index from the file contents index = faiss.read_index_binary(contents) # Clear the existing documents and add the new ones documents.clear() documents.extend(index.reconstruct_n(0, index.ntotal)) return {"message": "Database uploaded"}