File size: 3,736 Bytes
27f6ef7
cbe8e48
e384a9f
1e04073
34139ad
bbaa18e
 
e384a9f
 
8a9401d
e384a9f
188010c
834da73
bbaa18e
 
 
 
56599c7
bbaa18e
4721a1c
 
 
 
 
cbe8e48
 
9143358
 
 
 
 
 
 
 
 
 
 
 
 
4721a1c
 
 
 
 
8c39757
4721a1c
 
cbe8e48
 
9143358
 
bfe1386
f4c3c98
 
bfe1386
 
 
f4c3c98
bfe1386
c693434
bfe1386
f4c3c98
9f05250
9143358
 
34139ad
 
9143358
 
bfe1386
 
f4c3c98
e384a9f
 
 
 
 
 
 
 
 
 
 
b4930ce
df73242
233b98c
d469f0d
834da73
d469f0d
 
 
 
498261c
d469f0d
 
 
 
 
e384a9f
 
8a9401d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS 
from transformers import GPTNeoForCausalLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig 
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"  # Change this to a writable path in your space

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
 
# Global variables for model and tokenizer
model = None
tokenizer = None

def get_model_and_tokenizer(model_id):
    global model, tokenizer
    if model is None or tokenizer is None:
        try:
            print(f"Loading tokenizer for model_id: {model_id}")
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            tokenizer.pad_token = tokenizer.eos_token

            print(f"Loading model for model_id: {model_id} on {device}")
             
            bnb_config = BitsAndBytesConfig( 
                load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
            )
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id, quantization_config=bnb_config, device_map="auto"
            )
            
            model.config.use_cache=False
            model.config.pretraining_tp=1

            
        except Exception as e:
            print(f"Error loading model: {e}")
            raise e  # Raise the error to be caught in the POST request
    else:
        print(f"Model and tokenizer for {model_id} are already loaded.")

def generate_response(user_input, model_id):
    # Ensure model and tokenizer are loaded
    get_model_and_tokenizer(model_id)

    prompt = user_input 
 
    generation_config = GenerationConfig(         
        penalty_alpha=0.6,
        do_sample=True,
        top_p=0.2,
        top_k=50,
        temperature=0.3,
        repetition_penalty=1.2,
        max_new_tokens=60,
        pad_token_id=tokenizer.eos_token_id,
        stop_sequences=["User:", "Assistant:", "\n"],
    )

    
    inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
    
    outputs = model.generate(**inputs, generation_config=generation_config)
    response = (tokenizer.decode(outputs[0], skip_special_tokens=True)) 
     
    cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip() 
    return cleaned_response.strip().split("\n")[0]  # Keep only the first line of response

@app.route("/", methods=["GET"])
def handle_get_request():
    message = request.args.get("message", "No message provided.")
    return jsonify({"message": message, "status": "GET request successful!"})

@app.route("/send_message", methods=["POST"])
def handle_post_request():
    data = request.get_json()
    if data is None:
        return jsonify({"error": "No JSON data provided"}), 400

    message = data.get("inputs", "No message provided.") 
    model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")  # Default model if not provided

    try:
        print(f"Loading")
        # Generate a response from the model
        model_response = generate_response(message, model_id)
        return jsonify({
            "received_message": model_response, 
            "model_id": model_id, 
            "status": "POST request successful!"
        })
    except Exception as e:
        print(f"Error handling POST request: {e}")
        return jsonify({"error": "An error occurred while processing your request."}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)