Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -83,9 +83,9 @@ def initialize_faiss():
|
|
83 |
def save_faiss_index(index):
|
84 |
try:
|
85 |
if torch.cuda.is_available():
|
86 |
-
print("Moving FAISS index back to CPU before saving.")
|
87 |
-
|
88 |
-
|
89 |
print(f"Saving FAISS index to {faiss_index_file}.")
|
90 |
faiss.write_index(index, faiss_index_file)
|
91 |
print(f"FAISS index successfully saved to {faiss_index_file}.")
|
@@ -101,7 +101,7 @@ save_faiss_index(index)
|
|
101 |
# Load document store and populate FAISS index
|
102 |
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledge_text.txt") # Ensure this path is correct
|
103 |
|
104 |
-
def
|
105 |
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
|
106 |
global document_store
|
107 |
document_store = {} # Reset document store
|
@@ -121,15 +121,36 @@ def load_document_store():
|
|
121 |
else:
|
122 |
print("Error: knowledgebase.txt not found!")
|
123 |
|
124 |
-
# Generate embeddings for all documents
|
125 |
-
embeddings = bertmodel.encode(all_texts)
|
126 |
-
|
127 |
-
|
128 |
# Add embeddings to FAISS index
|
129 |
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
|
130 |
print(f"Added {len(all_texts)} document embeddings to FAISS index.")
|
131 |
|
|
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
# Function to upload document
|
134 |
def upload_document(file_path, embed_model):
|
135 |
try:
|
@@ -315,10 +336,12 @@ def generate_response(user_input, model_id):
|
|
315 |
|
316 |
# Append chat history
|
317 |
func_caller = []
|
318 |
-
|
319 |
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
|
320 |
D, I = index.search(query_vector, 1)
|
321 |
|
|
|
|
|
322 |
# Retrieve document
|
323 |
retrieved_id = I[0][0]
|
324 |
retrieved_knowledge = (
|
|
|
83 |
def save_faiss_index(index):
|
84 |
try:
|
85 |
if torch.cuda.is_available():
|
86 |
+
print("Moving FAISS index back to CPU before saving.")
|
87 |
+
res = faiss.StandardGpuResources() # Allocate GPU resources
|
88 |
+
index = faiss.index_cpu_to_gpu(res, 0, index) # Move to GPU 0
|
89 |
print(f"Saving FAISS index to {faiss_index_file}.")
|
90 |
faiss.write_index(index, faiss_index_file)
|
91 |
print(f"FAISS index successfully saved to {faiss_index_file}.")
|
|
|
101 |
# Load document store and populate FAISS index
|
102 |
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledge_text.txt") # Ensure this path is correct
|
103 |
|
104 |
+
def add_document_store_to_index():
|
105 |
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
|
106 |
global document_store
|
107 |
document_store = {} # Reset document store
|
|
|
121 |
else:
|
122 |
print("Error: knowledgebase.txt not found!")
|
123 |
|
124 |
+
# Generate embeddings for all documents
|
125 |
+
embeddings = bertmodel.encode(all_texts, batch_size=32, convert_to_numpy=True).astype("float32")
|
126 |
+
|
|
|
127 |
# Add embeddings to FAISS index
|
128 |
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
|
129 |
print(f"Added {len(all_texts)} document embeddings to FAISS index.")
|
130 |
|
131 |
+
add_document_store_to_index()
|
132 |
|
133 |
+
def load_document_store():
|
134 |
+
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
|
135 |
+
global document_store
|
136 |
+
document_store = {} # Reset document store
|
137 |
+
all_texts = []
|
138 |
+
|
139 |
+
if os.path.exists(knowledgebase_file):
|
140 |
+
with open(knowledgebase_file, "r", encoding="utf-8") as f:
|
141 |
+
lines = f.readlines()
|
142 |
+
|
143 |
+
for i, line in enumerate(lines):
|
144 |
+
text = line.strip()
|
145 |
+
if text:
|
146 |
+
document_store[i] = {"text": text} # Store text mapped to FAISS ID
|
147 |
+
all_texts.append(text) # Collect all texts for embedding
|
148 |
+
|
149 |
+
print(f"Loaded {len(document_store)} documents into document_store.")
|
150 |
+
else:
|
151 |
+
print("Error: knowledgebase.txt not found!")
|
152 |
+
|
153 |
+
|
154 |
# Function to upload document
|
155 |
def upload_document(file_path, embed_model):
|
156 |
try:
|
|
|
336 |
|
337 |
# Append chat history
|
338 |
func_caller = []
|
339 |
+
|
340 |
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
|
341 |
D, I = index.search(query_vector, 1)
|
342 |
|
343 |
+
load_document_store()
|
344 |
+
|
345 |
# Retrieve document
|
346 |
retrieved_id = I[0][0]
|
347 |
retrieved_knowledge = (
|