File size: 3,355 Bytes
e26d32e
5914320
a0edacc
3ab82e8
 
 
3eec3b2
00a8910
3ab82e8
 
744d14e
 
3eec3b2
3f61915
5914320
a0edacc
 
 
f238fcb
 
 
a0edacc
5914320
 
 
3ab82e8
744d14e
eb810c1
a0edacc
744d14e
 
 
 
 
 
 
 
 
 
 
 
 
 
5914320
eb810c1
f238fcb
 
 
744d14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
index = faiss.IndexFlatL2(384)  # 384 is the dimensionality of the MiniLM model
# Create a list to store the documents
documents = []

templates = Jinja2Templates(directory=".")

class EmbedRequest(BaseModel):
    texts: list[str]

class SearchRequest(BaseModel):
    text: str
    n: int = 5
    
@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/embed")
def embed_strings(request: EmbedRequest):
    # Add the new texts to the documents list
    new_documents = request.texts
    documents.extend(new_documents)

    # Encode the new documents and add them to the FAISS database
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))

    # Get the new size of the FAISS database
    new_size = len(documents)

    return {
        "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
    }

@app.post("/search")
def search_string(request: SearchRequest):
    embedding = model.encode([request.text])
    distances, indices = index.search(np.array(embedding), request.n)

    # Get the documents associated with the returned indices
    found_documents = [documents[i] for i in indices[0]]

    return {
        "distances": distances[0].tolist(),
        "indices": indices[0].tolist(),
        "documents": found_documents
    }

#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
    return {"length": len(documents)}

@app.post("/admin/database/clear")
def clear_database():
    documents.clear()
    index.reset()
    return {"message": "Database cleared"}

@app.get("/admin/documents/download")
def download_documents():
    # Convert the documents list to a JSON string
    documents_json = json.dumps(documents)

    # Create a response with the JSON string as the content
    response = Response(content=documents_json, media_type="application/json")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=documents.json"

    return response

@app.get("/admin/database/download")
def download_database():
    # Save the FAISS index to a file
    faiss.write_index(index, "database.index")

    # Create a response with the index file as the content
    response = FileResponse("database.index")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=database.index"

    return response

@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the FAISS index from the file contents
    index = faiss.read_index_binary(contents)

    # Clear the existing documents and add the new ones
    documents.clear()
    documents.extend(index.reconstruct_n(0, index.ntotal))

    return {"message": "Database uploaded"}