opsgenius3 / app.py
YALCINKAYA's picture
Update app.py
1cd29ab verified
raw
history blame
4.18 kB
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator()
def get_model_and_tokenizer(model_id: str):
global model, tokenizer
if model is None or tokenizer is None:
try:
print(f"Loading tokenizer for model_id: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for model_id: {model_id}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto"
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
# Use accelerator.prepare() to handle device assignment (no need to move model manually)
model = accelerator.prepare(model)
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
def generate_response(user_input, model_id):
try:
get_model_and_tokenizer(model_id)
prompt = formatted_prompt(user_input)
#prompt = user_input
device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
generation_config = GenerationConfig(
penalty_alpha=0.6,
do_sample=True,
top_p=0.2,
top_k=50,
temperature=0.3,
repetition_penalty=1.2,
max_new_tokens=60,
pad_token_id=tokenizer.eos_token_id
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# No need to move model here, as it's already dispatched to the correct device
outputs = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
#cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
#return cleaned_response.split("\n")[0]
return response
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
def formatted_prompt(question)-> str:
return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)