from fastapi import FastAPI, Request, Query from fastapi.templating import Jinja2Templates from fastapi import File, UploadFile from fastapi.responses import FileResponse from fastapi.responses import Response from pydantic import BaseModel from sentence_transformers import SentenceTransformer import faiss import numpy as np import json import io app = FastAPI() model = SentenceTransformer('paraphrase-MiniLM-L6-v2') embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model index = faiss.IndexFlatL2(embedding_dimension) documents = [] templates = Jinja2Templates(directory=".") class EmbedRequest(BaseModel): texts: list[str] class SearchRequest(BaseModel): text: str n: int = 5 @app.get("/") def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/embed") def embed_strings(request: EmbedRequest): new_documents = request.texts new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) new_size = index.ntotal documents.extend(new_documents) return { "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}" } @app.post("/search") def search_string(request: SearchRequest): embedding = model.encode([request.text]) distances, indices = index.search(np.array(embedding), request.n) # Get the documents associated with the returned indices found_documents = [documents[i] for i in indices[0]] return { "distances": distances[0].tolist(), "indices": indices[0].tolist(), "documents": found_documents } ######################### ## database management ## ######################### @app.get("/admin/database/length") def get_database_length(): return {"length": index.ntotal} @app.post("/admin/database/reset") def reset_database(): global index global documents index = faiss.IndexFlatL2(embedding_dimension) documents = [] return {"message": "Database reset"} @app.get("/admin/documents/download") def download_documents(): # Convert the documents list to a JSON string documents_json = json.dumps({"texts": documents}) # Create a response with the JSON string as the content response = Response(content=documents_json, media_type="application/json") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=documents.json" return response ##### TESTING ###### @app.post("/admin/documents/upload") def upload_documents(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = file.file.read() # Load the JSON data from the file contents data = json.loads(contents) # Get the list of documents from the JSON data new_documents = data["texts"] # Encode the new documents and add them to the FAISS database new_embeddings = model.encode(new_documents) index.add(np.array(new_embeddings)) # Add the new documents to the documents list documents.extend(new_documents) return {"message": f"{len(new_documents)} new documents uploaded and embedded"} @app.get("/admin/database/download") def download_database(): # Save the FAISS index to a file faiss.write_index(index, "database.index") # Create a response with the index file as the content response = FileResponse("database.index") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=database.index" return response def download_database_0(): # Save the FAISS index and documents list to a single file data = { "index": faiss.write_index_binary(index), "documents": documents } with open("database.json", "w") as f: json.dump(data, f) # Create a response with the database file as the content response = FileResponse("database.json") # Set the content disposition header to trigger a download response.headers["Content-Disposition"] = "attachment; filename=database.json" return response @app.post("/admin/database/upload") def upload_database(file: UploadFile = File(...)): # Read the contents of the uploaded file contents = json.load(file.file) # Load the FAISS index from the file contents index = faiss.read_index_binary(contents["index"]) # Load the documents list from the file contents documents.clear() documents.extend(contents["documents"]) return {"message": "Database uploaded", "documents": documents}