File size: 5,321 Bytes
e26d32e
5914320
daedc24
e9edc55
23f0ebc
e9edc55
a0edacc
3ab82e8
 
 
e9edc55
08d2180
3eec3b2
00a8910
3ab82e8
3a2c9fc
 
53897dc
3eec3b2
3f61915
5914320
a0edacc
 
 
f238fcb
 
 
a0edacc
5914320
 
 
3ab82e8
744d14e
eb810c1
a0edacc
744d14e
 
 
daedc24
53897dc
744d14e
 
 
5914320
daedc24
eb810c1
f238fcb
 
 
53897dc
 
 
 
744d14e
 
 
53897dc
744d14e
 
 
 
 
 
 
daedc24
744d14e
daedc24
 
3a2c9fc
3b0c68a
3a2c9fc
f621dca
daedc24
744d14e
dbfd408
 
 
ea83f7b
744d14e
 
dbfd408
744d14e
 
dbfd408
744d14e
 
 
3b0c68a
bee75ef
 
 
 
 
 
 
 
 
 
90e94e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bee75ef
 
 
 
 
 
 
 
 
3a2c9fc
744d14e
 
66d1715
 
 
24633c7
 
 
 
 
 
 
 
 
cc2cdb0
 
 
 
 
66d1715
cc2cdb0
 
 
 
 
66d1715
cc2cdb0
 
 
744d14e
cc2cdb0
744d14e
 
 
cc2cdb0
744d14e
90e94e1
744d14e
 
90e94e1
 
744d14e
90e94e1
 
 
744d14e
4b98830
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import Response

from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import io

app = FastAPI()
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model
index = faiss.IndexFlatL2(embedding_dimension)  
documents = []

templates = Jinja2Templates(directory=".")

class EmbedRequest(BaseModel):
    texts: list[str]

class SearchRequest(BaseModel):
    text: str
    n: int = 5
    
@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/embed")
def embed_strings(request: EmbedRequest):
    new_documents = request.texts
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))
    new_size = index.ntotal
    documents.extend(new_documents)
    return {
        "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
    }


@app.post("/search")
def search_string(request: SearchRequest):
    embedding = model.encode([request.text])
    distances, indices = index.search(np.array(embedding), request.n)

    # Get the documents associated with the returned indices
    found_documents = [documents[i] for i in indices[0]]

    return {
        "distances": distances[0].tolist(),
        "indices": indices[0].tolist(),
        "documents": found_documents
    }

#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
    return {"length": index.ntotal}

@app.post("/admin/database/reset")
def reset_database():
    global index
    global documents
    index = faiss.IndexFlatL2(embedding_dimension)
    documents = []
    return {"message": "Database reset"}

@app.get("/admin/documents/download")
def download_documents():
    # Convert the documents list to a JSON string
    documents_json = json.dumps({"texts": documents})

    # Create a response with the JSON string as the content
    response = Response(content=documents_json, media_type="application/json")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=documents.json"

    return response

@app.post("/admin/documents/upload")
def upload_documents(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the JSON data from the file contents
    data = json.loads(contents)

    # Get the list of documents from the JSON data
    new_documents = data["texts"]

    # Add the new documents to the documents list
    documents.extend(new_documents)

    return {"message": f"{len(new_documents)} new documents uploaded"}

@app.post("/admin/documents/embed")
def embed_documents(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the JSON data from the file contents
    data = json.loads(contents)

    # Get the list of documents from the JSON data
    new_documents = data["texts"]

    # Encode the new documents and add them to the FAISS database
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))

    # Add the new documents to the documents list
    documents.extend(new_documents)

    return {"message": f"{len(new_documents)} new documents uploaded and embedded"}


@app.get("/admin/database/download")
def download_database():
    # Save the FAISS index to a file
    faiss.write_index(index, "database.index")

    # Create a response with the index file as the content
    response = FileResponse("database.index", media_type="application/octet-stream")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=database.index"

    return response


@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
    # Open the uploaded file as a binary file object
    with open(file.filename, "wb") as f:
        f.write(file.file.read())

    # Open the file as a binary file object
    with open(file.filename, "rb") as f:
        # Load the FAISS index from the file object
        global index
        index = faiss.read_index_binary(f)

    # Clear the existing documents list and add the new documents
    global documents
    documents = index.reconstruct_n(0, index.ntotal).tolist()

    return {"message": f"Database uploaded with {len(documents)} documents"}


@app.post("/admin/database/upload")
def upload_database_0(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the FAISS index from the file contents
    global index
    index = faiss.read_index_binary(contents)

    # Clear the existing documents list and add the new documents
    #global documents
    #documents = index.reconstruct_n(0, index.ntotal).tolist()

    return {"message": f"Database uploaded with {index.ntotal} embeddings"}