opsgenius3 / app.py
YALCINKAYA's picture
Update app.py
d632349 verified
raw
history blame
9.02 kB
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Load zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")
# Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator()
highest_label = None
loaded_models = {}
def get_model_and_tokenizer(model_id: str):
"""
Load and cache the model and tokenizer for the given model_id.
"""
global model, tokenizer # Declare global variables to modify them within the function
if model_id not in loaded_models:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = accelerator.prepare(model)
loaded_models[model_id] = (model, tokenizer)
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
return loaded_models[model_id]
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
"""
Extract the core sentence needing grammar correction from the user input.
"""
match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
if match:
return match.group(0).strip()
return user_input
def classify_intent(user_input):
"""
Classify the intent of the user input using zero-shot classification.
"""
candidate_labels = [
"grammar correction", "information request", "task completion",
"dialog continuation", "personal opinion", "product inquiry",
"feedback request", "recommendation request", "clarification request",
"affirmation or agreement", "real-time data request", "current information"
]
result = classifier(user_input, candidate_labels)
highest_score_index = result['scores'].index(max(result['scores']))
highest_label = result['labels'][highest_score_index]
return highest_label
# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
"""
Reformulate the prompt based on the classified intent.
"""
core_sentence = extract_core_sentence(user_input)
prompt_templates = {
"grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
"information request": f"Provide information about: {core_sentence}",
"dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
"personal opinion": f"What is your personal opinion on: {core_sentence}?",
"product inquiry": f"Provide details about the product: {core_sentence}",
"feedback request": f"Please provide feedback on: {core_sentence}",
"recommendation request": f"Recommend something related to: {core_sentence}",
"clarification request": f"Clarify the following: {core_sentence}",
"affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
}
return prompt_templates.get(intent_label, "Input does not require a defined action.")
chat_history = [
("Hi there, how are you?", "I am fine. How are you?"),
("Tell me a joke!", "The capital of France is Paris."),
("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
]
def generate_response(user_input, model_id):
try:
model, tokenizer = get_model_and_tokenizer(model_id)
device = accelerator.device # Get the device from the accelerator
# Append chat history
func_caller = []
for msg in chat_history:
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
highest_label_result = classify_intent(user_input)
# Reformulated prompt based on intent classification
reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
#prompt = user_input
#device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
generation_config = GenerationConfig(
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents
top_k = 5 if highest_label == "recommendation request" else None,
#attention_mask=attention_mask,
max_length=150,
repetition_penalty=1.2,
length_penalty=1.0,
no_repeat_ngram_size=2,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
#stop_sequences=["User:", "Assistant:", "\n"],
)
# Generate response
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
# Extract AI's response only (omit the prompt)
#ai_response2 = final_response.replace(reformulated_prompt, "").strip()
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
#ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
# Encode the prompt and candidates
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
# Compute similarity scores between prompt and each candidate
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
# Find the candidate with the highest similarity score
best_index = similarities.argmax()
best_response = ai_response[best_index]
# Assuming best_response is already defined and contains the generated response
if highest_label == "dialog continuation":
# Split the response into sentences
sentences = best_response.split('. ')
# Take the first three sentences and join them back together
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
# Append the user's message to the chat history
chat_history.append({'role': 'user', 'content': user_input})
chat_history.append({'role': 'assistant', 'content': best_response})
return best_response
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "openai-community/gpt2-large")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)