File size: 3,688 Bytes
bb8e493
 
 
 
 
 
e35273c
bb8e493
 
 
 
 
 
 
 
 
 
 
 
bbaa18e
284c0f7
bbaa18e
4721a1c
 
 
 
 
cbe8e48
284c0f7
 
 
8cf19de
 
9143358
 
 
 
 
 
284c0f7
 
 
9143358
4721a1c
8cf19de
 
 
8c39757
4721a1c
8cf19de
 
cbe8e48
8cf19de
 
9f05250
8cf19de
 
 
 
 
 
 
 
 
 
284c0f7
8cf19de
 
284c0f7
8cf19de
 
f4c3c98
8cf19de
 
 
 
 
 
 
 
e384a9f
 
 
8cf19de
 
 
 
e384a9f
8cf19de
 
233b98c
8cf19de
d469f0d
8cf19de
d469f0d
284c0f7
 
d469f0d
 
 
8cf19de
 
 
e384a9f
bb8e493
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
import re
import traceback

# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})

# Global variables for model and tokenizer
model = None
tokenizer = None

def get_model_and_tokenizer(model_id: str):
    global model, tokenizer
    if model is None or tokenizer is None:
        try:
            print(f"Loading tokenizer for model_id: {model_id}")
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            tokenizer.pad_token = tokenizer.eos_token

            print(f"Loading model for model_id: {model_id}")
            
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True, bnb_4bit_quant_type="nf4", 
                bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
            )
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id, quantization_config=bnb_config, device_map="auto"
            )
            
            model.config.use_cache = False
            model.config.pretraining_tp = 1
            model.config.pad_token_id = tokenizer.eos_token_id  # Fix padding issue

        except Exception as e:
            print("Error loading model:")
            print(traceback.format_exc())  # Logs the full error traceback
            raise e  # Reraise the exception to stop execution

def generate_response(user_input, model_id):
    try:
        get_model_and_tokenizer(model_id)

        prompt = user_input
        device = "cuda" if torch.cuda.is_available() else "cpu"

        generation_config = GenerationConfig(
            penalty_alpha=0.6,
            do_sample=True,
            top_p=0.2,
            top_k=50,
            temperature=0.3,
            repetition_penalty=1.2,
            max_new_tokens=60,
            pad_token_id=tokenizer.eos_token_id
        )

        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        model.to(device)

        outputs = model.generate(**inputs, generation_config=generation_config)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Clean up response
        cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
        return cleaned_response.split("\n")[0]

    except Exception as e:
        print("Error in generate_response:")
        print(traceback.format_exc())  # Logs the full traceback
        raise e

@app.route("/send_message", methods=["POST"])
def handle_post_request():
    try:
        data = request.get_json()
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400

        message = data.get("inputs", "No message provided.")
        model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")

        print(f"Processing request with model_id: {model_id}")
        model_response = generate_response(message, model_id)

        return jsonify({
            "received_message": model_response,
            "model_id": model_id,
            "status": "POST request successful!"
        })
    except Exception as e:
        print("Error handling POST request:")
        print(traceback.format_exc())  # Logs the full traceback
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)