import os
import torch
import uuid
import shutil
import numpy as np
import faiss
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline 
from sentence_transformers import SentenceTransformer, util
  
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
  
# Load zero-shot classification pipeline
#classifier = pipeline("zero-shot-classification")

 # Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2')  # Lightweight, efficient model; choose larger if needed

# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator() 
highest_label = None 
loaded_models = {}

# Load model with accelerator
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    revision="d7645e1",
    device=accelerator.device  # Ensures correct device placement
)

# Move model to correct device
classifier.model = accelerator.prepare(classifier.model)
 
# Define upload directory and FAISS index file
UPLOAD_DIR = "/app/uploads"
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")

# Ensure upload directory exists and has write permissions
try:
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    if not os.access(UPLOAD_DIR, os.W_OK):
        print(f"Fixing permissions for {UPLOAD_DIR}...")
        os.chmod(UPLOAD_DIR, 0o777)
    print(f"Uploads directory is ready: {UPLOAD_DIR}")
except PermissionError as e:
    print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.")
    
document_store = {}

def initialize_faiss():
    if os.path.exists(faiss_index_file):
        try:
            print(f"FAISS index file {faiss_index_file} exists, attempting to load it.") 
            index = faiss.read_index(faiss_index_file)
            if index.ntotal > 0:
                print(f"FAISS index loaded with {index.ntotal} vectors.")
                # If the index has non-zero entries, reset it
                index.reset()  # Resetting the index if non-zero entries
                print("Index reset. Reinitializing index.")
                index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  # Reinitialize the index
            else:
                print("Loaded index has zero vectors, reinitializing index.")
                index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  # Initialize with flat L2 distance
            
        except Exception as e:
            print(f"Error loading FAISS index: {e}, reinitializing a new index.")
            index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  
    else:
        print(f"FAISS index file {faiss_index_file} does not exist, initializing a new index.")
        index = faiss.IndexIDMap(faiss.IndexFlatL2(384))

    # Move to GPU if available
    # if torch.cuda.is_available():
    # print("CUDA is available, moving FAISS index to GPU.")
    # index = faiss.index_cpu_to_all_gpus(index)  
        
    # print("FAISS index is now on GPU.")

    return index

def save_faiss_index(index):
    try:
        if torch.cuda.is_available():
            print("Moving FAISS index back to CPU before saving.") 
            res = faiss.StandardGpuResources()  # Allocate GPU resources
            index = faiss.index_cpu_to_gpu(res, 0, index)  # Move to GPU 0
        print(f"Saving FAISS index to {faiss_index_file}.")
        faiss.write_index(index, faiss_index_file)
        print(f"FAISS index successfully saved to {faiss_index_file}.")
    except Exception as e:
        print(f"Error saving FAISS index: {e}")

# Initialize FAISS index
index = initialize_faiss()

# Save FAISS index after modifications
save_faiss_index(index)

# Load document store and populate FAISS index
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledgebase1.txt")  # Ensure this path is correct
   
def load_document_store():
    """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
    global document_store
    document_store = {}  # Reset document store
    all_texts = []
    
    if os.path.exists(knowledgebase_file):
        with open(knowledgebase_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        for i, line in enumerate(lines):
            text = line.strip()
            if text:
                document_store[i] = {"text": text}  # Store text mapped to FAISS ID
                all_texts.append(text)  # Collect all texts for embedding

        print(f"Loaded {len(document_store)} documents into document_store.")
    else:
        print("Error: knowledgebase.txt not found!")
        
    # Generate embeddings for all documents
    embeddings = bertmodel.encode(all_texts)
    embeddings = embeddings.astype("float32")
    
    # Add embeddings to FAISS index
    index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
    print(f"Added {len(all_texts)} document embeddings to FAISS index.") 

def load_document_store_once(file_path):
    """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
    global document_store
    document_store = {}  # Reset document store
    all_texts = []
    file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
    if os.path.exists(file_location):
        with open(file_location, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        for i, line in enumerate(lines):
            text = line.strip()
            if text:
                document_store[i] = {"text": text}  # Store text mapped to FAISS ID
                all_texts.append(text)  # Collect all texts for embedding

        print(f"Loaded {len(document_store)} documents into document_store.")
    else:
        print("Error: knowledgebase.txt not found!")
    
    # Generate embeddings for all documents
    embeddings = bertmodel.encode(all_texts)
    embeddings = embeddings.astype("float32")
    
    # Add embeddings to FAISS index
    index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
    print(f"Added {len(all_texts)} document embeddings to FAISS index.")

 
# Function to upload document
def upload_document(file_path, embed_model):
    try:
        # Generate unique document ID
        doc_id = uuid.uuid4().int % (2**63 - 1)
        
        # Ensure the file is saved to the correct directory with secure handling
        file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
        print(f"Saving file to: {file_location}")  # Log the location
        
        # Safely copy the file to the upload directory
        shutil.copy(file_path, file_location)
        
        # Read the content of the uploaded file
        try:
            with open(file_location, "r", encoding="utf-8") as f:
                text = f.read()
        except Exception as e:
            print(f"Error reading file {file_location}: {e}")
            return {"error": f"Error reading file: {e}"}, 507  # Error while reading file
        
        # Embed the text and add it to the FAISS index
        try:
            # Ensure the embedding model is valid
            if embed_model is None:
                raise ValueError("Embedding model is not initialized properly.")
            
            vector = embed_model.encode(text).astype("float32")
            print(f"Generated vector for document {doc_id}: {vector}")  # Log vector
            
            index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
            document_store[doc_id] = {"path": file_location, "text": text}
            
            # Log FAISS index file path
            print(f"Saving FAISS index to: {faiss_index_file}")  # Log the file path
            
            # Save the FAISS index after adding the document
            try:
                faiss.write_index(index, faiss_index_file)
                print(f"Document uploaded with doc_id: {doc_id}")
            except Exception as e:
                print(f"Error saving FAISS index: {e}")
                return {"error": f"Error saving FAISS index: {e}"}, 508  # Error while saving FAISS index
            
        except Exception as e:
            print(f"Error during document upload: {e}")
            return {"error": f"Error during document upload: {e}"}, 509  # Error during embedding or FAISS processing
    
    except Exception as e:
        print(f"Unexpected error: {e}")
        return {"error": f"Unexpected error: {e}"}, 500  # General error

@app.route("/list_uploads", methods=["GET"])
def list_uploaded_files():
    try:
        # Ensure the upload directory exists
        if not os.path.exists(UPLOAD_DIR):
            return jsonify({"error": "Upload directory does not exist"}), 400

        # List all files in the upload directory
        files = os.listdir(UPLOAD_DIR)
        
        if not files:
            return jsonify({"message": "No files found in the upload directory"}), 200
        
        return jsonify({"files": files}), 200
    except Exception as e:
        return jsonify({"error": f"Error listing files: {e}"}), 504

@app.route("/upload", methods=["POST"])
def handle_upload():
    # Check if the request contains the file
    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400
    
    file = request.files["file"]
    
    # Ensure the filename is safe and construct the full file path
    file_path = os.path.join(UPLOAD_DIR, file.filename)
    
    # Ensure the upload directory exists and has correct permissions
    try:
        os.makedirs(UPLOAD_DIR, exist_ok=True)  # Ensure the directory exists
        if not os.access(UPLOAD_DIR, os.W_OK):  # Check write permissions
            os.chmod(UPLOAD_DIR, 0o777)
    except PermissionError as e:
        return jsonify({"error": f"Permission error with upload directory: {e}"}), 501

    try:
        # Save the file to the upload directory
        file.save(file_path)
        load_document_store()  # Reload FAISS index
        # Now that the document is uploaded, call load_document_store()
        print(f"File uploaded successfully. Calling load_document_store()...") 
    except Exception as e:
        return jsonify({"error": f"Error saving file: {e}"}), 502
    
    # Process the document using the upload_document function
    try:
        load_document_store_once(file_path)
    #    upload_document(file_path, bertmodel)  # Assuming 'bertmodel' is defined elsewhere
    except Exception as e:
        return jsonify({"error": f"Error processing file: {e}"}), 503
    
    # Return success response
    return jsonify({"message": "File uploaded and processed successfully"}), 200

def get_model_and_tokenizer(model_id: str):
    """
    Load and cache the model and tokenizer for the given model_id.
    """
    global model, tokenizer  # Declare global variables to modify them within the function
    if model_id not in loaded_models:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            model = AutoModelForCausalLM.from_pretrained(model_id)
            model = accelerator.prepare(model)
            loaded_models[model_id] = (model, tokenizer)
        except Exception as e:
            print("Error loading model:")
            print(traceback.format_exc())  # Logs the full error traceback
            raise e  # Reraise the exception to stop execution
    return loaded_models[model_id]
        
        
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
    """
    Extract the core sentence needing grammar correction from the user input.
    """
    match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
    if match:
        return match.group(0).strip()
    return user_input

def classify_intent(user_input):
    """
    Classify the intent of the user input using zero-shot classification.
    """
    candidate_labels = [
        "grammar correction", "information request", "task completion", 
        "dialog continuation", "personal opinion", "product inquiry",
        "feedback request", "recommendation request", "clarification request", 
        "affirmation or agreement", "real-time data request", "current information"
    ]
    result = classifier(user_input, candidate_labels)
    highest_score_index = result['scores'].index(max(result['scores']))
    highest_label = result['labels'][highest_score_index]
    return highest_label


# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
    """
    Reformulate the prompt based on the classified intent.
    """
    core_sentence = extract_core_sentence(user_input)
    prompt_templates = {
        "grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
        "information request": f"Provide information about: {core_sentence}",
        "dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
        "personal opinion": f"What is your personal opinion on: {core_sentence}?",
        "product inquiry": f"Provide details about the product: {core_sentence}",
        "feedback request": f"Please provide feedback on: {core_sentence}",
        "recommendation request": f"Recommend something related to: {core_sentence}",
        "clarification request": f"Clarify the following: {core_sentence}",
        "affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
    }
    return prompt_templates.get(intent_label, "Input does not require a defined action.")

chat_history = [
            ("Hi there, how are you?", "I am fine. How are you?"),
            ("Tell me a joke!", "The capital of France is Paris."),
            ("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
            ]
    
 
def generate_response(user_input, model_id):
    try:
        model, tokenizer = get_model_and_tokenizer(model_id)
        device = accelerator.device  # Get the device from the accelerator
       
        # Append chat history
        func_caller = []
          
        query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
        D, I = index.search(query_vector, 1)

        # Retrieve document
        retrieved_id = I[0][0]
        retrieved_knowledge = (
            document_store.get(retrieved_id, {}).get("text", "No relevant information found.")
            if retrieved_id != -1 else "No relevant information found."
        )
        
        # Construct the knowledge prompt
        prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
        
        # Log the prompt (you can change this to a logging library if needed)
        print(f"Generated prompt: {prompt}")  # <-- Log the prompt here
        
        # Add the retrieved knowledge to the prompt
        func_caller.append({"role": "system", "content": prompt}) 

        for msg in chat_history:
            func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
            func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
 
        highest_label_result = classify_intent(user_input) 

        # Reformulated prompt based on intent classification
        reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
          
        func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
        formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])

        #prompt = user_input
        #device = accelerator.device  # Automatically uses GPU or CPU based on accelerator setup
        
        generation_config = GenerationConfig(
            do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"),  # True if dialog continuation, else False
            temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None),  # Set temperature for specific intents 
            top_k = 5 if highest_label == "recommendation request" else None, 
            #attention_mask=attention_mask,
            max_length=150, 
            repetition_penalty=1.2, 
            length_penalty=1.0, 
            no_repeat_ngram_size=2, 
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            #stop_sequences=["User:", "Assistant:", "\n"],
            )
        
        # Generate response
        gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
        gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
        final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
        # Extract AI's response only (omit the prompt)
        #ai_response2 = final_response.replace(reformulated_prompt, "").strip()
        ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
        #ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
        ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
       
        # Encode the prompt and candidates
        prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
        candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
        
        # Compute similarity scores between prompt and each candidate
        similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
        
        # Find the candidate with the highest similarity score
        
        best_index = similarities.argmax()
        best_response = ai_response[best_index]
        
        # Assuming best_response is already defined and contains the generated response
        
        if highest_label == "dialog continuation":
            # Split the response into sentences
            sentences = best_response.split('. ')
            # Take the first three sentences and join them back together
            best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
 
        # Append the user's message to the chat history
        chat_history.append({'role': 'user', 'content': user_input})
        chat_history.append({'role': 'assistant', 'content': best_response})

        return best_response
        
    except Exception as e:
        print("Error in generate_response:")
        print(traceback.format_exc())  # Logs the full traceback
        raise e
 
@app.route("/send_message", methods=["POST"])
def handle_post_request():
    try:
        data = request.get_json()
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400

        message = data.get("inputs", "No message provided.")
        model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct")
        
        #model_id = data.get("model_id", "openai-community/gpt2-large")
        
        print(f"Processing request with model_id: {model_id}")
        model_response = generate_response(message, model_id)

        return jsonify({
            "received_message": model_response,
            "model_id": model_id,
            "status": "POST request successful!"
        })
    
    except Exception as e:
        print("Error handling POST request:")
        print(traceback.format_exc())  # Logs the full traceback
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)