Spaces:
Sleeping
Sleeping
File size: 13,533 Bytes
bb8e493 30dde8d bb8e493 ddf8ec6 bb8e493 e35273c 231d2b5 bb8e493 231d2b5 bb8e493 231d2b5 bb8e493 231d2b5 d3914c1 a47c058 30dde8d d3914c1 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d e7602b9 30dde8d 284c0f7 231d2b5 4721a1c 231d2b5 ddf8ec6 231d2b5 4721a1c 8cf19de 231d2b5 4721a1c 8cf19de 231d2b5 30dde8d 231d2b5 d632349 231d2b5 d632349 231d2b5 9f05250 231d2b5 8cf19de 231d2b5 1cd29ab 8cf19de 231d2b5 e384a9f 8cf19de e384a9f 8cf19de 3fb34ca 233b98c 8cf19de d469f0d 8cf19de d469f0d 284c0f7 d469f0d 231d2b5 d469f0d 8cf19de e384a9f bb8e493 ddf8ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
import os
import torch
import uuid
import shutil
import numpy as np
import faiss
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Load zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")
# Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator()
highest_label = None
loaded_models = {}
# Define upload directory and FAISS index file
UPLOAD_DIR = "/app/uploads"
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")
# Ensure upload directory exists and has write permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True)
if not os.access(UPLOAD_DIR, os.W_OK):
print(f"Fixing permissions for {UPLOAD_DIR}...")
os.chmod(UPLOAD_DIR, 0o777)
print(f"Uploads directory is ready: {UPLOAD_DIR}")
except PermissionError as e:
print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.")
document_store = {}
# Check if FAISS index file exists, otherwise initialize it
if os.path.exists(faiss_index_file):
try:
index = faiss.read_index(faiss_index_file)
if index.ntotal > 0:
print(f"FAISS index loaded with {index.ntotal} vectors.")
index.reset() # Resetting the index if non-zero entries
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Reinitialize the index
else:
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize with flat L2 distance
except Exception as e:
print(f"Error loading FAISS index: {e}, reinitializing.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize if reading fails
else:
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize if file doesn't exist
# Function to upload document
def upload_document(file_path, embed_model):
# Generate unique document ID
doc_id = uuid.uuid4().int % (2**63 - 1)
# Ensure the file is saved to the correct directory with secure handling
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
# Safely copy the file to the upload directory
shutil.copy(file_path, file_location)
# Read the content of the uploaded file
try:
with open(file_location, "r", encoding="utf-8") as f:
text = f.read()
except Exception as e:
print(f"Error reading file {file_location}: {e}")
return
# Embed the text and add it to the FAISS index
try:
vector = embed_model.encode(text).astype("float32")
index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
document_store[doc_id] = {"path": file_location, "text": text}
# Save the FAISS index after adding the document
faiss.write_index(index, faiss_index_file)
print(f"Document uploaded with doc_id: {doc_id}")
except Exception as e:
print(f"Error during document upload: {e}")
@app.route("/upload", methods=["POST"])
@app.route("/upload", methods=["POST"])
def handle_upload():
# Check if the request contains the file
if "file" not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files["file"]
# Ensure the filename is safe and construct the full file path
file_path = os.path.join(UPLOAD_DIR, file.filename)
# Ensure the upload directory exists and has correct permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True) # Ensure the directory exists
if not os.access(UPLOAD_DIR, os.W_OK): # Check write permissions
os.chmod(UPLOAD_DIR, 0o777)
except PermissionError as e:
return jsonify({"error": f"Permission error with upload directory: {e}"}), 500
try:
# Save the file to the upload directory
file.save(file_path)
except Exception as e:
return jsonify({"error": f"Error saving file: {e}"}), 500
# Process the document using the upload_document function
try:
upload_document(file_path, bertmodel) # Assuming 'bertmodel' is defined elsewhere
except Exception as e:
return jsonify({"error": f"Error processing file: {e}"}), 500
# Return success response
return jsonify({"message": "File uploaded and processed successfully"}), 200
def get_model_and_tokenizer(model_id: str):
"""
Load and cache the model and tokenizer for the given model_id.
"""
global model, tokenizer # Declare global variables to modify them within the function
if model_id not in loaded_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = accelerator.prepare(model)
loaded_models[model_id] = (model, tokenizer)
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
return loaded_models[model_id]
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
"""
Extract the core sentence needing grammar correction from the user input.
"""
match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
if match:
return match.group(0).strip()
return user_input
def classify_intent(user_input):
"""
Classify the intent of the user input using zero-shot classification.
"""
candidate_labels = [
"grammar correction", "information request", "task completion",
"dialog continuation", "personal opinion", "product inquiry",
"feedback request", "recommendation request", "clarification request",
"affirmation or agreement", "real-time data request", "current information"
]
result = classifier(user_input, candidate_labels)
highest_score_index = result['scores'].index(max(result['scores']))
highest_label = result['labels'][highest_score_index]
return highest_label
# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
"""
Reformulate the prompt based on the classified intent.
"""
core_sentence = extract_core_sentence(user_input)
prompt_templates = {
"grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
"information request": f"Provide information about: {core_sentence}",
"dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
"personal opinion": f"What is your personal opinion on: {core_sentence}?",
"product inquiry": f"Provide details about the product: {core_sentence}",
"feedback request": f"Please provide feedback on: {core_sentence}",
"recommendation request": f"Recommend something related to: {core_sentence}",
"clarification request": f"Clarify the following: {core_sentence}",
"affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
}
return prompt_templates.get(intent_label, "Input does not require a defined action.")
chat_history = [
("Hi there, how are you?", "I am fine. How are you?"),
("Tell me a joke!", "The capital of France is Paris."),
("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
]
def generate_response(user_input, model_id):
try:
model, tokenizer = get_model_and_tokenizer(model_id)
device = accelerator.device # Get the device from the accelerator
# Append chat history
func_caller = []
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
D, I = index.search(query_vector, 1)
retrieved_knowledge = document_store.get(I[0][0], {}).get("text", "No relevant information found.")
# Construct the knowledge prompt
prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
# Add the retrieved knowledge to the prompt
func_caller.append({"role": "system", "content": prompt})
for msg in chat_history:
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
highest_label_result = classify_intent(user_input)
# Reformulated prompt based on intent classification
reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
#prompt = user_input
#device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
generation_config = GenerationConfig(
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents
top_k = 5 if highest_label == "recommendation request" else None,
#attention_mask=attention_mask,
max_length=150,
repetition_penalty=1.2,
length_penalty=1.0,
no_repeat_ngram_size=2,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
#stop_sequences=["User:", "Assistant:", "\n"],
)
# Generate response
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
# Extract AI's response only (omit the prompt)
#ai_response2 = final_response.replace(reformulated_prompt, "").strip()
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
#ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
# Encode the prompt and candidates
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
# Compute similarity scores between prompt and each candidate
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
# Find the candidate with the highest similarity score
best_index = similarities.argmax()
best_response = ai_response[best_index]
# Assuming best_response is already defined and contains the generated response
if highest_label == "dialog continuation":
# Split the response into sentences
sentences = best_response.split('. ')
# Take the first three sentences and join them back together
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
# Append the user's message to the chat history
chat_history.append({'role': 'user', 'content': user_input})
chat_history.append({'role': 'assistant', 'content': best_response})
return best_response
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "openai-community/gpt2-large")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|