File size: 7,168 Bytes
e26d32e 5914320 daedc24 e9edc55 23f0ebc e9edc55 a0edacc 3ab82e8 e9edc55 08d2180 3eec3b2 00a8910 3a1f579 3a2c9fc 53897dc 3eec3b2 3f61915 5914320 a0edacc f238fcb a0edacc 5914320 3ab82e8 744d14e eb810c1 a0edacc 97f5451 9ffa1c2 97f5451 744d14e daedc24 53897dc 744d14e 5914320 daedc24 eb810c1 f238fcb 53897dc 744d14e 53897dc 744d14e daedc24 744d14e daedc24 3a2c9fc 3b0c68a 3a2c9fc f621dca daedc24 744d14e dbfd408 ea83f7b 744d14e dbfd408 744d14e dbfd408 744d14e 3b0c68a bee75ef 90e94e1 bee75ef 3a2c9fc 744d14e 66d1715 24633c7 cc2cdb0 5e7e86a d79e1d6 5e7e86a 4a1a39e 5e7e86a cc2cdb0 66d1715 cc2cdb0 66d1715 cc2cdb0 744d14e cc2cdb0 744d14e cc2cdb0 744d14e 90e94e1 744d14e 90e94e1 744d14e 90e94e1 744d14e 4b98830 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
from fastapi import FastAPI, Request, Query
from fastapi.templating import Jinja2Templates
from fastapi import File, UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import Response
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import io
app = FastAPI()
#model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
#embedding_dimension = 384 # 384 is the dimensionality of the MiniLM model
#1. Specify preffered dimensions
embedding_dimension = 512
# 2. load model
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=embedding_dimension)
index = faiss.IndexFlatL2(embedding_dimension)
documents = []
templates = Jinja2Templates(directory=".")
class EmbedRequest(BaseModel):
texts: list[str]
class SearchRequest(BaseModel):
text: str
n: int = 5
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/embed")
def embed_strings(request: EmbedRequest):
new_documents = request.texts
print(f"Start embedding of {len(new_documents)} docs")
batch_size = 20
# Split the new_documents list into batches of 10 documents
batches = [new_documents[i:i+batch_size] for i in range(0, len(new_documents), batch_size)]
# Perform embedding for each batch
new_embeddings = []
for batch in batches:
batch_embeddings = model.encode(batch)
new_embeddings.extend(batch_embeddings)
print(f"embeded {batch_size} docs")
# Handle remaining documents less than batch_size
remaining_docs = len(new_documents) % batch_size
print(f"embedind remaining {remaining_docs} docs")
if remaining_docs > 0:
remaining_batch = new_documents[-remaining_docs:]
remaining_embeddings = model.encode(remaining_batch)
new_embeddings.extend(remaining_embeddings)
index.add(np.array(new_embeddings))
new_size = index.ntotal
documents.extend(new_documents)
print(f"End embedding {len(new_documents)} docs, new DB size: {new_size}")
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
def embed_strings_v0(request: EmbedRequest):
new_documents = request.texts
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
new_size = index.ntotal
documents.extend(new_documents)
return {
"message": f"{len(new_documents)} new strings embedded and added to FAISS database. New size of the database: {new_size}"
}
@app.post("/search")
def search_string(request: SearchRequest):
embedding = model.encode([request.text])
distances, indices = index.search(np.array(embedding), request.n)
# Get the documents associated with the returned indices
found_documents = [documents[i] for i in indices[0]]
return {
"distances": distances[0].tolist(),
"indices": indices[0].tolist(),
"documents": found_documents
}
#########################
## database management ##
#########################
@app.get("/admin/database/length")
def get_database_length():
return {"length": index.ntotal}
@app.post("/admin/database/reset")
def reset_database():
global index
global documents
index = faiss.IndexFlatL2(embedding_dimension)
documents = []
return {"message": "Database reset"}
@app.get("/admin/documents/download")
def download_documents():
# Convert the documents list to a JSON string
documents_json = json.dumps({"texts": documents})
# Create a response with the JSON string as the content
response = Response(content=documents_json, media_type="application/json")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=documents.json"
return response
@app.post("/admin/documents/upload")
def upload_documents(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the JSON data from the file contents
data = json.loads(contents)
# Get the list of documents from the JSON data
new_documents = data["texts"]
# Add the new documents to the documents list
documents.extend(new_documents)
return {"message": f"{len(new_documents)} new documents uploaded"}
@app.post("/admin/documents/embed")
def embed_documents(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the JSON data from the file contents
data = json.loads(contents)
# Get the list of documents from the JSON data
new_documents = data["texts"]
# Encode the new documents and add them to the FAISS database
new_embeddings = model.encode(new_documents)
index.add(np.array(new_embeddings))
# Add the new documents to the documents list
documents.extend(new_documents)
return {"message": f"{len(new_documents)} new documents uploaded and embedded"}
@app.get("/admin/database/download")
def download_database():
# Save the FAISS index to a file
faiss.write_index(index, "database.index")
# Create a response with the index file as the content
response = FileResponse("database.index", media_type="application/octet-stream")
# Set the content disposition header to trigger a download
response.headers["Content-Disposition"] = "attachment; filename=database.index"
return response
@app.post("/admin/database/upload")
def upload_database(file: UploadFile = File(...)):
# Read the contents of the uploaded file
#contents = file.file.read()
# Open the uploaded file as a binary file object
with open(file.filename, "wb") as f:
f.write(file.file.read())
# Load the FAISS index from the file contents
global index
index = faiss.read_index(file.filename)
return {"message": f"Database uploaded with {index.ntotal} embeddings"}
def upload_database_1(file: UploadFile = File(...)):
# Open the uploaded file as a binary file object
with open(file.filename, "wb") as f:
f.write(file.file.read())
# Open the file as a binary file object
with open(file.filename, "rb") as f:
# Load the FAISS index from the file object
global index
index = faiss.read_index_binary(f)
# Clear the existing documents list and add the new documents
global documents
documents = index.reconstruct_n(0, index.ntotal).tolist()
return {"message": f"Database uploaded with {len(documents)} documents"}
def upload_database_0(file: UploadFile = File(...)):
# Read the contents of the uploaded file
contents = file.file.read()
# Load the FAISS index from the file contents
global index
index = faiss.read_index_binary(contents)
# Clear the existing documents list and add the new documents
#global documents
#documents = index.reconstruct_n(0, index.ntotal).tolist()
return {"message": f"Database uploaded with {index.ntotal} embeddings"}
|