Spaces:
Sleeping
Sleeping
File size: 3,688 Bytes
bb8e493 e35273c bb8e493 bbaa18e 284c0f7 bbaa18e 4721a1c cbe8e48 284c0f7 8cf19de 9143358 284c0f7 9143358 4721a1c 8cf19de 8c39757 4721a1c 8cf19de cbe8e48 8cf19de 9f05250 8cf19de 284c0f7 8cf19de 284c0f7 8cf19de f4c3c98 8cf19de e384a9f 8cf19de e384a9f 8cf19de 233b98c 8cf19de d469f0d 8cf19de d469f0d 284c0f7 d469f0d 8cf19de e384a9f bb8e493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
import re
import traceback
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id: str):
global model, tokenizer
if model is None or tokenizer is None:
try:
print(f"Loading tokenizer for model_id: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for model_id: {model_id}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto"
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
except Exception as e:
print("Error loading model:")
print(traceback.format_exc()) # Logs the full error traceback
raise e # Reraise the exception to stop execution
def generate_response(user_input, model_id):
try:
get_model_and_tokenizer(model_id)
prompt = user_input
device = "cuda" if torch.cuda.is_available() else "cpu"
generation_config = GenerationConfig(
penalty_alpha=0.6,
do_sample=True,
top_p=0.2,
top_k=50,
temperature=0.3,
repetition_penalty=1.2,
max_new_tokens=60,
pad_token_id=tokenizer.eos_token_id
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model.to(device)
outputs = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
return cleaned_response.split("\n")[0]
except Exception as e:
print("Error in generate_response:")
print(traceback.format_exc()) # Logs the full traceback
raise e
@app.route("/send_message", methods=["POST"])
def handle_post_request():
try:
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")
print(f"Processing request with model_id: {model_id}")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print("Error handling POST request:")
print(traceback.format_exc()) # Logs the full traceback
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |