mike23415 commited on
Commit
a7b29f8
·
verified ·
1 Parent(s): 4ef3090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -186
app.py CHANGED
@@ -1,221 +1,289 @@
 
 
1
  import os
2
- import time
3
  import torch
4
- import warnings
5
- from flask import Flask, request, jsonify
6
- from flask_cors import CORS
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, logging
8
- import gradio as gr
 
9
 
10
- # Suppress warnings
11
- warnings.filterwarnings("ignore")
12
- logging.set_verbosity_error()
 
 
 
 
13
 
14
- # Global variables
15
- # Using phi-2 which is a smaller model that can run on CPU
16
- MODEL_ID = "microsoft/phi-2"
17
- MAX_LENGTH = 2048
18
- MAX_NEW_TOKENS = 512
19
- TEMPERATURE = 0.7
20
- TOP_P = 0.9
21
- THINKING_STEPS = 3 # Number of thinking steps
22
 
23
- # Global variables for model and tokenizer
24
- model = None
 
 
 
 
 
25
  tokenizer = None
 
26
 
27
- # Function to load model and tokenizer
28
- def load_model_and_tokenizer():
29
- global model, tokenizer
30
-
31
- if model is not None and tokenizer is not None:
32
- return
33
-
34
- print(f"Loading model: {MODEL_ID}")
35
-
36
  try:
37
- # Load tokenizer
 
 
 
 
 
 
38
  tokenizer = AutoTokenizer.from_pretrained(
39
- MODEL_ID,
40
- use_fast=True,
41
- trust_remote_code=True
42
  )
43
 
44
- # Load model with CPU optimizations - removed 4-bit quantization
 
45
  model = AutoModelForCausalLM.from_pretrained(
46
- MODEL_ID,
47
- low_cpu_mem_usage=True, # Optimize for CPU usage
48
- torch_dtype=torch.float32, # Use float32 instead of bfloat16 for better CPU compatibility
49
- trust_remote_code=True
 
50
  )
51
 
52
- print("Model and tokenizer loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
- import traceback
55
- print(f"Error loading model: {str(e)}")
56
- print(traceback.format_exc())
57
  raise
58
 
59
- # Initialize Flask app
60
- app = Flask(__name__)
61
- CORS(app)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Helper function for step-by-step thinking
64
- def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
65
- # Initialize conversation with prompt
66
- full_prompt = prompt
67
-
68
- # Add thinking prefix
69
- thinking_prompt = full_prompt + "\n\nLet me think through this step by step:"
70
-
71
- # Generate thinking steps
72
- thinking_output = ""
73
- for step in range(thinking_steps):
74
- # Generate step i of thinking
75
- inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt")
76
-
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- inputs["input_ids"],
80
- max_length=MAX_LENGTH,
81
- max_new_tokens=MAX_NEW_TOKENS // thinking_steps,
82
- temperature=TEMPERATURE,
83
- top_p=TOP_P,
 
 
 
 
 
84
  do_sample=True,
85
- pad_token_id=tokenizer.eos_token_id
 
 
 
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Extract only new tokens
89
- new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
90
- thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
91
 
92
- # Add this step to our thinking output
93
- thinking_output += f"\n\nStep {step+1}: {thinking_step_output}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Now generate final answer based on the thinking
96
- final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:"
97
 
98
- inputs = tokenizer(final_prompt, return_tensors="pt")
99
- with torch.no_grad():
100
- outputs = model.generate(
101
- inputs["input_ids"],
102
- max_length=MAX_LENGTH,
103
- max_new_tokens=MAX_NEW_TOKENS // 2,
104
- temperature=TEMPERATURE,
105
- top_p=TOP_P,
106
- do_sample=True,
107
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Extract only the new tokens (the answer)
111
- new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
112
- answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
113
 
114
- # Return thinking process and final answer
115
- return {
116
- "thinking": thinking_output,
117
- "answer": answer,
118
- "full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer
119
- }
120
-
121
- # API endpoint for chat
122
- @app.route('/api/chat', methods=['POST'])
123
- def chat():
124
  try:
125
- # Ensure model is loaded
126
- if model is None or tokenizer is None:
127
- load_model_and_tokenizer()
128
-
129
- data = request.json
130
- prompt = data.get('prompt', '')
131
- include_thinking = data.get('include_thinking', False)
 
132
 
133
- if not prompt:
134
- return jsonify({'error': 'Prompt is required'}), 400
 
135
 
136
- start_time = time.time()
137
- response = generate_with_thinking(prompt)
138
- end_time = time.time()
 
 
 
 
 
 
 
139
 
140
- result = {
141
- 'answer': response['answer'],
142
- 'time_taken': round(end_time - start_time, 2)
143
- }
144
 
145
- # Include thinking steps if requested
146
- if include_thinking:
147
- result['thinking'] = response['thinking']
148
-
149
- return jsonify(result)
150
-
151
- except Exception as e:
152
- import traceback
153
- print(f"Error in chat endpoint: {str(e)}")
154
- print(traceback.format_exc())
155
- return jsonify({'error': str(e)}), 500
156
-
157
- # Simple health check endpoint
158
- @app.route('/health', methods=['GET'])
159
- def health_check():
160
- return jsonify({'status': 'ok'})
161
-
162
- # Gradio Web UI
163
- def create_ui():
164
- with gr.Blocks() as demo:
165
- gr.Markdown("# AI Assistant with Step-by-Step Thinking")
166
-
167
- with gr.Row():
168
- with gr.Column():
169
- input_text = gr.Textbox(
170
- label="Your question",
171
- placeholder="Ask me anything...",
172
- lines=3
173
- )
174
-
175
- with gr.Row():
176
- submit_btn = gr.Button("Submit")
177
- clear_btn = gr.Button("Clear")
178
-
179
- show_thinking = gr.Checkbox(label="Show thinking steps", value=True)
180
-
181
- with gr.Column():
182
- thinking_output = gr.Markdown(label="Thinking Process", visible=True)
183
- answer_output = gr.Markdown(label="Final Answer")
184
-
185
- def respond(question, show_thinking):
186
- if not question.strip():
187
- return "", "Please enter a question"
188
-
189
- # Ensure model is loaded
190
- if model is None or tokenizer is None:
191
- load_model_and_tokenizer()
192
-
193
- response = generate_with_thinking(question)
194
-
195
- if show_thinking:
196
- return response["thinking"], response["answer"]
197
- else:
198
- return "", response["answer"]
199
-
200
- submit_btn.click(
201
- respond,
202
- inputs=[input_text, show_thinking],
203
- outputs=[thinking_output, answer_output]
204
- )
205
 
206
- clear_btn.click(
207
- lambda: ("", "", ""),
208
- inputs=None,
209
- outputs=[input_text, thinking_output, answer_output]
210
- )
211
-
212
- return demo
213
 
214
- # Create Gradio UI and launch the app
215
  if __name__ == "__main__":
216
- # Load model at startup
217
- load_model_and_tokenizer()
218
-
219
- # Create and launch Gradio interface
220
- demo = create_ui()
221
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
1
+ from flask import Flask, request, jsonify, Response, stream_with_context
2
+ from flask_cors import CORS
3
  import os
 
4
  import torch
5
+ import time
6
+ import logging
7
+ import threading
8
+ import queue
9
+ import json
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
+ # Set up logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
16
+ datefmt='%Y-%m-%d %H:%M:%S'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
 
20
+ # Fix caching issue on Hugging Face Spaces
21
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
22
+ os.environ["HF_HOME"] = "/tmp"
23
+ os.environ["XDG_CACHE_HOME"] = "/tmp"
 
 
 
 
24
 
25
+ app = Flask(__name__)
26
+ CORS(app) # Enable CORS for all routes
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ logger.info(f"Using device: {device}")
30
+
31
+ # Global model variables
32
  tokenizer = None
33
+ model = None
34
 
35
+ # Initialize models once on startup
36
+ def initialize_models():
37
+ global tokenizer, model
 
 
 
 
 
 
38
  try:
39
+ logger.info("Loading language model...")
40
+
41
+ # You can change the model here if needed
42
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct" # Good balance of quality and speed for CPU
43
+
44
+ # Load tokenizer with caching
45
+ logger.info(f"Loading tokenizer: {model_name}")
46
  tokenizer = AutoTokenizer.from_pretrained(
47
+ model_name,
48
+ use_fast=True # Use the fast tokenizers when available
 
49
  )
50
 
51
+ # Load model with optimizations for CPU
52
+ logger.info(f"Loading model: {model_name}")
53
  model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch.float16, # Use float16 for lower memory
56
+ device_map="cpu", # Explicitly set to CPU
57
+ low_cpu_mem_usage=True, # Optimize memory loading
58
+ offload_folder="offload" # Use disk offloading if needed
59
  )
60
 
61
+ # Handle padding tokens
62
+ if tokenizer.pad_token is None:
63
+ logger.info("Setting pad token to EOS token")
64
+ tokenizer.pad_token = tokenizer.eos_token
65
+ model.config.pad_token_id = model.config.eos_token_id
66
+
67
+ # Set up model configuration for better generation
68
+ model.config.do_sample = True # Enable sampling
69
+ model.config.temperature = 0.7 # Default temperature
70
+ model.config.top_p = 0.9 # Default top_p
71
+
72
+ logger.info("Models initialized successfully")
73
  except Exception as e:
74
+ logger.error(f"Error initializing models: {str(e)}")
 
 
75
  raise
76
 
77
+ # Function to simulate "thinking" process
78
+ def thinking_process(message, result_queue):
79
+ """
80
+ This function simulates a thinking process and puts the result in the queue.
81
+ It includes both an explicit thinking stage and then a generation stage.
82
+ """
83
+ try:
84
+ # Simulate explicit thinking stage
85
+ logger.info(f"Thinking about: '{message}'")
86
+
87
+ # Pause to simulate deeper thinking (helps with more complex queries)
88
+ time.sleep(1)
89
+
90
+ # Create thoughtful prompt with system message and thinking instructions
91
+ prompt = f"""<|im_start|>system
92
+ You are a helpful, friendly, and thoughtful AI assistant.
93
+ Let's approach the user's request step by step:
94
+ 1. First, understand what the user is really asking
95
+ 2. Consider the key aspects we need to address
96
+ 3. Think about the best way to structure the response
97
+ 4. Provide clear, accurate information in a conversational tone
98
 
99
+ Always think carefully before responding, consider different angles, and provide thoughtful, detailed answers.
100
+ <|im_end|>
101
+ <|im_start|>user
102
+ {message}<|im_end|>
103
+ <|im_start|>assistant
104
+ """
105
+
106
+ # Handle inputs
107
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
108
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
109
+
110
+ # Generate answer with streaming
111
+ streamer = TextStreamer(tokenizer, result_queue)
112
+
113
+ # Simulate thinking first by sending some initial dots
114
+ result_queue.put("Let me think about this...")
115
+ time.sleep(0.5)
116
+
117
+ # Generate response - we use a temperature of 0.7 for more thoughtful outputs
118
+ # and top_p for nucleus sampling to avoid repetitive or generic responses
119
+ try:
120
+ model.generate(
121
+ **inputs,
122
+ max_new_tokens=512,
123
+ temperature=0.7,
124
+ top_p=0.9,
125
  do_sample=True,
126
+ streamer=streamer,
127
+ num_beams=2, # Using 2 beams helps with coherence
128
+ no_repeat_ngram_size=3,
129
+ repetition_penalty=1.2 # Discourages token repetition
130
  )
131
+ except Exception as e:
132
+ logger.error(f"Model generation error: {str(e)}")
133
+ result_queue.put(f"\n\nI apologize, but I encountered an error while processing your request.")
134
+
135
+ # Signal generation is complete
136
+ result_queue.put(None)
137
+
138
+ except Exception as e:
139
+ logger.error(f"Error in thinking process: {str(e)}")
140
+ result_queue.put(f"I apologize, but I encountered an error while processing your request: {str(e)}")
141
+ # Signal generation is complete
142
+ result_queue.put(None)
143
+
144
+ # TextStreamer class for token-by-token generation
145
+ class TextStreamer:
146
+ def __init__(self, tokenizer, queue):
147
+ self.tokenizer = tokenizer
148
+ self.queue = queue
149
+ self.current_tokens = []
150
 
151
+ def put(self, token_ids):
152
+ self.current_tokens.extend(token_ids.tolist())
153
+ text = self.tokenizer.decode(self.current_tokens, skip_special_tokens=True)
154
+ self.queue.put(text)
155
 
156
+ def end(self):
157
+ pass
158
+
159
+ # API route for home page
160
+ @app.route('/')
161
+ def home():
162
+ return jsonify({"message": "AI Chat API is running!"})
163
+
164
+ # API route for streaming chat responses
165
+ @app.route('/chat', methods=['POST', 'GET'])
166
+ def chat():
167
+ # Handle both POST JSON and GET query parameters for flexibility
168
+ if request.method == 'POST':
169
+ try:
170
+ data = request.get_json()
171
+ message = data.get("message", "")
172
+ except:
173
+ # If JSON parsing fails, try form data
174
+ message = request.form.get("message", "")
175
+ else: # GET
176
+ message = request.args.get("message", "")
177
 
178
+ if not message:
179
+ return jsonify({"error": "Message is required"}), 400
180
 
181
+ try:
182
+ def generate():
183
+ # Signal the start of streaming with headers
184
+ yield "retry: 1000\n"
185
+ yield "event: message\n"
186
+
187
+ # Show thinking indicator
188
+ yield f"data: [Thinking...]\n\n"
189
+
190
+ # Create a queue for communication between threads
191
+ result_queue = queue.Queue()
192
+
193
+ # Start thinking in a separate thread
194
+ thread = threading.Thread(target=thinking_process, args=(message, result_queue))
195
+ thread.daemon = True # Make thread die when main thread exits
196
+ thread.start()
197
+
198
+ # Stream results as they become available
199
+ previous_text = ""
200
+ while True:
201
+ try:
202
+ result = result_queue.get(block=True, timeout=30) # 30 second timeout
203
+ if result is None: # End of generation
204
+ break
205
+
206
+ # Only yield the new part of the text
207
+ if isinstance(result, str):
208
+ new_part = result[len(previous_text):]
209
+ previous_text = result
210
+ if new_part:
211
+ yield f"data: {json.dumps(new_part)}\n\n"
212
+ time.sleep(0.01) # Small delay for more natural typing effect
213
+
214
+ except queue.Empty:
215
+ # Timeout occurred
216
+ yield "data: [Generation timeout. The model is taking too long to respond.]\n\n"
217
+ break
218
+
219
+ yield "data: [DONE]\n\n"
220
+
221
+ return Response(
222
+ stream_with_context(generate()),
223
+ mimetype='text/event-stream',
224
+ headers={
225
+ 'Cache-Control': 'no-cache',
226
+ 'Connection': 'keep-alive',
227
+ 'X-Accel-Buffering': 'no' # Disable buffering for Nginx
228
+ }
229
  )
230
+
231
+ except Exception as e:
232
+ logger.error(f"Error processing chat request: {str(e)}")
233
+ return jsonify({"error": f"An error occurred: {str(e)}"}), 500
234
+
235
+ # Simple API for non-streaming chat (fallback)
236
+ @app.route('/chat-simple', methods=['POST'])
237
+ def chat_simple():
238
+ data = request.get_json()
239
+ message = data.get("message", "")
240
 
241
+ if not message:
242
+ return jsonify({"error": "Message is required"}), 400
 
243
 
 
 
 
 
 
 
 
 
 
 
244
  try:
245
+ # Create prompt with system message
246
+ prompt = f"""<|im_start|>system
247
+ You are a helpful, friendly, and thoughtful AI assistant. Think carefully and provide informative, detailed responses.
248
+ <|im_end|>
249
+ <|im_start|>user
250
+ {message}<|im_end|>
251
+ <|im_start|>assistant
252
+ """
253
 
254
+ # Handle inputs
255
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
256
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
257
 
258
+ # Generate answer
259
+ output = model.generate(
260
+ **inputs,
261
+ max_new_tokens=512,
262
+ temperature=0.7,
263
+ top_p=0.9,
264
+ do_sample=True,
265
+ num_beams=1,
266
+ no_repeat_ngram_size=3
267
+ )
268
 
269
+ # Decode and format answer
270
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
271
 
272
+ # Clean up the response
273
+ if "<|im_end|>" in answer:
274
+ answer = answer.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ return jsonify({"response": answer})
277
+
278
+ except Exception as e:
279
+ logger.error(f"Error processing chat request: {str(e)}")
280
+ return jsonify({"error": f"An error occurred: {str(e)}"}), 500
 
 
281
 
 
282
  if __name__ == "__main__":
283
+ try:
284
+ # Initialize models at startup
285
+ initialize_models()
286
+ logger.info("Starting Flask application")
287
+ app.run(host="0.0.0.0", port=7860)
288
+ except Exception as e:
289
+ logger.critical(f"Failed to start application: {str(e)}")