Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import uuid | |
import shutil | |
import numpy as np | |
import faiss | |
from flask import Flask, jsonify, request | |
from flask_cors import CORS | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig | |
from accelerate import Accelerator | |
import re | |
import traceback | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
# Set the HF_HOME environment variable to a writable directory | |
os.environ["HF_HOME"] = "/workspace/huggingface_cache" | |
app = Flask(__name__) | |
# Enable CORS for specific origins | |
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}}) | |
# Load zero-shot classification pipeline | |
#classifier = pipeline("zero-shot-classification") | |
# Load Sentence-BERT model | |
bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
accelerator = Accelerator() | |
highest_label = None | |
loaded_models = {} | |
# Load model with accelerator | |
classifier = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
revision="d7645e1", | |
device=accelerator.device # Ensures correct device placement | |
) | |
# Move model to correct device | |
classifier.model = accelerator.prepare(classifier.model) | |
# Define upload directory and FAISS index file | |
UPLOAD_DIR = "/app/uploads" | |
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin") | |
# Ensure upload directory exists and has write permissions | |
try: | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
if not os.access(UPLOAD_DIR, os.W_OK): | |
print(f"Fixing permissions for {UPLOAD_DIR}...") | |
os.chmod(UPLOAD_DIR, 0o777) | |
print(f"Uploads directory is ready: {UPLOAD_DIR}") | |
except PermissionError as e: | |
print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.") | |
document_store = {} | |
def initialize_faiss(): | |
if os.path.exists(faiss_index_file): | |
try: | |
print(f"FAISS index file {faiss_index_file} exists, attempting to load it.") | |
index = faiss.read_index(faiss_index_file) | |
if index.ntotal > 0: | |
print(f"FAISS index loaded with {index.ntotal} vectors.") | |
# If the index has non-zero entries, reset it | |
index.reset() # Resetting the index if non-zero entries | |
print("Index reset. Reinitializing index.") | |
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Reinitialize the index | |
else: | |
print("Loaded index has zero vectors, reinitializing index.") | |
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) # Initialize with flat L2 distance | |
except Exception as e: | |
print(f"Error loading FAISS index: {e}, reinitializing a new index.") | |
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) | |
else: | |
print(f"FAISS index file {faiss_index_file} does not exist, initializing a new index.") | |
index = faiss.IndexIDMap(faiss.IndexFlatL2(384)) | |
# Move to GPU if available | |
# if torch.cuda.is_available(): | |
# print("CUDA is available, moving FAISS index to GPU.") | |
# index = faiss.index_cpu_to_all_gpus(index) | |
# print("FAISS index is now on GPU.") | |
return index | |
def save_faiss_index(index): | |
try: | |
if torch.cuda.is_available(): | |
print("Moving FAISS index back to CPU before saving.") | |
res = faiss.StandardGpuResources() # Allocate GPU resources | |
index = faiss.index_cpu_to_gpu(res, 0, index) # Move to GPU 0 | |
print(f"Saving FAISS index to {faiss_index_file}.") | |
faiss.write_index(index, faiss_index_file) | |
print(f"FAISS index successfully saved to {faiss_index_file}.") | |
except Exception as e: | |
print(f"Error saving FAISS index: {e}") | |
# Initialize FAISS index | |
index = initialize_faiss() | |
# Save FAISS index after modifications | |
save_faiss_index(index) | |
# Load document store and populate FAISS index | |
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledgebase1.txt") # Ensure this path is correct | |
def load_document_store(): | |
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings""" | |
global document_store | |
document_store = {} # Reset document store | |
all_texts = [] | |
if os.path.exists(knowledgebase_file): | |
with open(knowledgebase_file, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
for i, line in enumerate(lines): | |
text = line.strip() | |
if text: | |
document_store[i] = {"text": text} # Store text mapped to FAISS ID | |
all_texts.append(text) # Collect all texts for embedding | |
print(f"Loaded {len(document_store)} documents into document_store.") | |
else: | |
print("Error: knowledgebase.txt not found!") | |
# Generate embeddings for all documents | |
embeddings = bertmodel.encode(all_texts) | |
embeddings = embeddings.astype("float32") | |
# Add embeddings to FAISS index | |
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64)) | |
print(f"Added {len(all_texts)} document embeddings to FAISS index.") | |
def load_document_store_once(file_path): | |
"""Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings""" | |
global document_store | |
document_store = {} # Reset document store | |
all_texts = [] | |
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path)) | |
if os.path.exists(file_location): | |
with open(file_location, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
for i, line in enumerate(lines): | |
text = line.strip() | |
if text: | |
document_store[i] = {"text": text} # Store text mapped to FAISS ID | |
all_texts.append(text) # Collect all texts for embedding | |
print(f"Loaded {len(document_store)} documents into document_store.") | |
else: | |
print("Error: knowledgebase.txt not found!") | |
# Generate embeddings for all documents | |
embeddings = bertmodel.encode(all_texts) | |
embeddings = embeddings.astype("float32") | |
# Add embeddings to FAISS index | |
index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64)) | |
print(f"Added {len(all_texts)} document embeddings to FAISS index.") | |
# Function to upload document | |
def upload_document(file_path, embed_model): | |
try: | |
# Generate unique document ID | |
doc_id = uuid.uuid4().int % (2**63 - 1) | |
# Ensure the file is saved to the correct directory with secure handling | |
file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path)) | |
print(f"Saving file to: {file_location}") # Log the location | |
# Safely copy the file to the upload directory | |
shutil.copy(file_path, file_location) | |
# Read the content of the uploaded file | |
try: | |
with open(file_location, "r", encoding="utf-8") as f: | |
text = f.read() | |
except Exception as e: | |
print(f"Error reading file {file_location}: {e}") | |
return {"error": f"Error reading file: {e}"}, 507 # Error while reading file | |
# Embed the text and add it to the FAISS index | |
try: | |
# Ensure the embedding model is valid | |
if embed_model is None: | |
raise ValueError("Embedding model is not initialized properly.") | |
vector = embed_model.encode(text).astype("float32") | |
print(f"Generated vector for document {doc_id}: {vector}") # Log vector | |
index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64)) | |
document_store[doc_id] = {"path": file_location, "text": text} | |
# Log FAISS index file path | |
print(f"Saving FAISS index to: {faiss_index_file}") # Log the file path | |
# Save the FAISS index after adding the document | |
try: | |
faiss.write_index(index, faiss_index_file) | |
print(f"Document uploaded with doc_id: {doc_id}") | |
except Exception as e: | |
print(f"Error saving FAISS index: {e}") | |
return {"error": f"Error saving FAISS index: {e}"}, 508 # Error while saving FAISS index | |
except Exception as e: | |
print(f"Error during document upload: {e}") | |
return {"error": f"Error during document upload: {e}"}, 509 # Error during embedding or FAISS processing | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
return {"error": f"Unexpected error: {e}"}, 500 # General error | |
def list_uploaded_files(): | |
try: | |
# Ensure the upload directory exists | |
if not os.path.exists(UPLOAD_DIR): | |
return jsonify({"error": "Upload directory does not exist"}), 400 | |
# List all files in the upload directory | |
files = os.listdir(UPLOAD_DIR) | |
if not files: | |
return jsonify({"message": "No files found in the upload directory"}), 200 | |
return jsonify({"files": files}), 200 | |
except Exception as e: | |
return jsonify({"error": f"Error listing files: {e}"}), 504 | |
def handle_upload(): | |
# Check if the request contains the file | |
if "file" not in request.files: | |
return jsonify({"error": "No file provided"}), 400 | |
file = request.files["file"] | |
# Ensure the filename is safe and construct the full file path | |
file_path = os.path.join(UPLOAD_DIR, file.filename) | |
# Ensure the upload directory exists and has correct permissions | |
try: | |
os.makedirs(UPLOAD_DIR, exist_ok=True) # Ensure the directory exists | |
if not os.access(UPLOAD_DIR, os.W_OK): # Check write permissions | |
os.chmod(UPLOAD_DIR, 0o777) | |
except PermissionError as e: | |
return jsonify({"error": f"Permission error with upload directory: {e}"}), 501 | |
try: | |
# Save the file to the upload directory | |
file.save(file_path) | |
load_document_store() # Reload FAISS index | |
# Now that the document is uploaded, call load_document_store() | |
print(f"File uploaded successfully. Calling load_document_store()...") | |
except Exception as e: | |
return jsonify({"error": f"Error saving file: {e}"}), 502 | |
# Process the document using the upload_document function | |
try: | |
load_document_store_once(file_path) | |
# upload_document(file_path, bertmodel) # Assuming 'bertmodel' is defined elsewhere | |
except Exception as e: | |
return jsonify({"error": f"Error processing file: {e}"}), 503 | |
# Return success response | |
return jsonify({"message": "File uploaded and processed successfully"}), 200 | |
def get_model_and_tokenizer(model_id: str): | |
""" | |
Load and cache the model and tokenizer for the given model_id. | |
""" | |
global model, tokenizer # Declare global variables to modify them within the function | |
if model_id not in loaded_models: | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
model = accelerator.prepare(model) | |
loaded_models[model_id] = (model, tokenizer) | |
except Exception as e: | |
print("Error loading model:") | |
print(traceback.format_exc()) # Logs the full error traceback | |
raise e # Reraise the exception to stop execution | |
return loaded_models[model_id] | |
# Extract the core sentence needing grammar correction | |
def extract_core_sentence(user_input): | |
""" | |
Extract the core sentence needing grammar correction from the user input. | |
""" | |
match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE) | |
if match: | |
return match.group(0).strip() | |
return user_input | |
def classify_intent(user_input): | |
""" | |
Classify the intent of the user input using zero-shot classification. | |
""" | |
candidate_labels = [ | |
"grammar correction", "information request", "task completion", | |
"dialog continuation", "personal opinion", "product inquiry", | |
"feedback request", "recommendation request", "clarification request", | |
"affirmation or agreement", "real-time data request", "current information" | |
] | |
result = classifier(user_input, candidate_labels) | |
highest_score_index = result['scores'].index(max(result['scores'])) | |
highest_label = result['labels'][highest_score_index] | |
return highest_label | |
# Reformulate the prompt based on intent | |
# Function to generate reformulated prompts | |
def reformulate_prompt(user_input, intent_label): | |
""" | |
Reformulate the prompt based on the classified intent. | |
""" | |
core_sentence = extract_core_sentence(user_input) | |
prompt_templates = { | |
"grammar correction": f"Fix the grammar in this sentence: {core_sentence}", | |
"information request": f"Provide information about: {core_sentence}", | |
"dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n", | |
"personal opinion": f"What is your personal opinion on: {core_sentence}?", | |
"product inquiry": f"Provide details about the product: {core_sentence}", | |
"feedback request": f"Please provide feedback on: {core_sentence}", | |
"recommendation request": f"Recommend something related to: {core_sentence}", | |
"clarification request": f"Clarify the following: {core_sentence}", | |
"affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}", | |
} | |
return prompt_templates.get(intent_label, "Input does not require a defined action.") | |
chat_history = [ | |
("Hi there, how are you?", "I am fine. How are you?"), | |
("Tell me a joke!", "The capital of France is Paris."), | |
("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"), | |
] | |
def generate_response(user_input, model_id): | |
try: | |
model, tokenizer = get_model_and_tokenizer(model_id) | |
device = accelerator.device # Get the device from the accelerator | |
# Append chat history | |
func_caller = [] | |
query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32") | |
D, I = index.search(query_vector, 1) | |
# Retrieve document | |
retrieved_id = I[0][0] | |
retrieved_knowledge = ( | |
document_store.get(retrieved_id, {}).get("text", "No relevant information found.") | |
if retrieved_id != -1 else "No relevant information found." | |
) | |
# Construct the knowledge prompt | |
prompt = f"Use the following knowledge:\n{retrieved_knowledge}" | |
# Log the prompt (you can change this to a logging library if needed) | |
print(f"Generated prompt: {prompt}") # <-- Log the prompt here | |
# Add the retrieved knowledge to the prompt | |
func_caller.append({"role": "system", "content": prompt}) | |
for msg in chat_history: | |
func_caller.append({"role": "user", "content": f"{str(msg[0])}"}) | |
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"}) | |
highest_label_result = classify_intent(user_input) | |
# Reformulated prompt based on intent classification | |
reformulated_prompt = reformulate_prompt(user_input, highest_label_result) | |
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'}) | |
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller]) | |
#prompt = user_input | |
#device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup | |
generation_config = GenerationConfig( | |
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False | |
temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents | |
top_k = 5 if highest_label == "recommendation request" else None, | |
#attention_mask=attention_mask, | |
max_length=150, | |
repetition_penalty=1.2, | |
length_penalty=1.0, | |
no_repeat_ngram_size=2, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id, | |
#stop_sequences=["User:", "Assistant:", "\n"], | |
) | |
# Generate response | |
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) | |
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config) | |
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True) | |
# Extract AI's response only (omit the prompt) | |
#ai_response2 = final_response.replace(reformulated_prompt, "").strip() | |
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip() | |
#ai_response = re.split(r'(?<=\w[.!?]) +', ai_response) | |
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s] | |
# Encode the prompt and candidates | |
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True) | |
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True) | |
# Compute similarity scores between prompt and each candidate | |
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0] | |
# Find the candidate with the highest similarity score | |
best_index = similarities.argmax() | |
best_response = ai_response[best_index] | |
# Assuming best_response is already defined and contains the generated response | |
if highest_label == "dialog continuation": | |
# Split the response into sentences | |
sentences = best_response.split('. ') | |
# Take the first three sentences and join them back together | |
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response | |
# Append the user's message to the chat history | |
chat_history.append({'role': 'user', 'content': user_input}) | |
chat_history.append({'role': 'assistant', 'content': best_response}) | |
return best_response | |
except Exception as e: | |
print("Error in generate_response:") | |
print(traceback.format_exc()) # Logs the full traceback | |
raise e | |
def handle_post_request(): | |
try: | |
data = request.get_json() | |
if data is None: | |
return jsonify({"error": "No JSON data provided"}), 400 | |
message = data.get("inputs", "No message provided.") | |
model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct") | |
#model_id = data.get("model_id", "openai-community/gpt2-large") | |
print(f"Processing request with model_id: {model_id}") | |
model_response = generate_response(message, model_id) | |
return jsonify({ | |
"received_message": model_response, | |
"model_id": model_id, | |
"status": "POST request successful!" | |
}) | |
except Exception as e: | |
print("Error handling POST request:") | |
print(traceback.format_exc()) # Logs the full traceback | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |