opsgenius3 / app.py
YALCINKAYA's picture
Update app.py
ac91e2e verified
raw
history blame
20 kB
import os
import torch
import uuid
import shutil
import numpy as np
import faiss
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Load zero-shot classification pipeline
#classifier = pipeline("zero-shot-classification")
# Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator()
highest_label = None
loaded_models = {}
# Load model with accelerator
classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
revision="d7645e1",
device=accelerator.device # Ensures correct device placement
)
# Move model to correct device
classifier.model = accelerator.prepare(classifier.model)
# Define upload directory and FAISS index file
UPLOAD_DIR = "/app/uploads"
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")
# Ensure upload directory exists and has write permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True)
if not os.access(UPLOAD_DIR, os.W_OK):
print(f"Fixing permissions for {UPLOAD_DIR}...")
os.chmod(UPLOAD_DIR, 0o777)
print(f"Uploads directory is ready: {UPLOAD_DIR}")
except PermissionError as e:
print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.")
document_store = {}
def initialize_faiss():
if os.path.exists(faiss_index_file):
try:
print(f"FAISS index file {faiss_index_file} exists, attempting to load it.")
index = faiss.read_index(faiss_index_file)
if index.ntotal > 0:
print(f"FAISS index loaded with {index.ntotal} vectors.")
# If the index has non-zero entries, reset it
index.reset() # Resetting the index if non-zero entries
print("Index reset. Reinitializing index.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Reinitialize the index
else:
print("Loaded index has zero vectors, reinitializing index.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize with flat L2 distance
except Exception as e:
print(f"Error loading FAISS index: {e}, reinitializing a new index.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
else:
print(f"FAISS index file {faiss_index_file} does not exist, initializing a new index.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384))
# Move to GPU if available
# if torch.cuda.is_available():
# print("CUDA is available, moving FAISS index to GPU.")
# index = faiss.index_cpu_to_all_gpus(index)
# print("FAISS index is now on GPU.")
return index
def save_faiss_index(index):
try:
if torch.cuda.is_available():
print("Moving FAISS index back to CPU before saving.")
res = faiss.StandardGpuResources() # Allocate GPU resources
index = faiss.index_cpu_to_gpu(res, 0, index) # Move to GPU 0
print(f"Saving FAISS index to {faiss_index_file}.")
faiss.write_index(index, faiss_index_file)
print(f"FAISS index successfully saved to {faiss_index_file}.")
except Exception as e:
print(f"Error saving FAISS index: {e}")
# Initialize FAISS index
index = initialize_faiss()
# Save FAISS index after modifications
save_faiss_index(index)
# Load document store and populate FAISS index
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledgebase1.txt") # Ensure this path is correct
def load_document_store():
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
global document_store
document_store = {} # Reset document store
all_texts = []
if os.path.exists(knowledgebase_file):
with open(knowledgebase_file, "r", encoding="utf-8") as f:
lines = f.readlines()
for i, line in enumerate(lines):
text = line.strip()
if text:
document_store[i] = {"text": text} # Store text mapped to FAISS ID
all_texts.append(text) # Collect all texts for embedding
print(f"Loaded {len(document_store)} documents into document_store.")
else:
print("Error: knowledgebase.txt not found!")
# Generate embeddings for all documents
embeddings = bertmodel.encode(all_texts)
embeddings = embeddings.astype("float32")
# Add embeddings to FAISS index
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
print(f"Added {len(all_texts)} document embeddings to FAISS index.")
def load_document_store_once(file_path):
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
global document_store
document_store = {} # Reset document store
all_texts = []
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
if os.path.exists(file_location):
with open(file_location, "r", encoding="utf-8") as f:
lines = f.readlines()
for i, line in enumerate(lines):
text = line.strip()
if text:
document_store[i] = {"text": text} # Store text mapped to FAISS ID
all_texts.append(text) # Collect all texts for embedding
print(f"Loaded {len(document_store)} documents into document_store.")
else:
print("Error: knowledgebase.txt not found!")
# Generate embeddings for all documents
embeddings = bertmodel.encode(all_texts)
embeddings = embeddings.astype("float32")
# Add embeddings to FAISS index
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
print(f"Added {len(all_texts)} document embeddings to FAISS index.")
# Function to upload document
def upload_document(file_path, embed_model):
try:
# Generate unique document ID
doc_id = uuid.uuid4().int % (2**63 - 1)
# Ensure the file is saved to the correct directory with secure handling
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
print(f"Saving file to: {file_location}") # Log the location
# Safely copy the file to the upload directory
shutil.copy(file_path, file_location)
# Read the content of the uploaded file
try:
with open(file_location, "r", encoding="utf-8") as f:
text = f.read()
except Exception as e:
print(f"Error reading file {file_location}: {e}")
return {"error": f"Error reading file: {e}"}, 507 # Error while reading file
# Embed the text and add it to the FAISS index
try:
# Ensure the embedding model is valid
if embed_model is None:
raise ValueError("Embedding model is not initialized properly.")
vector = embed_model.encode(text).astype("float32")
print(f"Generated vector for document {doc_id}: {vector}") # Log vector
index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
document_store[doc_id] = {"path": file_location, "text": text}
# Log FAISS index file path
print(f"Saving FAISS index to: {faiss_index_file}") # Log the file path
# Save the FAISS index after adding the document
try:
faiss.write_index(index, faiss_index_file)
print(f"Document uploaded with doc_id: {doc_id}")
except Exception as e:
print(f"Error saving FAISS index: {e}")
return {"error": f"Error saving FAISS index: {e}"}, 508 # Error while saving FAISS index
except Exception as e:
print(f"Error during document upload: {e}")
return {"error": f"Error during document upload: {e}"}, 509 # Error during embedding or FAISS processing
except Exception as e:
print(f"Unexpected error: {e}")
return {"error": f"Unexpected error: {e}"}, 500 # General error
@app.route("/list_uploads", methods=["GET"])
def list_uploaded_files():
try:
# Ensure the upload directory exists
if not os.path.exists(UPLOAD_DIR):
return jsonify({"error": "Upload directory does not exist"}), 400
# List all files in the upload directory
files = os.listdir(UPLOAD_DIR)
if not files:
return jsonify({"message": "No files found in the upload directory"}), 200
return jsonify({"files": files}), 200
except Exception as e:
return jsonify({"error": f"Error listing files: {e}"}), 504
@app.route("/upload", methods=["POST"])
def handle_upload():
# Check if the request contains the file
if "file" not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files["file"]
# Ensure the filename is safe and construct the full file path
file_path = os.path.join(UPLOAD_DIR, file.filename)
# Ensure the upload directory exists and has correct permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True) # Ensure the directory exists
if not os.access(UPLOAD_DIR, os.W_OK): # Check write permissions
os.chmod(UPLOAD_DIR, 0o777)
except PermissionError as e:
return jsonify({"error": f"Permission error with upload directory: {e}"}), 501
try:
# Save the file to the upload directory
file.save(file_path)
load_document_store() # Reload FAISS index
# Now that the document is uploaded, call load_document_store()
print(f"File uploaded successfully. Calling load_document_store()...")
except Exception as e:
return jsonify({"error": f"Error saving file: {e}"}), 502
# Process the document using the upload_document function
try:
load_document_store_once(file_path)
# upload_document(file_path, bertmodel) # Assuming 'bertmodel' is defined elsewhere
except Exception as e:
return jsonify({"error": f"Error processing file: {e}"}), 503
# Return success response
return jsonify({"message": "File uploaded and processed successfully"}), 200
def get_model_and_tokenizer(model_id: str):
"""
Load and cache the model and tokenizer for the given model_id.
"""
global model, tokenizer # Declare global variables to modify them within the function
if model_id not in loaded_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = accelerator.prepare(model)
loaded_models[model_id] = (model, tokenizer)
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
return loaded_models[model_id]
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
"""
Extract the core sentence needing grammar correction from the user input.
"""
match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
if match:
return match.group(0).strip()
return user_input
def classify_intent(user_input):
"""
Classify the intent of the user input using zero-shot classification.
"""
candidate_labels = [
"grammar correction", "information request", "task completion",
"dialog continuation", "personal opinion", "product inquiry",
"feedback request", "recommendation request", "clarification request",
"affirmation or agreement", "real-time data request", "current information"
]
result = classifier(user_input, candidate_labels)
highest_score_index = result['scores'].index(max(result['scores']))
highest_label = result['labels'][highest_score_index]
return highest_label
# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
"""
Reformulate the prompt based on the classified intent.
"""
core_sentence = extract_core_sentence(user_input)
prompt_templates = {
"grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
"information request": f"Provide information about: {core_sentence}",
"dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
"personal opinion": f"What is your personal opinion on: {core_sentence}?",
"product inquiry": f"Provide details about the product: {core_sentence}",
"feedback request": f"Please provide feedback on: {core_sentence}",
"recommendation request": f"Recommend something related to: {core_sentence}",
"clarification request": f"Clarify the following: {core_sentence}",
"affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
}
return prompt_templates.get(intent_label, "Input does not require a defined action.")
chat_history = [
("Hi there, how are you?", "I am fine. How are you?"),
("Tell me a joke!", "The capital of France is Paris."),
("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
]
def generate_response(user_input, model_id):
try:
model, tokenizer = get_model_and_tokenizer(model_id)
device = accelerator.device # Get the device from the accelerator
# Append chat history
func_caller = []
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
D, I = index.search(query_vector, 1)
# Retrieve document
retrieved_id = I[0][0]
retrieved_knowledge = (
document_store.get(retrieved_id, {}).get("text", "No relevant information found.")
if retrieved_id != -1 else "No relevant information found."
)
# Construct the knowledge prompt
prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
# Log the prompt (you can change this to a logging library if needed)
print(f"Generated prompt: {prompt}") # <-- Log the prompt here
# Add the retrieved knowledge to the prompt
func_caller.append({"role": "system", "content": prompt})
for msg in chat_history:
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
highest_label_result = classify_intent(user_input)
# Reformulated prompt based on intent classification
reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
#prompt = user_input
#device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
generation_config = GenerationConfig(
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents
top_k = 5 if highest_label == "recommendation request" else None,
#attention_mask=attention_mask,
max_length=150,
repetition_penalty=1.2,
length_penalty=1.0,
no_repeat_ngram_size=2,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
#stop_sequences=["User:", "Assistant:", "\n"],
)
# Generate response
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
# Extract AI's response only (omit the prompt)
#ai_response2 = final_response.replace(reformulated_prompt, "").strip()
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
#ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
# Encode the prompt and candidates
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
# Compute similarity scores between prompt and each candidate
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
# Find the candidate with the highest similarity score
best_index = similarities.argmax()
best_response = ai_response[best_index]
# Assuming best_response is already defined and contains the generated response
if highest_label == "dialog continuation":
# Split the response into sentences
sentences = best_response.split('. ')
# Take the first three sentences and join them back together
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
# Append the user's message to the chat history
chat_history.append({'role': 'user', 'content': user_input})
chat_history.append({'role': 'assistant', 'content': best_response})
return best_response
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct")
#model_id = data.get("model_id", "openai-community/gpt2-large")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)