Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import os
|
2 |
import torch
|
|
|
|
|
|
|
|
|
3 |
from flask import Flask, jsonify, request
|
4 |
from flask_cors import CORS
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
|
@@ -30,6 +34,55 @@ accelerator = Accelerator()
|
|
30 |
highest_label = None
|
31 |
loaded_models = {}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def get_model_and_tokenizer(model_id: str):
|
34 |
"""
|
35 |
Load and cache the model and tokenizer for the given model_id.
|
@@ -108,6 +161,17 @@ def generate_response(user_input, model_id):
|
|
108 |
|
109 |
# Append chat history
|
110 |
func_caller = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
for msg in chat_history:
|
113 |
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
import uuid
|
4 |
+
import shutil
|
5 |
+
import numpy as np
|
6 |
+
import faiss
|
7 |
from flask import Flask, jsonify, request
|
8 |
from flask_cors import CORS
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
|
|
|
34 |
highest_label = None
|
35 |
loaded_models = {}
|
36 |
|
37 |
+
# FAISS Index Setup
|
38 |
+
UPLOAD_DIR = "./uploads"
|
39 |
+
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")
|
40 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
41 |
+
document_store = {}
|
42 |
+
|
43 |
+
if os.path.exists(faiss_index_file):
|
44 |
+
try:
|
45 |
+
index = faiss.read_index(faiss_index_file)
|
46 |
+
if index.ntotal > 0:
|
47 |
+
print(f"FAISS index loaded with {index.ntotal} vectors.")
|
48 |
+
index.reset()
|
49 |
+
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
|
50 |
+
else:
|
51 |
+
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error loading FAISS index: {e}, reinitializing.")
|
54 |
+
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
|
55 |
+
else:
|
56 |
+
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
|
57 |
+
|
58 |
+
# Function to upload document
|
59 |
+
def upload_document(file_path, embed_model):
|
60 |
+
doc_id = uuid.uuid4().int % (2**63 - 1)
|
61 |
+
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
|
62 |
+
shutil.copy(file_path, file_location)
|
63 |
+
|
64 |
+
with open(file_location, "r", encoding="utf-8") as f:
|
65 |
+
text = f.read()
|
66 |
+
|
67 |
+
vector = embed_model.encode(text).astype("float32")
|
68 |
+
index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
|
69 |
+
document_store[doc_id] = {"path": file_location, "text": text}
|
70 |
+
|
71 |
+
faiss.write_index(index, faiss_index_file)
|
72 |
+
print(f"Document uploaded with doc_id: {doc_id}")
|
73 |
+
|
74 |
+
@app.route("/upload", methods=["POST"])
|
75 |
+
def handle_upload():
|
76 |
+
if "file" not in request.files:
|
77 |
+
return jsonify({"error": "No file provided"}), 400
|
78 |
+
|
79 |
+
file = request.files["file"]
|
80 |
+
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
81 |
+
file.save(file_path)
|
82 |
+
|
83 |
+
upload_document(file_path, bertmodel)
|
84 |
+
return jsonify({"message": "File uploaded successfully"})
|
85 |
+
|
86 |
def get_model_and_tokenizer(model_id: str):
|
87 |
"""
|
88 |
Load and cache the model and tokenizer for the given model_id.
|
|
|
161 |
|
162 |
# Append chat history
|
163 |
func_caller = []
|
164 |
+
|
165 |
+
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
|
166 |
+
D, I = index.search(query_vector, 1)
|
167 |
+
|
168 |
+
retrieved_knowledge = document_store.get(I[0][0], {}).get("text", "No relevant information found.")
|
169 |
+
|
170 |
+
# Construct the knowledge prompt
|
171 |
+
prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
|
172 |
+
|
173 |
+
# Add the retrieved knowledge to the prompt
|
174 |
+
func_caller.append({"role": "system", "content": prompt})
|
175 |
|
176 |
for msg in chat_history:
|
177 |
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
|