import os import torch from flask import Flask, jsonify, request from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig import re import traceback # Set the HF_HOME environment variable to a writable directory os.environ["HF_HOME"] = "/workspace/huggingface_cache" app = Flask(__name__) # Enable CORS for specific origins CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}}) # Global variables for model and tokenizer model = None tokenizer = None def get_model_and_tokenizer(model_id: str): global model, tokenizer if model is None or tokenizer is None: try: print(f"Loading tokenizer for model_id: {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token print(f"Loading model for model_id: {model_id}") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True ) model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto" ) model.config.use_cache = False model.config.pretraining_tp = 1 model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue except Exception as e: print("Error loading model:") print(traceback.format_exc()) # Logs the full error traceback raise e # Reraise the exception to stop execution def generate_response(user_input, model_id): try: get_model_and_tokenizer(model_id) prompt = user_input device = "cuda" if torch.cuda.is_available() else "cpu" generation_config = GenerationConfig( penalty_alpha=0.6, do_sample=True, top_p=0.2, top_k=50, temperature=0.3, repetition_penalty=1.2, max_new_tokens=60, pad_token_id=tokenizer.eos_token_id ) inputs = tokenizer(prompt, return_tensors="pt").to(device) model.to(device) outputs = model.generate(**inputs, generation_config=generation_config) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up response cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip() return cleaned_response.split("\n")[0] except Exception as e: print("Error in generate_response:") print(traceback.format_exc()) # Logs the full traceback raise e @app.route("/send_message", methods=["POST"]) def handle_post_request(): try: data = request.get_json() if data is None: return jsonify({"error": "No JSON data provided"}), 400 message = data.get("inputs", "No message provided.") model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") print(f"Processing request with model_id: {model_id}") model_response = generate_response(message, model_id) return jsonify({ "received_message": model_response, "model_id": model_id, "status": "POST request successful!" }) except Exception as e: print("Error handling POST request:") print(traceback.format_exc()) # Logs the full traceback return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)