File size: 3,743 Bytes
27f6ef7
e158a1c
e384a9f
233b98c
44e0ccd
bbaa18e
 
 
e384a9f
 
8a9401d
e384a9f
e158a1c
8a9401d
bbaa18e
 
 
e158a1c
 
bbaa18e
56599c7
bbaa18e
 
70f5edf
44e0ccd
 
bbaa18e
44e0ccd
e158a1c
05f391e
e158a1c
bbaa18e
e158a1c
bbaa18e
 
 
e158a1c
 
 
 
 
bbaa18e
 
9f05250
09df582
 
e158a1c
 
56599c7
e158a1c
 
8a9401d
05f391e
 
56599c7
9f05250
bbaa18e
9f05250
bbaa18e
9f05250
bbaa18e
 
9f05250
 
d469f0d
 
 
 
 
 
 
 
09df582
233b98c
 
e384a9f
 
 
 
 
 
 
 
 
 
 
 
b4930ce
df73242
233b98c
d469f0d
 
 
 
 
498261c
d469f0d
 
 
 
 
e384a9f
 
8a9401d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import time
from flask import Flask, jsonify, request
from flask_cors import CORS 
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig  

# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"  # Change this to a writable path in your space

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"/api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})

# Global variables for model and tokenizer
model = None
tokenizer = None
last_loaded_time = 0
COOLDOWN_PERIOD = 300  # Set your cooldown period to 5 minutes (300 seconds)

def get_model_and_tokenizer(model_id):
    global model, tokenizer
    try:
        print(f"Loading tokenizer for model_id: {model_id}")
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
        tokenizer.pad_token = tokenizer.eos_token
        
        print(f"Loading model for model_id: {model_id}")
        # Load the model
        model = AutoModelForCausalLM.from_pretrained(model_id)  # , device_map="auto")
        model.config.use_cache = False
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading model: {e}")

def is_model_loaded_and_fresh():
    global last_loaded_time
    current_time = time.time()
    return model is not None and (current_time - last_loaded_time) < COOLDOWN_PERIOD

def generate_response(user_input, model_id):
    prompt = formatted_prompt(user_input)
    
    global model, tokenizer

    # Check if model is loaded and fresh
    if not is_model_loaded_and_fresh():
        get_model_and_tokenizer(model_id)  # Load model and tokenizer
        global last_loaded_time
        last_loaded_time = time.time()  # Update the last load time

    # Prepare the input tensors
    inputs = tokenizer(prompt, return_tensors="pt")  # Move inputs to GPU if available
    
    generation_config = GenerationConfig(
        max_new_tokens=100,
        min_length=5,
        do_sample=False,
        num_beams=1,
        pad_token_id=tokenizer.eos_token_id,
        truncation=True
    )

    try:
        # Generate response
        outputs = model.generate(**inputs, generation_config=generation_config)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except Exception as e:
        print(f"Error generating response: {e}")
        return "Error generating response."
    
def formatted_prompt(question) -> str:
    return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"

@app.route("/", methods=["GET"])
def handle_get_request():
    message = request.args.get("message", "No message provided.")
    return jsonify({"message": message, "status": "GET request successful!"})

@app.route("/send_message", methods=["POST"])
def handle_post_request():
    data = request.get_json()
    if data is None:
        return jsonify({"error": "No JSON data provided"}), 400

    message = data.get("inputs", "No message provided.") 
    model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")  # Default model if not provided

    try:
        # Generate a response from the model
        model_response = generate_response(message, model_id)
        return jsonify({
            "received_message": model_response, 
            "model_id": model_id, 
            "status": "POST request successful!"
        })
    except Exception as e:
        print(f"Error handling POST request: {e}")
        return jsonify({"error": "An error occurred while processing your request."}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)