import os import torch import uuid import shutil import numpy as np import faiss from flask import Flask, jsonify, request from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig from accelerate import Accelerator import re import traceback from transformers import pipeline from sentence_transformers import SentenceTransformer, util # Set the HF_HOME environment variable to a writable directory os.environ["HF_HOME"] = "/workspace/huggingface_cache" app = Flask(__name__) # Enable CORS for specific origins CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}}) # Load zero-shot classification pipeline #classifier = pipeline("zero-shot-classification") # Load Sentence-BERT model bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed # Global variables for model and tokenizer model = None tokenizer = None accelerator = Accelerator() highest_label = None loaded_models = {} # Load model with accelerator classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", revision="d7645e1", device=accelerator.device # Ensures correct device placement ) # Move model to correct device classifier.model = accelerator.prepare(classifier.model) # Define upload directory and FAISS index file UPLOAD_DIR = "/app/uploads" faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin") # Ensure upload directory exists and has write permissions try: os.makedirs(UPLOAD_DIR, exist_ok=True) if not os.access(UPLOAD_DIR, os.W_OK): print(f"Fixing permissions for {UPLOAD_DIR}...") os.chmod(UPLOAD_DIR, 0o777) print(f"Uploads directory is ready: {UPLOAD_DIR}") except PermissionError as e: print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.") document_store = {} def initialize_faiss(): if os.path.exists(faiss_index_file): try: print(f"FAISS index file {faiss_index_file} exists, attempting to load it.") index = faiss.read_index(faiss_index_file) if index.ntotal > 0: print(f"FAISS index loaded with {index.ntotal} vectors.") # If the index has non-zero entries, reset it index.reset() # Resetting the index if non-zero entries print("Index reset. Reinitializing index.") index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Reinitialize the index else: print("Loaded index has zero vectors, reinitializing index.") index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize with flat L2 distance except Exception as e: print(f"Error loading FAISS index: {e}, reinitializing a new index.") index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) else: print(f"FAISS index file {faiss_index_file} does not exist, initializing a new index.") index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Move to GPU if available # if torch.cuda.is_available(): # print("CUDA is available, moving FAISS index to GPU.") # index = faiss.index_cpu_to_all_gpus(index) # print("FAISS index is now on GPU.") return index def save_faiss_index(index): try: if torch.cuda.is_available(): print("Moving FAISS index back to CPU before saving.") res = faiss.StandardGpuResources() # Allocate GPU resources index = faiss.index_cpu_to_gpu(res, 0, index) # Move to GPU 0 print(f"Saving FAISS index to {faiss_index_file}.") faiss.write_index(index, faiss_index_file) print(f"FAISS index successfully saved to {faiss_index_file}.") except Exception as e: print(f"Error saving FAISS index: {e}") # Initialize FAISS index index = initialize_faiss() # Save FAISS index after modifications save_faiss_index(index) # Load document store and populate FAISS index knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledgebase1.txt") # Ensure this path is correct def load_document_store(): """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings""" global document_store document_store = {} # Reset document store all_texts = [] if os.path.exists(knowledgebase_file): with open(knowledgebase_file, "r", encoding="utf-8") as f: lines = f.readlines() for i, line in enumerate(lines): text = line.strip() if text: document_store[i] = {"text": text} # Store text mapped to FAISS ID all_texts.append(text) # Collect all texts for embedding print(f"Loaded {len(document_store)} documents into document_store.") else: print("Error: knowledgebase.txt not found!") # Generate embeddings for all documents embeddings = bertmodel.encode(all_texts) embeddings = embeddings.astype("float32") # Add embeddings to FAISS index index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64)) print(f"Added {len(all_texts)} document embeddings to FAISS index.") def load_document_store_once(file_path): """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings""" global document_store document_store = {} # Reset document store all_texts = [] file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path)) if os.path.exists(file_location): with open(file_location, "r", encoding="utf-8") as f: lines = f.readlines() for i, line in enumerate(lines): text = line.strip() if text: document_store[i] = {"text": text} # Store text mapped to FAISS ID all_texts.append(text) # Collect all texts for embedding print(f"Loaded {len(document_store)} documents into document_store.") else: print("Error: knowledgebase.txt not found!") # Generate embeddings for all documents embeddings = bertmodel.encode(all_texts) embeddings = embeddings.astype("float32") # Add embeddings to FAISS index index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64)) print(f"Added {len(all_texts)} document embeddings to FAISS index.") # Function to upload document def upload_document(file_path, embed_model): try: # Generate unique document ID doc_id = uuid.uuid4().int % (2**63 - 1) # Ensure the file is saved to the correct directory with secure handling file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path)) print(f"Saving file to: {file_location}") # Log the location # Safely copy the file to the upload directory shutil.copy(file_path, file_location) # Read the content of the uploaded file try: with open(file_location, "r", encoding="utf-8") as f: text = f.read() except Exception as e: print(f"Error reading file {file_location}: {e}") return {"error": f"Error reading file: {e}"}, 507 # Error while reading file # Embed the text and add it to the FAISS index try: # Ensure the embedding model is valid if embed_model is None: raise ValueError("Embedding model is not initialized properly.") vector = embed_model.encode(text).astype("float32") print(f"Generated vector for document {doc_id}: {vector}") # Log vector index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64)) document_store[doc_id] = {"path": file_location, "text": text} # Log FAISS index file path print(f"Saving FAISS index to: {faiss_index_file}") # Log the file path # Save the FAISS index after adding the document try: faiss.write_index(index, faiss_index_file) print(f"Document uploaded with doc_id: {doc_id}") except Exception as e: print(f"Error saving FAISS index: {e}") return {"error": f"Error saving FAISS index: {e}"}, 508 # Error while saving FAISS index except Exception as e: print(f"Error during document upload: {e}") return {"error": f"Error during document upload: {e}"}, 509 # Error during embedding or FAISS processing except Exception as e: print(f"Unexpected error: {e}") return {"error": f"Unexpected error: {e}"}, 500 # General error @app.route("/list_uploads", methods=["GET"]) def list_uploaded_files(): try: # Ensure the upload directory exists if not os.path.exists(UPLOAD_DIR): return jsonify({"error": "Upload directory does not exist"}), 400 # List all files in the upload directory files = os.listdir(UPLOAD_DIR) if not files: return jsonify({"message": "No files found in the upload directory"}), 200 return jsonify({"files": files}), 200 except Exception as e: return jsonify({"error": f"Error listing files: {e}"}), 504 @app.route("/upload", methods=["POST"]) def handle_upload(): # Check if the request contains the file if "file" not in request.files: return jsonify({"error": "No file provided"}), 400 file = request.files["file"] # Ensure the filename is safe and construct the full file path file_path = os.path.join(UPLOAD_DIR, file.filename) # Ensure the upload directory exists and has correct permissions try: os.makedirs(UPLOAD_DIR, exist_ok=True) # Ensure the directory exists if not os.access(UPLOAD_DIR, os.W_OK): # Check write permissions os.chmod(UPLOAD_DIR, 0o777) except PermissionError as e: return jsonify({"error": f"Permission error with upload directory: {e}"}), 501 try: # Save the file to the upload directory file.save(file_path) load_document_store() # Reload FAISS index # Now that the document is uploaded, call load_document_store() print(f"File uploaded successfully. Calling load_document_store()...") except Exception as e: return jsonify({"error": f"Error saving file: {e}"}), 502 # Process the document using the upload_document function try: load_document_store_once(file_path) # upload_document(file_path, bertmodel) # Assuming 'bertmodel' is defined elsewhere except Exception as e: return jsonify({"error": f"Error processing file: {e}"}), 503 # Return success response return jsonify({"message": "File uploaded and processed successfully"}), 200 def get_model_and_tokenizer(model_id: str): """ Load and cache the model and tokenizer for the given model_id. """ global model, tokenizer # Declare global variables to modify them within the function if model_id not in loaded_models: try: tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) model = accelerator.prepare(model) loaded_models[model_id] = (model, tokenizer) except Exception as e: print("Error loading model:") print(traceback.format_exc()) # Logs the full error traceback raise e # Reraise the exception to stop execution return loaded_models[model_id] # Extract the core sentence needing grammar correction def extract_core_sentence(user_input): """ Extract the core sentence needing grammar correction from the user input. """ match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE) if match: return match.group(0).strip() return user_input def classify_intent(user_input): """ Classify the intent of the user input using zero-shot classification. """ candidate_labels = [ "grammar correction", "information request", "task completion", "dialog continuation", "personal opinion", "product inquiry", "feedback request", "recommendation request", "clarification request", "affirmation or agreement", "real-time data request", "current information" ] result = classifier(user_input, candidate_labels) highest_score_index = result['scores'].index(max(result['scores'])) highest_label = result['labels'][highest_score_index] return highest_label # Reformulate the prompt based on intent # Function to generate reformulated prompts def reformulate_prompt(user_input, intent_label): """ Reformulate the prompt based on the classified intent. """ core_sentence = extract_core_sentence(user_input) prompt_templates = { "grammar correction": f"Fix the grammar in this sentence: {core_sentence}", "information request": f"Provide information about: {core_sentence}", "dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n", "personal opinion": f"What is your personal opinion on: {core_sentence}?", "product inquiry": f"Provide details about the product: {core_sentence}", "feedback request": f"Please provide feedback on: {core_sentence}", "recommendation request": f"Recommend something related to: {core_sentence}", "clarification request": f"Clarify the following: {core_sentence}", "affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}", } return prompt_templates.get(intent_label, "Input does not require a defined action.") chat_history = [ ("Hi there, how are you?", "I am fine. How are you?"), ("Tell me a joke!", "The capital of France is Paris."), ("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"), ] def generate_response(user_input, model_id): try: model, tokenizer = get_model_and_tokenizer(model_id) device = accelerator.device # Get the device from the accelerator # Append chat history func_caller = [] query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32") D, I = index.search(query_vector, 1) # Retrieve document retrieved_id = I[0][0] retrieved_knowledge = ( document_store.get(retrieved_id, {}).get("text", "No relevant information found.") if retrieved_id != -1 else "No relevant information found." ) # Construct the knowledge prompt prompt = f"Use the following knowledge:\n{retrieved_knowledge}" # Log the prompt (you can change this to a logging library if needed) print(f"Generated prompt: {prompt}") # <-- Log the prompt here # Add the retrieved knowledge to the prompt func_caller.append({"role": "system", "content": prompt}) for msg in chat_history: func_caller.append({"role": "user", "content": f"{str(msg[0])}"}) func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"}) highest_label_result = classify_intent(user_input) # Reformulated prompt based on intent classification reformulated_prompt = reformulate_prompt(user_input, highest_label_result) func_caller.append({"role": "user", "content": f'{reformulated_prompt}'}) formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller]) #prompt = user_input #device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup generation_config = GenerationConfig( do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents top_k = 5 if highest_label == "recommendation request" else None, #attention_mask=attention_mask, max_length=150, repetition_penalty=1.2, length_penalty=1.0, no_repeat_ngram_size=2, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, #stop_sequences=["User:", "Assistant:", "\n"], ) # Generate response gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config) final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True) # Extract AI's response only (omit the prompt) #ai_response2 = final_response.replace(reformulated_prompt, "").strip() ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip() #ai_response = re.split(r'(?<=\w[.!?]) +', ai_response) ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s] # Encode the prompt and candidates prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True) candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True) # Compute similarity scores between prompt and each candidate similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0] # Find the candidate with the highest similarity score best_index = similarities.argmax() best_response = ai_response[best_index] # Assuming best_response is already defined and contains the generated response if highest_label == "dialog continuation": # Split the response into sentences sentences = best_response.split('. ') # Take the first three sentences and join them back together best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response # Append the user's message to the chat history chat_history.append({'role': 'user', 'content': user_input}) chat_history.append({'role': 'assistant', 'content': best_response}) return best_response except Exception as e: print("Error in generate_response:") print(traceback.format_exc()) # Logs the full traceback raise e @app.route("/send_message", methods=["POST"]) def handle_post_request(): try: data = request.get_json() if data is None: return jsonify({"error": "No JSON data provided"}), 400 message = data.get("inputs", "No message provided.") model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct") #model_id = data.get("model_id", "openai-community/gpt2-large") print(f"Processing request with model_id: {model_id}") model_response = generate_response(message, model_id) return jsonify({ "received_message": model_response, "model_id": model_id, "status": "POST request successful!" }) except Exception as e: print("Error handling POST request:") print(traceback.format_exc()) # Logs the full traceback return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)