File size: 7,278 Bytes
a7b29f8
 
f1fd41e
6f93dce
a7b29f8
 
 
 
 
580eaed
a7b29f8
480e847
a7b29f8
fa2a9d3
a7b29f8
 
 
 
2f665a8
a7b29f8
 
 
 
 
 
 
59219bf
a7b29f8
59219bf
a7b29f8
 
 
fa2a9d3
a7b29f8
480e847
 
fa2a9d3
a7b29f8
480e847
 
 
fa2a9d3
 
a7b29f8
 
 
 
 
fa2a9d3
a7b29f8
fa2a9d3
c5dd812
a7b29f8
 
 
480e847
a7b29f8
 
480e847
a7b29f8
 
480e847
a7b29f8
480e847
a7b29f8
 
 
 
 
 
 
480e847
 
a7b29f8
 
 
 
480e847
 
 
 
 
 
 
 
 
 
 
a7b29f8
 
 
 
 
 
 
 
 
 
480e847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7b29f8
 
 
480e847
a7b29f8
 
480e847
a7b29f8
480e847
 
c5dd812
a7b29f8
 
c5dd812
a7b29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480e847
a7b29f8
 
 
 
 
 
 
 
480e847
a7b29f8
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
c5dd812
6f93dce
480e847
a7b29f8
480e847
a7b29f8
 
 
 
 
c5dd812
480e847
 
 
 
 
 
 
 
 
 
 
 
 
 
c5dd812
a7b29f8
 
c5dd812
a7b29f8
 
 
c5dd812
a7b29f8
 
 
 
 
c5dd812
 
a7b29f8
480e847
5d565fc
480e847
 
a7b29f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import os
import torch
import time
import logging
import threading
import queue
from transformers import AutoTokenizer, AutoModelForCausalLM

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Fix caching issue on Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

# Global model variables
tokenizer = None
model = None

# Initialize models once on startup
def initialize_models():
    global tokenizer, model
    try:
        logger.info("Loading language model...")
        model_name = "Qwen/Qwen2.5-1.5B-Instruct"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,  # Use float16 for lower memory on CPU
            device_map="cpu",  # Explicitly set to CPU
            low_cpu_mem_usage=True  # Optimize memory loading
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
            
        logger.info("Models initialized successfully")
    except Exception as e:
        logger.error(f"Error initializing models: {str(e)}")
        raise

# Function to simulate "thinking" process
def thinking_process(message, result_queue):
    """
    This function simulates a thinking process and puts the result in the queue
    """
    try:
        # Simulate thinking process
        logger.info(f"Thinking about: '{message}'")
        
        # Create prompt with system message
        prompt = f"""<|im_start|>system
You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses.
<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to('cpu') for k, v in inputs.items()}
        
        # Generate answer with streaming
        streamer = TextStreamer(tokenizer, result_queue)
        
        # Generate response
        model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            streamer=streamer,
            num_beams=1,
            no_repeat_ngram_size=3
        )
        
        # Signal generation is complete
        result_queue.put(None)
        
    except Exception as e:
        logger.error(f"Error in thinking process: {str(e)}")
        result_queue.put(f"I apologize, but I encountered an error while processing your request: {str(e)}")
        # Signal generation is complete
        result_queue.put(None)

# TextStreamer class for token-by-token generation
class TextStreamer:
    def __init__(self, tokenizer, queue):
        self.tokenizer = tokenizer
        self.queue = queue
        self.current_tokens = []
        
    def put(self, token_ids):
        self.current_tokens.extend(token_ids.tolist())
        text = self.tokenizer.decode(self.current_tokens, skip_special_tokens=True)
        self.queue.put(text)
        
    def end(self):
        pass

# API route for home page
@app.route('/')
def home():
    return jsonify({"message": "AI Chat API is running!"})

# API route for streaming chat responses
@app.route('/chat', methods=['POST'])
def chat():
    data = request.get_json()
    message = data.get("message", "")
    
    if not message:
        return jsonify({"error": "Message is required"}), 400
    
    try:
        def generate():
            # Create a queue for communication between threads
            result_queue = queue.Queue()
            
            # Start thinking in a separate thread
            thread = threading.Thread(target=thinking_process, args=(message, result_queue))
            thread.start()
            
            # Stream results as they become available
            previous_text = ""
            while True:
                try:
                    result = result_queue.get(block=True, timeout=30)  # 30 second timeout
                    if result is None:  # End of generation
                        break
                        
                    # Only yield the new part of the text
                    if isinstance(result, str):
                        new_part = result[len(previous_text):]
                        previous_text = result
                        if new_part:
                            yield f"data: {new_part}\n\n"
                            
                except queue.Empty:
                    # Timeout occurred
                    yield "data: [Generation timeout. The model is taking too long to respond.]\n\n"
                    break
                    
            yield "data: [DONE]\n\n"
            
        return Response(stream_with_context(generate()), mimetype='text/event-stream')
        
    except Exception as e:
        logger.error(f"Error processing chat request: {str(e)}")
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

# Simple API for non-streaming chat (fallback)
@app.route('/chat-simple', methods=['POST'])
def chat_simple():
    data = request.get_json()
    message = data.get("message", "")
    
    if not message:
        return jsonify({"error": "Message is required"}), 400
    
    try:
        # Create prompt with system message
        prompt = f"""<|im_start|>system
You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses.
<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
        
        # Handle inputs
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to('cpu') for k, v in inputs.items()}
        
        # Generate answer
        output = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=1,
            no_repeat_ngram_size=3
        )
        
        # Decode and format answer
        answer = tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Clean up the response
        if "<|im_end|>" in answer:
            answer = answer.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
        
        return jsonify({"response": answer})
        
    except Exception as e:
        logger.error(f"Error processing chat request: {str(e)}")
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

if __name__ == "__main__":
    try:
        # Initialize models at startup
        initialize_models()
        logger.info("Starting Flask application")
        app.run(host="0.0.0.0", port=7860)
    except Exception as e:
        logger.critical(f"Failed to start application: {str(e)}")