File size: 6,385 Bytes
27f6ef7
e384a9f
233b98c
44e0ccd
8c39757
bbaa18e
 
 
e384a9f
 
8a9401d
e384a9f
188010c
8a9401d
bbaa18e
 
 
 
56599c7
bbaa18e
 
70f5edf
44e0ccd
8c39757
bbaa18e
44e0ccd
188010c
05f391e
188010c
bbaa18e
188010c
bbaa18e
 
 
8c39757
 
ed324ed
 
 
8c39757
 
 
 
 
ed324ed
 
8c39757
ed324ed
 
 
 
 
8c39757
ed324ed
 
 
 
cffec04
ed324ed
 
8c39757
bbaa18e
 
9f05250
09df582
 
188010c
 
56599c7
8a9401d
05f391e
 
56599c7
9f05250
d3382bd
 
 
 
 
 
 
85de869
 
 
 
 
 
 
 
 
c693434
 
 
 
 
 
 
 
 
 
 
0a575e7
79de7a5
c693434
 
 
 
3b2f4b3
9f05250
d469f0d
 
b1d9e55
 
 
a7c12c0
 
b1d9e55
8c39757
d469f0d
 
 
09df582
233b98c
 
e384a9f
 
 
 
 
 
 
 
 
 
 
 
b4930ce
df73242
233b98c
d469f0d
 
 
 
 
498261c
d469f0d
 
 
 
 
e384a9f
 
8a9401d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from flask import Flask, jsonify, request
from flask_cors import CORS 
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig  
import re

# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"  # Change this to a writable path in your space

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"api/predict/*": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})

# Global variables for model and tokenizer
model = None
tokenizer = None

def get_model_and_tokenizer(model_id):
    global model, tokenizer
    try:
        print(f"Loading tokenizer for model_id: {model_id}")
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.pad_token = tokenizer.eos_token
        
        print(f"Loading model and for model_id: {model_id}")
        # Load the model
        model = AutoModelForCausalLM.from_pretrained(model_id) #, device_map="auto")
        model.config.use_cache = False
   
    except Exception as e:
        print(f"Error loading model: {e}")

def extract_relevant_text(response):
    """
    This function extracts the first complete 'user' and 'assistant' blocks
    between <|im_start|> and <|im_end|> in the generated response.
    If the tags are corrupted, it returns the text up to the first <|im_end|> tag.
    """
    # Regex to match content between <|im_start|> and <|im_end|> tags
    pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
    matches = pattern.findall(response)

    # Debugging: print the matches found
    print("Matches found:", matches)

    # If complete matches found, extract them
    if len(matches) >= 2:
        user_message = matches[0].strip()  # First <|im_start|> block
        assistant_message = matches[1].strip()  # Second <|im_start|> block
        return f"user: {user_message}\nassistant: {assistant_message}"
    
    # If no complete blocks found, check for a partial extraction
    if '<|im_end|>' in response:
        # Extract everything before the first <|im_end|>
        partial_response = response.split('<|im_end|>')[0].strip()
        return f"{partial_response}"

    return "No complete blocks found. Please check the format of the response."

def generate_response(user_input, model_id):
    prompt = formatted_prompt(user_input)
    
    global model, tokenizer

    # Load the model and tokenizer if they are not already loaded or if the model_id has changed
    if model is None or tokenizer is None or (model.config._name_or_path != model_id):
        get_model_and_tokenizer(model_id)  # Load model and tokenizer

    # Prepare the input tensors
    inputs = tokenizer(prompt, return_tensors="pt")  # Move inputs to GPU if available
    
    generation_config = GenerationConfig(
        # max_new_tokens=100,
        # min_length=5,
        # do_sample=False,
        # num_beams=1,
        # pad_token_id=tokenizer.eos_token_id,
        # truncation=True
        
        #penalty_alpha=0.6,
        #do_sample = True,
        #top_k=5,
        #temperature=0.5,
        #repetition_penalty=1.2,
        #max_new_tokens=60,
        #pad_token_id=tokenizer.eos_token_id, 
        #truncation=True, 
        
        #penalty_alpha=0.6,           # Keep this to balance exploration and exploitation
        #do_sample=True,               # Keep sampling to allow for variability in responses
        #top_k=20,                    # Increase top_k to give more options for sampling
        #temperature=0.3,             # Lower temperature to make outputs more deterministic and focused
        #repetition_penalty=1.5,      # Increase repetition penalty to discourage repeated phrases
        #max_new_tokens=60,           # Keep this as is, depending on your expected output length
        #pad_token_id=tokenizer.eos_token_id, 
        #truncation=True,                       # Enable truncation for input sequences

        penalty_alpha=0.6,           # Maintain this for balance
        do_sample=True,               # Allow sampling for variability
        top_k=3,                    # Reduce top_k to narrow down options
        temperature=0.7,             # Keep this low for more deterministic responses
        repetition_penalty=1.2,      # Keep this moderate to avoid repetitive responses
        max_new_tokens=60,           # Maintain this limit
        pad_token_id=tokenizer.eos_token_id,
        truncation=True,              # Enable truncation for longer prompts
        )

    try:
        # Generate response
        #outputs = model.generate(**inputs, generation_config=generation_config) 
        outputs = model.generate(**inputs, generation_config=generation_config)
    
        #response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        #use the slicing method
        response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
        return extract_relevant_text(response)
    except Exception as e:
        print(f"Error generating response: {e}")
        return "Error generating response."
    
def formatted_prompt(question) -> str:
    return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"

@app.route("/", methods=["GET"])
def handle_get_request():
    message = request.args.get("message", "No message provided.")
    return jsonify({"message": message, "status": "GET request successful!"})

@app.route("/send_message", methods=["POST"])
def handle_post_request():
    data = request.get_json()
    if data is None:
        return jsonify({"error": "No JSON data provided"}), 400

    message = data.get("inputs", "No message provided.") 
    model_id = data.get("model_id", "YALCINKAYA/FinetunedByYalcin")  # Default model if not provided

    try:
        # Generate a response from the model
        model_response = generate_response(message, model_id)
        return jsonify({
            "received_message": model_response, 
            "model_id": model_id, 
            "status": "POST request successful!"
        })
    except Exception as e:
        print(f"Error handling POST request: {e}")
        return jsonify({"error": "An error occurred while processing your request."}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)