Spaces:
Sleeping
Sleeping
File size: 3,743 Bytes
27f6ef7 e158a1c e384a9f 233b98c 44e0ccd bbaa18e e384a9f 8a9401d e384a9f e158a1c 8a9401d bbaa18e e158a1c bbaa18e 56599c7 bbaa18e 70f5edf 44e0ccd bbaa18e 44e0ccd e158a1c 05f391e e158a1c bbaa18e e158a1c bbaa18e e158a1c bbaa18e 9f05250 09df582 e158a1c 56599c7 e158a1c 8a9401d 05f391e 56599c7 9f05250 bbaa18e 9f05250 bbaa18e 9f05250 bbaa18e 9f05250 d469f0d 09df582 233b98c e384a9f b4930ce df73242 233b98c d469f0d 498261c d469f0d e384a9f 8a9401d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import time
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
app = Flask(__name__)
# Enable CORS for specific origins
CORS(app, resources={r"/api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
# Global variables for model and tokenizer
model = None
tokenizer = None
last_loaded_time = 0
COOLDOWN_PERIOD = 300 # Set your cooldown period to 5 minutes (300 seconds)
def get_model_and_tokenizer(model_id):
global model, tokenizer
try:
print(f"Loading tokenizer for model_id: {model_id}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model for model_id: {model_id}")
# Load the model
model = AutoModelForCausalLM.from_pretrained(model_id) # , device_map="auto")
model.config.use_cache = False
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
def is_model_loaded_and_fresh():
global last_loaded_time
current_time = time.time()
return model is not None and (current_time - last_loaded_time) < COOLDOWN_PERIOD
def generate_response(user_input, model_id):
prompt = formatted_prompt(user_input)
global model, tokenizer
# Check if model is loaded and fresh
if not is_model_loaded_and_fresh():
get_model_and_tokenizer(model_id) # Load model and tokenizer
global last_loaded_time
last_loaded_time = time.time() # Update the last load time
# Prepare the input tensors
inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to GPU if available
generation_config = GenerationConfig(
max_new_tokens=100,
min_length=5,
do_sample=False,
num_beams=1,
pad_token_id=tokenizer.eos_token_id,
truncation=True
)
try:
# Generate response
outputs = model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
print(f"Error generating response: {e}")
return "Error generating response."
def formatted_prompt(question) -> str:
return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
@app.route("/", methods=["GET"])
def handle_get_request():
message = request.args.get("message", "No message provided.")
return jsonify({"message": message, "status": "GET request successful!"})
@app.route("/send_message", methods=["POST"])
def handle_post_request():
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
message = data.get("inputs", "No message provided.")
model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin") # Default model if not provided
try:
# Generate a response from the model
model_response = generate_response(message, model_id)
return jsonify({
"received_message": model_response,
"model_id": model_id,
"status": "POST request successful!"
})
except Exception as e:
print(f"Error handling POST request: {e}")
return jsonify({"error": "An error occurred while processing your request."}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|