File size: 3,355 Bytes
e26d32e 5914320 a0edacc 3ab82e8 3eec3b2 00a8910 3ab82e8 744d14e 3eec3b2 3f61915 5914320 a0edacc f238fcb a0edacc 5914320 3ab82e8 744d14e eb810c1 a0edacc 744d14e 5914320 eb810c1 f238fcb 744d14e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
index = faiss.IndexFlatL2(384) # 384 is the dimensionality of the MiniLM model
# Create a list to store the documents
documents = []
templates = Jinja2Templates(directory=".")
class EmbedRequest(BaseModel):
texts: list[str]
class SearchRequest(BaseModel):
text: str
n: int = 5
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/embed")
def embed_strings(request: EmbedRequest):
# Add the new texts to the documents list
new_documents = request.texts
documents.extend(new_documents)
# Encode the new documents and add them to the FAISS database
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
# Get the new size of the FAISS database
new_size = len(documents)
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
@app.post("/search")
def search_string(request: SearchRequest):
embedding = model.encode([request.text])
distances, indices = index.search(np.array(embedding), request.n)
# Get the documents associated with the returned indices
found_documents = [documents[i] for i in indices[0]]
return {
"distances": distances[0].tolist(),
"indices": indices[0].tolist(),
"documents": found_documents
}
#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
return {"length": len(documents)}
@app.post("/admin/database/clear")
def clear_database():
documents.clear()
index.reset()
return {"message": "Database cleared"}
@app.get("/admin/documents/download")
def download_documents():
# Convert the documents list to a JSON string
documents_json = json.dumps(documents)
# Create a response with the JSON string as the content
response = Response(content=documents_json, media_type="application/json")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=documents.json"
return response
@app.get("/admin/database/download")
def download_database():
# Save the FAISS index to a file
faiss.write_index(index, "database.index")
# Create a response with the index file as the content
response = FileResponse("database.index")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=database.index"
return response
@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the FAISS index from the file contents
index = faiss.read_index_binary(contents)
# Clear the existing documents and add the new ones
documents.clear()
documents.extend(index.reconstruct_n(0, index.ntotal))
return {"message": "Database uploaded"}
|