mike23415 commited on
Commit
c5dd812
·
verified ·
1 Parent(s): 7625bb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -150
app.py CHANGED
@@ -1,166 +1,188 @@
1
  import os
2
  import time
3
- import json
4
- import gc
5
- from pathlib import Path
6
- from flask import Flask, request, jsonify, Response
7
- from flask_cors import CORS
8
  import torch
 
 
 
 
9
 
10
- # Caching
11
- cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache'))
12
- cache_dir.mkdir(parents=True, exist_ok=True)
13
-
14
  app = Flask(__name__)
15
  CORS(app)
16
 
17
- MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
18
- MAX_NEW_TOKENS = 256
19
- DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
20
-
21
- tokenizer = None
22
- model = None
 
23
 
 
 
24
  def load_model():
25
- global tokenizer, model
26
- if tokenizer and model:
27
- return True
28
-
29
- try:
30
- from transformers import AutoTokenizer, AutoModelForCausalLM
31
-
32
- print(f"Loading {MODEL_NAME} on {DEVICE}...")
33
- hf_token = os.environ.get("HF_TOKEN")
34
- token_kwargs = {"token": hf_token} if hf_token else {}
35
-
36
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=str(cache_dir), trust_remote_code=False, **token_kwargs)
37
-
38
- model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_NAME,
40
- cache_dir=str(cache_dir),
41
- torch_dtype=torch.bfloat16 if DEVICE == "cpu" else torch.float16,
42
- low_cpu_mem_usage=True,
43
- trust_remote_code=False,
44
- **token_kwargs
45
- )
46
-
47
- if DEVICE == "cuda":
48
- model = model.to("cuda")
49
-
50
- print("✅ Phi-3 Mini loaded successfully!")
51
- return True
52
- except Exception as e:
53
- print(f"❌ Model load failed: {e}")
54
- return False
55
-
56
- def stream_generator(prompt):
57
- if not load_model():
58
- yield json.dumps({"type": "error", "content": "Model failed to load"}) + '\n'
59
- return
60
-
61
- thinking = ["🧠 Thinking...", "🤖 Preparing answer..."]
62
- for step in thinking:
63
- yield json.dumps({"type": "thinking", "content": step}) + '\n'
64
- time.sleep(0.4)
65
-
66
- try:
67
- formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
68
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE if DEVICE == "cuda" else "cpu")
69
-
70
  with torch.no_grad():
71
- output = model.generate(
72
- **inputs,
73
- max_new_tokens=MAX_NEW_TOKENS,
74
- temperature=0.7,
75
- top_p=0.9,
 
76
  do_sample=True,
77
  pad_token_id=tokenizer.eos_token_id
78
  )
79
-
80
- new_tokens = output[0][inputs.input_ids.shape[-1]:]
81
- generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
82
-
83
- for i in range(0, len(generated_text), 12):
84
- yield json.dumps({"type": "answer", "content": generated_text[i:i+12]}) + '\n'
85
- time.sleep(0.03)
86
-
87
- except Exception as e:
88
- yield json.dumps({"type": "error", "content": str(e)}) + '\n'
89
-
90
- yield json.dumps({"type": "complete"}) + '\n'
91
- if DEVICE == "cuda":
92
- torch.cuda.empty_cache()
93
- gc.collect()
94
-
95
- @app.route('/stream_chat', methods=['POST'])
96
- def stream_chat():
97
- data = request.get_json()
98
- prompt = data.get('prompt', '').strip()
99
- if not prompt:
100
- return jsonify({"error": "Empty prompt"}), 400
101
-
102
- return Response(
103
- stream_generator(prompt),
104
- mimetype='text/event-stream',
105
- headers={
106
- 'Cache-Control': 'no-cache',
107
- 'X-Accel-Buffering': 'no',
108
- 'Connection': 'keep-alive'
109
- }
110
- )
111
-
112
- @app.route('/chat', methods=['POST'])
 
 
113
  def chat():
114
- if not load_model():
115
- return jsonify({"error": "Model failed to load"}), 500
116
-
117
- data = request.get_json()
118
- prompt = data.get('prompt', '').strip()
119
- if not prompt:
120
- return jsonify({"error": "Empty prompt"}), 400
121
-
122
  try:
123
- formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
124
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE if DEVICE == "cuda" else "cpu")
125
-
126
- with torch.no_grad():
127
- output = model.generate(
128
- **inputs,
129
- max_new_tokens=MAX_NEW_TOKENS,
130
- temperature=0.7,
131
- top_p=0.9,
132
- do_sample=True,
133
- pad_token_id=tokenizer.eos_token_id
134
- )
135
-
136
- response_text = tokenizer.decode(output[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
137
- return jsonify({"response": response_text})
138
- except Exception as e:
139
- return jsonify({"error": str(e)}), 500
140
-
141
- @app.route('/health')
142
- def health():
143
- import psutil
144
- model_loaded = model is not None
145
- return jsonify({
146
- "status": "ok" if model_loaded else "waiting",
147
- "model_loaded": model_loaded,
148
- "memory": f"{psutil.virtual_memory().used/1024**3:.2f}GB used",
149
- "device": DEVICE,
150
- })
151
-
152
- @app.route('/')
153
- def home():
154
- return jsonify({
155
- "service": "Phi-3 Mini Chat API",
156
- "status": "online",
157
- "endpoints": {
158
- "POST /chat": "Single-response",
159
- "POST /stream_chat": "Streaming chat"
160
  }
161
- })
162
-
163
- if __name__ == '__main__':
164
- if os.getenv('PRELOAD_MODEL', 'false') == 'true':
165
- load_model()
166
- app.run(host='0.0.0.0', port=int(os.environ.get("PORT", 5000)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
 
 
 
 
 
3
  import torch
4
+ from flask import Flask, request, jsonify
5
+ from flask_cors import CORS
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import gradio as gr
8
 
9
+ # Initialize Flask app
 
 
 
10
  app = Flask(__name__)
11
  CORS(app)
12
 
13
+ # Global variables
14
+ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
15
+ MAX_LENGTH = 2048
16
+ MAX_NEW_TOKENS = 512
17
+ TEMPERATURE = 0.7
18
+ TOP_P = 0.9
19
+ THINKING_STEPS = 3 # Number of thinking steps
20
 
21
+ # Load model and tokenizer
22
+ @app.before_first_request
23
  def load_model():
24
+ global model, tokenizer
25
+
26
+ print(f"Loading model: {MODEL_ID}")
27
+
28
+ # Load tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
30
+
31
+ # Load model with optimizations for limited resources
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_ID,
34
+ device_map="auto",
35
+ torch_dtype=torch.bfloat16,
36
+ load_in_4bit=True,
37
+ )
38
+
39
+ print("Model and tokenizer loaded successfully!")
40
+
41
+ # Helper function for step-by-step thinking
42
+ def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
43
+ # Initialize conversation with prompt
44
+ full_prompt = prompt
45
+
46
+ # Add thinking prefix
47
+ thinking_prompt = full_prompt + "\n\nLet me think through this step by step:"
48
+
49
+ # Generate thinking steps
50
+ thinking_output = ""
51
+ for step in range(thinking_steps):
52
+ # Generate step i of thinking
53
+ inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt").to(model.device)
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  with torch.no_grad():
56
+ outputs = model.generate(
57
+ inputs["input_ids"],
58
+ max_length=MAX_LENGTH,
59
+ max_new_tokens=MAX_NEW_TOKENS // thinking_steps,
60
+ temperature=TEMPERATURE,
61
+ top_p=TOP_P,
62
  do_sample=True,
63
  pad_token_id=tokenizer.eos_token_id
64
  )
65
+
66
+ # Extract only new tokens
67
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
68
+ thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
69
+
70
+ # Add this step to our thinking output
71
+ thinking_output += f"\n\nStep {step+1}: {thinking_step_output}"
72
+
73
+ # Now generate final answer based on the thinking
74
+ final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:"
75
+
76
+ inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
77
+ with torch.no_grad():
78
+ outputs = model.generate(
79
+ inputs["input_ids"],
80
+ max_length=MAX_LENGTH,
81
+ max_new_tokens=MAX_NEW_TOKENS // 2,
82
+ temperature=TEMPERATURE,
83
+ top_p=TOP_P,
84
+ do_sample=True,
85
+ pad_token_id=tokenizer.eos_token_id
86
+ )
87
+
88
+ # Extract only the new tokens (the answer)
89
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
90
+ answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
91
+
92
+ # Return thinking process and final answer
93
+ return {
94
+ "thinking": thinking_output,
95
+ "answer": answer,
96
+ "full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer
97
+ }
98
+
99
+ # API endpoint for chat
100
+ @app.route('/api/chat', methods=['POST'])
101
  def chat():
 
 
 
 
 
 
 
 
102
  try:
103
+ data = request.json
104
+ prompt = data.get('prompt', '')
105
+ include_thinking = data.get('include_thinking', False)
106
+
107
+ if not prompt:
108
+ return jsonify({'error': 'Prompt is required'}), 400
109
+
110
+ start_time = time.time()
111
+ response = generate_with_thinking(prompt)
112
+ end_time = time.time()
113
+
114
+ result = {
115
+ 'answer': response['answer'],
116
+ 'time_taken': round(end_time - start_time, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  }
118
+
119
+ # Include thinking steps if requested
120
+ if include_thinking:
121
+ result['thinking'] = response['thinking']
122
+
123
+ return jsonify(result)
124
+
125
+ except Exception as e:
126
+ return jsonify({'error': str(e)}), 500
127
+
128
+ # Simple health check endpoint
129
+ @app.route('/health', methods=['GET'])
130
+ def health_check():
131
+ return jsonify({'status': 'ok'})
132
+
133
+ # Gradio Web UI
134
+ def create_ui():
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown("# BitNet Specialist Chatbot with Step-by-Step Thinking")
137
+
138
+ with gr.Row():
139
+ with gr.Column():
140
+ input_text = gr.Textbox(
141
+ label="Your question",
142
+ placeholder="Ask me anything...",
143
+ lines=3
144
+ )
145
+
146
+ with gr.Row():
147
+ submit_btn = gr.Button("Submit")
148
+ clear_btn = gr.Button("Clear")
149
+
150
+ show_thinking = gr.Checkbox(label="Show thinking steps", value=True)
151
+
152
+ with gr.Column():
153
+ thinking_output = gr.Markdown(label="Thinking Process", visible=True)
154
+ answer_output = gr.Markdown(label="Final Answer")
155
+
156
+ def respond(question, show_thinking):
157
+ if not question.strip():
158
+ return "", "Please enter a question"
159
+
160
+ response = generate_with_thinking(question)
161
+
162
+ if show_thinking:
163
+ return response["thinking"], response["answer"]
164
+ else:
165
+ return "", response["answer"]
166
+
167
+ submit_btn.click(
168
+ respond,
169
+ inputs=[input_text, show_thinking],
170
+ outputs=[thinking_output, answer_output]
171
+ )
172
+
173
+ clear_btn.click(
174
+ lambda: ("", "", ""),
175
+ inputs=None,
176
+ outputs=[input_text, thinking_output, answer_output]
177
+ )
178
+
179
+ return demo
180
+
181
+ # Create Gradio UI and launch the app
182
+ if __name__ == "__main__":
183
+ # Load model at startup for Gradio
184
+ load_model()
185
+
186
+ # Create and launch Gradio interface
187
+ demo = create_ui()
188
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)