File size: 10,373 Bytes
a7b29f8
 
f1fd41e
6f93dce
a7b29f8
 
 
 
 
 
580eaed
a7b29f8
 
 
 
 
 
 
fa2a9d3
a7b29f8
 
 
 
2f665a8
a7b29f8
 
 
 
 
 
 
59219bf
a7b29f8
59219bf
a7b29f8
 
 
fa2a9d3
a7b29f8
 
 
 
 
 
 
fa2a9d3
a7b29f8
 
fa2a9d3
 
a7b29f8
 
fa2a9d3
a7b29f8
 
 
 
 
fa2a9d3
 
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
fa2a9d3
a7b29f8
fa2a9d3
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59219bf
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f665a8
a7b29f8
 
 
 
45ef073
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
 
 
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
c5dd812
6f93dce
a7b29f8
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
 
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
c5dd812
a7b29f8
 
 
c5dd812
a7b29f8
 
 
 
 
c5dd812
 
a7b29f8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import os
import torch
import time
import logging
import threading
import queue
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Fix caching issue on Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

# Global model variables
tokenizer = None
model = None

# Initialize models once on startup
def initialize_models():
    global tokenizer, model
    try:
        logger.info("Loading language model...")
        
        # You can change the model here if needed
        model_name = "Qwen/Qwen2.5-1.5B-Instruct"  # Good balance of quality and speed for CPU
        
        # Load tokenizer with caching
        logger.info(f"Loading tokenizer: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            use_fast=True  # Use the fast tokenizers when available
        )
        
        # Load model with optimizations for CPU
        logger.info(f"Loading model: {model_name}")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use float16 for lower memory 
            device_map="cpu",  # Explicitly set to CPU
            low_cpu_mem_usage=True,  # Optimize memory loading
            offload_folder="offload"  # Use disk offloading if needed
        )
        
        # Handle padding tokens
        if tokenizer.pad_token is None:
            logger.info("Setting pad token to EOS token")
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
            
        # Set up model configuration for better generation
        model.config.do_sample = True  # Enable sampling
        model.config.temperature = 0.7  # Default temperature
        model.config.top_p = 0.9  # Default top_p
        
        logger.info("Models initialized successfully")
    except Exception as e:
        logger.error(f"Error initializing models: {str(e)}")
        raise

# Function to simulate "thinking" process
def thinking_process(message, result_queue):
    """
    This function simulates a thinking process and puts the result in the queue.
    It includes both an explicit thinking stage and then a generation stage.
    """
    try:
        # Simulate explicit thinking stage
        logger.info(f"Thinking about: '{message}'")
        
        # Pause to simulate deeper thinking (helps with more complex queries)
        time.sleep(1)
        
        # Create thoughtful prompt with system message and thinking instructions
        prompt = f"""<|im_start|>system
You are a helpful, friendly, and thoughtful AI assistant. 
Let's approach the user's request step by step:
1. First, understand what the user is really asking
2. Consider the key aspects we need to address
3. Think about the best way to structure the response
4. Provide clear, accurate information in a conversational tone

Always think carefully before responding, consider different angles, and provide thoughtful, detailed answers.
<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to('cpu') for k, v in inputs.items()}
        
        # Generate answer with streaming
        streamer = TextStreamer(tokenizer, result_queue)
        
        # Simulate thinking first by sending some initial dots
        result_queue.put("Let me think about this...")
        time.sleep(0.5)
        
        # Generate response - we use a temperature of 0.7 for more thoughtful outputs
        # and top_p for nucleus sampling to avoid repetitive or generic responses
        try:
            model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                streamer=streamer,
                num_beams=2,  # Using 2 beams helps with coherence
                no_repeat_ngram_size=3,
                repetition_penalty=1.2  # Discourages token repetition
            )
        except Exception as e:
            logger.error(f"Model generation error: {str(e)}")
            result_queue.put(f"\n\nI apologize, but I encountered an error while processing your request.")
        
        # Signal generation is complete
        result_queue.put(None)
        
    except Exception as e:
        logger.error(f"Error in thinking process: {str(e)}")
        result_queue.put(f"I apologize, but I encountered an error while processing your request: {str(e)}")
        # Signal generation is complete
        result_queue.put(None)

# TextStreamer class for token-by-token generation
class TextStreamer:
    def __init__(self, tokenizer, queue):
        self.tokenizer = tokenizer
        self.queue = queue
        self.current_tokens = []
        
    def put(self, token_ids):
        self.current_tokens.extend(token_ids.tolist())
        text = self.tokenizer.decode(self.current_tokens, skip_special_tokens=True)
        self.queue.put(text)
        
    def end(self):
        pass

# API route for home page
@app.route('/')
def home():
    return jsonify({"message": "AI Chat API is running!"})

# API route for streaming chat responses
@app.route('/chat', methods=['POST', 'GET'])
def chat():
    # Handle both POST JSON and GET query parameters for flexibility
    if request.method == 'POST':
        try:
            data = request.get_json()
            message = data.get("message", "")
        except:
            # If JSON parsing fails, try form data
            message = request.form.get("message", "")
    else:  # GET
        message = request.args.get("message", "")
    
    if not message:
        return jsonify({"error": "Message is required"}), 400
    
    try:
        def generate():
            # Signal the start of streaming with headers
            yield "retry: 1000\n"
            yield "event: message\n"
            
            # Show thinking indicator
            yield f"data: [Thinking...]\n\n"
            
            # Create a queue for communication between threads
            result_queue = queue.Queue()
            
            # Start thinking in a separate thread
            thread = threading.Thread(target=thinking_process, args=(message, result_queue))
            thread.daemon = True  # Make thread die when main thread exits
            thread.start()
            
            # Stream results as they become available
            previous_text = ""
            while True:
                try:
                    result = result_queue.get(block=True, timeout=30)  # 30 second timeout
                    if result is None:  # End of generation
                        break
                        
                    # Only yield the new part of the text
                    if isinstance(result, str):
                        new_part = result[len(previous_text):]
                        previous_text = result
                        if new_part:
                            yield f"data: {json.dumps(new_part)}\n\n"
                            time.sleep(0.01)  # Small delay for more natural typing effect
                            
                except queue.Empty:
                    # Timeout occurred
                    yield "data: [Generation timeout. The model is taking too long to respond.]\n\n"
                    break
                    
            yield "data: [DONE]\n\n"
            
        return Response(
            stream_with_context(generate()), 
            mimetype='text/event-stream',
            headers={
                'Cache-Control': 'no-cache',
                'Connection': 'keep-alive',
                'X-Accel-Buffering': 'no'  # Disable buffering for Nginx
            }
        )
        
    except Exception as e:
        logger.error(f"Error processing chat request: {str(e)}")
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

# Simple API for non-streaming chat (fallback)
@app.route('/chat-simple', methods=['POST'])
def chat_simple():
    data = request.get_json()
    message = data.get("message", "")
    
    if not message:
        return jsonify({"error": "Message is required"}), 400
    
    try:
        # Create prompt with system message
        prompt = f"""<|im_start|>system
You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses.
<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to('cpu') for k, v in inputs.items()}
        
        # Generate answer
        output = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=1,
            no_repeat_ngram_size=3
        )
        
        # Decode and format answer
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Clean up the response
        if "<|im_end|>" in answer:
            answer = answer.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
        
        return jsonify({"response": answer})
        
    except Exception as e:
        logger.error(f"Error processing chat request: {str(e)}")
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

if __name__ == "__main__":
    try:
        # Initialize models at startup
        initialize_models()
        logger.info("Starting Flask application")
        app.run(host="0.0.0.0", port=7860)
    except Exception as e:
        logger.critical(f"Failed to start application: {str(e)}")