Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from flask import Flask, jsonify, request | |
from flask_cors import CORS | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig | |
import re | |
import traceback | |
# Set the HF_HOME environment variable to a writable directory | |
os.environ["HF_HOME"] = "/workspace/huggingface_cache" | |
app = Flask(__name__) | |
# Enable CORS for specific origins | |
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}}) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def get_model_and_tokenizer(model_id: str): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
try: | |
print(f"Loading tokenizer for model_id: {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
print(f"Loading model for model_id: {model_id}") | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, quantization_config=bnb_config, device_map="auto" | |
) | |
model.config.use_cache = False | |
model.config.pretraining_tp = 1 | |
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue | |
except Exception as e: | |
print("Error loading model:") | |
print(traceback.format_exc()) # Logs the full error traceback | |
raise e # Reraise the exception to stop execution | |
def generate_response(user_input, model_id): | |
try: | |
get_model_and_tokenizer(model_id) | |
prompt = user_input | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
generation_config = GenerationConfig( | |
penalty_alpha=0.6, | |
do_sample=True, | |
top_p=0.2, | |
top_k=50, | |
temperature=0.3, | |
repetition_penalty=1.2, | |
max_new_tokens=60, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
model.to(device) | |
outputs = model.generate(**inputs, generation_config=generation_config) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up response | |
cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip() | |
return cleaned_response.split("\n")[0] | |
except Exception as e: | |
print("Error in generate_response:") | |
print(traceback.format_exc()) # Logs the full traceback | |
raise e | |
def handle_post_request(): | |
try: | |
data = request.get_json() | |
if data is None: | |
return jsonify({"error": "No JSON data provided"}), 400 | |
message = data.get("inputs", "No message provided.") | |
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") | |
print(f"Processing request with model_id: {model_id}") | |
model_response = generate_response(message, model_id) | |
return jsonify({ | |
"received_message": model_response, | |
"model_id": model_id, | |
"status": "POST request successful!" | |
}) | |
except Exception as e: | |
print("Error handling POST request:") | |
print(traceback.format_exc()) # Logs the full traceback | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) |