File size: 7,168 Bytes
e26d32e
5914320
daedc24
e9edc55
23f0ebc
e9edc55
a0edacc
3ab82e8
 
 
e9edc55
08d2180
3eec3b2
00a8910
3a1f579
 
 
 
 
 
 
 
3a2c9fc
53897dc
3eec3b2
3f61915
5914320
a0edacc
 
 
f238fcb
 
 
a0edacc
5914320
 
 
3ab82e8
744d14e
eb810c1
a0edacc
97f5451
9ffa1c2
97f5451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744d14e
 
 
daedc24
53897dc
744d14e
 
 
5914320
daedc24
eb810c1
f238fcb
 
 
53897dc
 
 
 
744d14e
 
 
53897dc
744d14e
 
 
 
 
 
 
daedc24
744d14e
daedc24
 
3a2c9fc
3b0c68a
3a2c9fc
f621dca
daedc24
744d14e
dbfd408
 
 
ea83f7b
744d14e
 
dbfd408
744d14e
 
dbfd408
744d14e
 
 
3b0c68a
bee75ef
 
 
 
 
 
 
 
 
 
90e94e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bee75ef
 
 
 
 
 
 
 
 
3a2c9fc
744d14e
 
66d1715
 
 
24633c7
 
 
 
 
 
 
 
 
cc2cdb0
 
5e7e86a
d79e1d6
 
 
 
 
5e7e86a
 
 
4a1a39e
5e7e86a
 
 
 
 
 
cc2cdb0
 
 
66d1715
cc2cdb0
 
 
 
 
66d1715
cc2cdb0
 
 
744d14e
cc2cdb0
744d14e
 
cc2cdb0
744d14e
90e94e1
744d14e
 
90e94e1
 
744d14e
90e94e1
 
 
744d14e
4b98830
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import Response

from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import io

app = FastAPI()
#model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
#embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model
#1. Specify preffered dimensions
embedding_dimension = 512
# 2. load model
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=embedding_dimension)


index = faiss.IndexFlatL2(embedding_dimension)  
documents = []

templates = Jinja2Templates(directory=".")

class EmbedRequest(BaseModel):
    texts: list[str]

class SearchRequest(BaseModel):
    text: str
    n: int = 5
    
@app.get("/")
def read_root(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})


@app.post("/embed")
def embed_strings(request: EmbedRequest):
    new_documents = request.texts
    print(f"Start embedding of {len(new_documents)} docs")
    batch_size = 20

    # Split the new_documents list into batches of 10 documents
    batches = [new_documents[i:i+batch_size] for i in range(0, len(new_documents), batch_size)]

    # Perform embedding for each batch
    new_embeddings = []
    for batch in batches:
        batch_embeddings = model.encode(batch)
        new_embeddings.extend(batch_embeddings)
        print(f"embeded {batch_size} docs")

    # Handle remaining documents less than batch_size
    remaining_docs = len(new_documents) % batch_size
    print(f"embedind remaining {remaining_docs} docs")
    
    if remaining_docs > 0:
        remaining_batch = new_documents[-remaining_docs:]
        remaining_embeddings = model.encode(remaining_batch)
        new_embeddings.extend(remaining_embeddings)

    index.add(np.array(new_embeddings))
    new_size = index.ntotal
    documents.extend(new_documents)
    print(f"End embedding {len(new_documents)} docs, new DB size: {new_size}")
    return {
        "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
    }
    
def embed_strings_v0(request: EmbedRequest):
    new_documents = request.texts
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))
    new_size = index.ntotal
    documents.extend(new_documents)
    return {
        "message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
    }


@app.post("/search")
def search_string(request: SearchRequest):
    embedding = model.encode([request.text])
    distances, indices = index.search(np.array(embedding), request.n)

    # Get the documents associated with the returned indices
    found_documents = [documents[i] for i in indices[0]]

    return {
        "distances": distances[0].tolist(),
        "indices": indices[0].tolist(),
        "documents": found_documents
    }

#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
    return {"length": index.ntotal}

@app.post("/admin/database/reset")
def reset_database():
    global index
    global documents
    index = faiss.IndexFlatL2(embedding_dimension)
    documents = []
    return {"message": "Database reset"}

@app.get("/admin/documents/download")
def download_documents():
    # Convert the documents list to a JSON string
    documents_json = json.dumps({"texts": documents})

    # Create a response with the JSON string as the content
    response = Response(content=documents_json, media_type="application/json")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=documents.json"

    return response

@app.post("/admin/documents/upload")
def upload_documents(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the JSON data from the file contents
    data = json.loads(contents)

    # Get the list of documents from the JSON data
    new_documents = data["texts"]

    # Add the new documents to the documents list
    documents.extend(new_documents)

    return {"message": f"{len(new_documents)} new documents uploaded"}

@app.post("/admin/documents/embed")
def embed_documents(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the JSON data from the file contents
    data = json.loads(contents)

    # Get the list of documents from the JSON data
    new_documents = data["texts"]

    # Encode the new documents and add them to the FAISS database
    new_embeddings = model.encode(new_documents)
    index.add(np.array(new_embeddings))

    # Add the new documents to the documents list
    documents.extend(new_documents)

    return {"message": f"{len(new_documents)} new documents uploaded and embedded"}


@app.get("/admin/database/download")
def download_database():
    # Save the FAISS index to a file
    faiss.write_index(index, "database.index")

    # Create a response with the index file as the content
    response = FileResponse("database.index", media_type="application/octet-stream")

    # Set the content disposition header to trigger a download
    response.headers["Content-Disposition"] = "attachment; filename=database.index"

    return response


@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    #contents = file.file.read()

    # Open the uploaded file as a binary file object
    with open(file.filename, "wb") as f:
        f.write(file.file.read())

    # Load the FAISS index from the file contents
    global index
    index = faiss.read_index(file.filename)

    return {"message": f"Database uploaded with {index.ntotal} embeddings"}



def upload_database_1(file: UploadFile = File(...)):
    # Open the uploaded file as a binary file object
    with open(file.filename, "wb") as f:
        f.write(file.file.read())

    # Open the file as a binary file object
    with open(file.filename, "rb") as f:
        # Load the FAISS index from the file object
        global index
        index = faiss.read_index_binary(f)

    # Clear the existing documents list and add the new documents
    global documents
    documents = index.reconstruct_n(0, index.ntotal).tolist()

    return {"message": f"Database uploaded with {len(documents)} documents"}


def upload_database_0(file: UploadFile = File(...)):
    # Read the contents of the uploaded file
    contents = file.file.read()

    # Load the FAISS index from the file contents
    global index
    index = faiss.read_index_binary(contents)

    # Clear the existing documents list and add the new documents
    #global documents
    #documents = index.reconstruct_n(0, index.ntotal).tolist()

    return {"message": f"Database uploaded with {index.ntotal} embeddings"}