opsgenius3 / app.py
YALCINKAYA's picture
Update app.py
a47c058 verified
raw
history blame
13.5 kB
import os
import torch
import uuid
import shutil
import numpy as np
import faiss
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Load zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")
# Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator()
highest_label = None
loaded_models = {}
# Define upload directory and FAISS index file
UPLOAD_DIR = "/app/uploads"
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")
# Ensure upload directory exists and has write permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True)
if not os.access(UPLOAD_DIR, os.W_OK):
print(f"Fixing permissions for {UPLOAD_DIR}...")
os.chmod(UPLOAD_DIR, 0o777)
print(f"Uploads directory is ready: {UPLOAD_DIR}")
except PermissionError as e:
print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.")
document_store = {}
# Check if FAISS index file exists, otherwise initialize it
if os.path.exists(faiss_index_file):
try:
index = faiss.read_index(faiss_index_file)
if index.ntotal > 0:
print(f"FAISS index loaded with {index.ntotal} vectors.")
index.reset() # Resetting the index if non-zero entries
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Reinitialize the index
else:
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize with flat L2 distance
except Exception as e:
print(f"Error loading FAISS index: {e}, reinitializing.")
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize if reading fails
else:
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize if file doesn't exist
# Function to upload document
def upload_document(file_path, embed_model):
# Generate unique document ID
doc_id = uuid.uuid4().int % (2**63 - 1)
# Ensure the file is saved to the correct directory with secure handling
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
# Safely copy the file to the upload directory
shutil.copy(file_path, file_location)
# Read the content of the uploaded file
try:
with open(file_location, "r", encoding="utf-8") as f:
text = f.read()
except Exception as e:
print(f"Error reading file {file_location}: {e}")
return
# Embed the text and add it to the FAISS index
try:
vector = embed_model.encode(text).astype("float32")
index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
document_store[doc_id] = {"path": file_location, "text": text}
# Save the FAISS index after adding the document
faiss.write_index(index, faiss_index_file)
print(f"Document uploaded with doc_id: {doc_id}")
except Exception as e:
print(f"Error during document upload: {e}")
@app.route("/upload", methods=["POST"])
@app.route("/upload", methods=["POST"])
def handle_upload():
# Check if the request contains the file
if "file" not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files["file"]
# Ensure the filename is safe and construct the full file path
file_path = os.path.join(UPLOAD_DIR, file.filename)
# Ensure the upload directory exists and has correct permissions
try:
os.makedirs(UPLOAD_DIR, exist_ok=True) # Ensure the directory exists
if not os.access(UPLOAD_DIR, os.W_OK): # Check write permissions
os.chmod(UPLOAD_DIR, 0o777)
except PermissionError as e:
return jsonify({"error": f"Permission error with upload directory: {e}"}), 500
try:
# Save the file to the upload directory
file.save(file_path)
except Exception as e:
return jsonify({"error": f"Error saving file: {e}"}), 500
# Process the document using the upload_document function
try:
upload_document(file_path, bertmodel) # Assuming 'bertmodel' is defined elsewhere
except Exception as e:
return jsonify({"error": f"Error processing file: {e}"}), 500
# Return success response
return jsonify({"message": "File uploaded and processed successfully"}), 200
def get_model_and_tokenizer(model_id: str):
"""
Load and cache the model and tokenizer for the given model_id.
"""
global model, tokenizer # Declare global variables to modify them within the function
if model_id not in loaded_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = accelerator.prepare(model)
loaded_models[model_id] = (model, tokenizer)
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
return loaded_models[model_id]
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
"""
Extract the core sentence needing grammar correction from the user input.
"""
match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
if match:
return match.group(0).strip()
return user_input
def classify_intent(user_input):
"""
Classify the intent of the user input using zero-shot classification.
"""
candidate_labels = [
"grammar correction", "information request", "task completion",
"dialog continuation", "personal opinion", "product inquiry",
"feedback request", "recommendation request", "clarification request",
"affirmation or agreement", "real-time data request", "current information"
]
result = classifier(user_input, candidate_labels)
highest_score_index = result['scores'].index(max(result['scores']))
highest_label = result['labels'][highest_score_index]
return highest_label
# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
"""
Reformulate the prompt based on the classified intent.
"""
core_sentence = extract_core_sentence(user_input)
prompt_templates = {
"grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
"information request": f"Provide information about: {core_sentence}",
"dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
"personal opinion": f"What is your personal opinion on: {core_sentence}?",
"product inquiry": f"Provide details about the product: {core_sentence}",
"feedback request": f"Please provide feedback on: {core_sentence}",
"recommendation request": f"Recommend something related to: {core_sentence}",
"clarification request": f"Clarify the following: {core_sentence}",
"affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
}
return prompt_templates.get(intent_label, "Input does not require a defined action.")
chat_history = [
("Hi there, how are you?", "I am fine. How are you?"),
("Tell me a joke!", "The capital of France is Paris."),
("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
]
def generate_response(user_input, model_id):
try:
model, tokenizer = get_model_and_tokenizer(model_id)
device = accelerator.device # Get the device from the accelerator
# Append chat history
func_caller = []
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
D, I = index.search(query_vector, 1)
retrieved_knowledge = document_store.get(I[0][0], {}).get("text", "No relevant information found.")
# Construct the knowledge prompt
prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
# Add the retrieved knowledge to the prompt
func_caller.append({"role": "system", "content": prompt})
for msg in chat_history:
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
highest_label_result = classify_intent(user_input)
# Reformulated prompt based on intent classification
reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
#prompt = user_input
#device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
generation_config = GenerationConfig(
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents
top_k = 5 if highest_label == "recommendation request" else None,
#attention_mask=attention_mask,
max_length=150,
repetition_penalty=1.2,
length_penalty=1.0,
no_repeat_ngram_size=2,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
#stop_sequences=["User:", "Assistant:", "\n"],
)
# Generate response
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
# Extract AI's response only (omit the prompt)
#ai_response2 = final_response.replace(reformulated_prompt, "").strip()
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
#ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
# Encode the prompt and candidates
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
# Compute similarity scores between prompt and each candidate
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
# Find the candidate with the highest similarity score
best_index = similarities.argmax()
best_response = ai_response[best_index]
# Assuming best_response is already defined and contains the generated response
if highest_label == "dialog continuation":
# Split the response into sentences
sentences = best_response.split('. ')
# Take the first three sentences and join them back together
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
# Append the user's message to the chat history
chat_history.append({'role': 'user', 'content': user_input})
chat_history.append({'role': 'assistant', 'content': best_response})
return best_response
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "openai-community/gpt2-large")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)