Spaces:
Sleeping
Sleeping
File size: 3,591 Bytes
27f6ef7 cbe8e48 e384a9f 284c0f7 bbaa18e 284c0f7 e384a9f 8a9401d e384a9f 284c0f7 bbaa18e 284c0f7 bbaa18e 4721a1c cbe8e48 284c0f7 9143358 284c0f7 9143358 4721a1c 284c0f7 8c39757 4721a1c cbe8e48 284c0f7 f4c3c98 bfe1386 f4c3c98 bfe1386 284c0f7 f4c3c98 9f05250 284c0f7 34139ad 284c0f7 f4c3c98 e384a9f 284c0f7 233b98c d469f0d 284c0f7 d469f0d 284c0f7 d469f0d 284c0f7 d469f0d 284c0f7 e384a9f 8a9401d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
import re
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
def get_model_and_tokenizer(model_id: str):
global model, tokenizer
if model is None or tokenizer is None:
try:
print(f"Loading tokenizer for model_id: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for model_id: {model_id}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto"
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
except Exception as e:
print(f"Error loading model: {e}")
raise e
def generate_response(user_input, model_id):
# Ensure model and tokenizer are loaded
get_model_and_tokenizer(model_id)
prompt = user_input
device = "cuda" if torch.cuda.is_available() else "cpu"
generation_config = GenerationConfig(
penalty_alpha=0.6,
do_sample=True,
top_p=0.2,
top_k=50,
temperature=0.3,
repetition_penalty=1.2,
max_new_tokens=60,
pad_token_id=tokenizer.eos_token_id
)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model.to(device)
outputs = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
return cleaned_response.split("\n")[0] # Keep only the first line of response
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")
try:
print(f"Processing request")
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
error_message = str(e) if app.debug else "An error occurred while processing your request."
print(f"Error handling POST request: {e}")
return jsonify({"error": error_message}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|