YALCINKAYA commited on
Commit
30dde8d
·
verified ·
1 Parent(s): 16a530a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py CHANGED
@@ -1,5 +1,9 @@
1
  import os
2
  import torch
 
 
 
 
3
  from flask import Flask, jsonify, request
4
  from flask_cors import CORS
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
@@ -30,6 +34,55 @@ accelerator = Accelerator()
30
  highest_label = None
31
  loaded_models = {}
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def get_model_and_tokenizer(model_id: str):
34
  """
35
  Load and cache the model and tokenizer for the given model_id.
@@ -108,6 +161,17 @@ def generate_response(user_input, model_id):
108
 
109
  # Append chat history
110
  func_caller = []
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  for msg in chat_history:
113
  func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
 
1
  import os
2
  import torch
3
+ import uuid
4
+ import shutil
5
+ import numpy as np
6
+ import faiss
7
  from flask import Flask, jsonify, request
8
  from flask_cors import CORS
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
 
34
  highest_label = None
35
  loaded_models = {}
36
 
37
+ # FAISS Index Setup
38
+ UPLOAD_DIR = "./uploads"
39
+ faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")
40
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
41
+ document_store = {}
42
+
43
+ if os.path.exists(faiss_index_file):
44
+ try:
45
+ index = faiss.read_index(faiss_index_file)
46
+ if index.ntotal > 0:
47
+ print(f"FAISS index loaded with {index.ntotal} vectors.")
48
+ index.reset()
49
+ index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
50
+ else:
51
+ index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
52
+ except Exception as e:
53
+ print(f"Error loading FAISS index: {e}, reinitializing.")
54
+ index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
55
+ else:
56
+ index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
57
+
58
+ # Function to upload document
59
+ def upload_document(file_path, embed_model):
60
+ doc_id = uuid.uuid4().int % (2**63 - 1)
61
+ file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
62
+ shutil.copy(file_path, file_location)
63
+
64
+ with open(file_location, "r", encoding="utf-8") as f:
65
+ text = f.read()
66
+
67
+ vector = embed_model.encode(text).astype("float32")
68
+ index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
69
+ document_store[doc_id] = {"path": file_location, "text": text}
70
+
71
+ faiss.write_index(index, faiss_index_file)
72
+ print(f"Document uploaded with doc_id: {doc_id}")
73
+
74
+ @app.route("/upload", methods=["POST"])
75
+ def handle_upload():
76
+ if "file" not in request.files:
77
+ return jsonify({"error": "No file provided"}), 400
78
+
79
+ file = request.files["file"]
80
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
81
+ file.save(file_path)
82
+
83
+ upload_document(file_path, bertmodel)
84
+ return jsonify({"message": "File uploaded successfully"})
85
+
86
  def get_model_and_tokenizer(model_id: str):
87
  """
88
  Load and cache the model and tokenizer for the given model_id.
 
161
 
162
  # Append chat history
163
  func_caller = []
164
+
165
+ query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
166
+ D, I = index.search(query_vector, 1)
167
+
168
+ retrieved_knowledge = document_store.get(I[0][0], {}).get("text", "No relevant information found.")
169
+
170
+ # Construct the knowledge prompt
171
+ prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
172
+
173
+ # Add the retrieved knowledge to the prompt
174
+ func_caller.append({"role": "system", "content": prompt})
175
 
176
  for msg in chat_history:
177
  func_caller.append({"role": "user", "content": f"{str(msg[0])}"})