Update app.py
Browse files
app.py
CHANGED
@@ -1,166 +1,188 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
-
import json
|
4 |
-
import gc
|
5 |
-
from pathlib import Path
|
6 |
-
from flask import Flask, request, jsonify, Response
|
7 |
-
from flask_cors import CORS
|
8 |
import torch
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
-
cache_dir = Path(os.getenv('TRANSFORMERS_CACHE', '/app/cache'))
|
12 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
13 |
-
|
14 |
app = Flask(__name__)
|
15 |
CORS(app)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
|
|
|
|
|
24 |
def load_model():
|
25 |
-
global
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
def stream_generator(prompt):
|
57 |
-
if not load_model():
|
58 |
-
yield json.dumps({"type": "error", "content": "Model failed to load"}) + '\n'
|
59 |
-
return
|
60 |
-
|
61 |
-
thinking = ["🧠 Thinking...", "🤖 Preparing answer..."]
|
62 |
-
for step in thinking:
|
63 |
-
yield json.dumps({"type": "thinking", "content": step}) + '\n'
|
64 |
-
time.sleep(0.4)
|
65 |
-
|
66 |
-
try:
|
67 |
-
formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
|
68 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE if DEVICE == "cuda" else "cpu")
|
69 |
-
|
70 |
with torch.no_grad():
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
76 |
do_sample=True,
|
77 |
pad_token_id=tokenizer.eos_token_id
|
78 |
)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
113 |
def chat():
|
114 |
-
if not load_model():
|
115 |
-
return jsonify({"error": "Model failed to load"}), 500
|
116 |
-
|
117 |
-
data = request.get_json()
|
118 |
-
prompt = data.get('prompt', '').strip()
|
119 |
-
if not prompt:
|
120 |
-
return jsonify({"error": "Empty prompt"}), 400
|
121 |
-
|
122 |
try:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
return jsonify({"response": response_text})
|
138 |
-
except Exception as e:
|
139 |
-
return jsonify({"error": str(e)}), 500
|
140 |
-
|
141 |
-
@app.route('/health')
|
142 |
-
def health():
|
143 |
-
import psutil
|
144 |
-
model_loaded = model is not None
|
145 |
-
return jsonify({
|
146 |
-
"status": "ok" if model_loaded else "waiting",
|
147 |
-
"model_loaded": model_loaded,
|
148 |
-
"memory": f"{psutil.virtual_memory().used/1024**3:.2f}GB used",
|
149 |
-
"device": DEVICE,
|
150 |
-
})
|
151 |
-
|
152 |
-
@app.route('/')
|
153 |
-
def home():
|
154 |
-
return jsonify({
|
155 |
-
"service": "Phi-3 Mini Chat API",
|
156 |
-
"status": "online",
|
157 |
-
"endpoints": {
|
158 |
-
"POST /chat": "Single-response",
|
159 |
-
"POST /stream_chat": "Streaming chat"
|
160 |
}
|
161 |
-
|
162 |
-
|
163 |
-
if
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import time
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
from flask_cors import CORS
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
import gradio as gr
|
8 |
|
9 |
+
# Initialize Flask app
|
|
|
|
|
|
|
10 |
app = Flask(__name__)
|
11 |
CORS(app)
|
12 |
|
13 |
+
# Global variables
|
14 |
+
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
|
15 |
+
MAX_LENGTH = 2048
|
16 |
+
MAX_NEW_TOKENS = 512
|
17 |
+
TEMPERATURE = 0.7
|
18 |
+
TOP_P = 0.9
|
19 |
+
THINKING_STEPS = 3 # Number of thinking steps
|
20 |
|
21 |
+
# Load model and tokenizer
|
22 |
+
@app.before_first_request
|
23 |
def load_model():
|
24 |
+
global model, tokenizer
|
25 |
+
|
26 |
+
print(f"Loading model: {MODEL_ID}")
|
27 |
+
|
28 |
+
# Load tokenizer
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
30 |
+
|
31 |
+
# Load model with optimizations for limited resources
|
32 |
+
model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
MODEL_ID,
|
34 |
+
device_map="auto",
|
35 |
+
torch_dtype=torch.bfloat16,
|
36 |
+
load_in_4bit=True,
|
37 |
+
)
|
38 |
+
|
39 |
+
print("Model and tokenizer loaded successfully!")
|
40 |
+
|
41 |
+
# Helper function for step-by-step thinking
|
42 |
+
def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
|
43 |
+
# Initialize conversation with prompt
|
44 |
+
full_prompt = prompt
|
45 |
+
|
46 |
+
# Add thinking prefix
|
47 |
+
thinking_prompt = full_prompt + "\n\nLet me think through this step by step:"
|
48 |
+
|
49 |
+
# Generate thinking steps
|
50 |
+
thinking_output = ""
|
51 |
+
for step in range(thinking_steps):
|
52 |
+
# Generate step i of thinking
|
53 |
+
inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt").to(model.device)
|
54 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
with torch.no_grad():
|
56 |
+
outputs = model.generate(
|
57 |
+
inputs["input_ids"],
|
58 |
+
max_length=MAX_LENGTH,
|
59 |
+
max_new_tokens=MAX_NEW_TOKENS // thinking_steps,
|
60 |
+
temperature=TEMPERATURE,
|
61 |
+
top_p=TOP_P,
|
62 |
do_sample=True,
|
63 |
pad_token_id=tokenizer.eos_token_id
|
64 |
)
|
65 |
+
|
66 |
+
# Extract only new tokens
|
67 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
68 |
+
thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
69 |
+
|
70 |
+
# Add this step to our thinking output
|
71 |
+
thinking_output += f"\n\nStep {step+1}: {thinking_step_output}"
|
72 |
+
|
73 |
+
# Now generate final answer based on the thinking
|
74 |
+
final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:"
|
75 |
+
|
76 |
+
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
|
77 |
+
with torch.no_grad():
|
78 |
+
outputs = model.generate(
|
79 |
+
inputs["input_ids"],
|
80 |
+
max_length=MAX_LENGTH,
|
81 |
+
max_new_tokens=MAX_NEW_TOKENS // 2,
|
82 |
+
temperature=TEMPERATURE,
|
83 |
+
top_p=TOP_P,
|
84 |
+
do_sample=True,
|
85 |
+
pad_token_id=tokenizer.eos_token_id
|
86 |
+
)
|
87 |
+
|
88 |
+
# Extract only the new tokens (the answer)
|
89 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
90 |
+
answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
91 |
+
|
92 |
+
# Return thinking process and final answer
|
93 |
+
return {
|
94 |
+
"thinking": thinking_output,
|
95 |
+
"answer": answer,
|
96 |
+
"full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer
|
97 |
+
}
|
98 |
+
|
99 |
+
# API endpoint for chat
|
100 |
+
@app.route('/api/chat', methods=['POST'])
|
101 |
def chat():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
try:
|
103 |
+
data = request.json
|
104 |
+
prompt = data.get('prompt', '')
|
105 |
+
include_thinking = data.get('include_thinking', False)
|
106 |
+
|
107 |
+
if not prompt:
|
108 |
+
return jsonify({'error': 'Prompt is required'}), 400
|
109 |
+
|
110 |
+
start_time = time.time()
|
111 |
+
response = generate_with_thinking(prompt)
|
112 |
+
end_time = time.time()
|
113 |
+
|
114 |
+
result = {
|
115 |
+
'answer': response['answer'],
|
116 |
+
'time_taken': round(end_time - start_time, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
}
|
118 |
+
|
119 |
+
# Include thinking steps if requested
|
120 |
+
if include_thinking:
|
121 |
+
result['thinking'] = response['thinking']
|
122 |
+
|
123 |
+
return jsonify(result)
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
return jsonify({'error': str(e)}), 500
|
127 |
+
|
128 |
+
# Simple health check endpoint
|
129 |
+
@app.route('/health', methods=['GET'])
|
130 |
+
def health_check():
|
131 |
+
return jsonify({'status': 'ok'})
|
132 |
+
|
133 |
+
# Gradio Web UI
|
134 |
+
def create_ui():
|
135 |
+
with gr.Blocks() as demo:
|
136 |
+
gr.Markdown("# BitNet Specialist Chatbot with Step-by-Step Thinking")
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
with gr.Column():
|
140 |
+
input_text = gr.Textbox(
|
141 |
+
label="Your question",
|
142 |
+
placeholder="Ask me anything...",
|
143 |
+
lines=3
|
144 |
+
)
|
145 |
+
|
146 |
+
with gr.Row():
|
147 |
+
submit_btn = gr.Button("Submit")
|
148 |
+
clear_btn = gr.Button("Clear")
|
149 |
+
|
150 |
+
show_thinking = gr.Checkbox(label="Show thinking steps", value=True)
|
151 |
+
|
152 |
+
with gr.Column():
|
153 |
+
thinking_output = gr.Markdown(label="Thinking Process", visible=True)
|
154 |
+
answer_output = gr.Markdown(label="Final Answer")
|
155 |
+
|
156 |
+
def respond(question, show_thinking):
|
157 |
+
if not question.strip():
|
158 |
+
return "", "Please enter a question"
|
159 |
+
|
160 |
+
response = generate_with_thinking(question)
|
161 |
+
|
162 |
+
if show_thinking:
|
163 |
+
return response["thinking"], response["answer"]
|
164 |
+
else:
|
165 |
+
return "", response["answer"]
|
166 |
+
|
167 |
+
submit_btn.click(
|
168 |
+
respond,
|
169 |
+
inputs=[input_text, show_thinking],
|
170 |
+
outputs=[thinking_output, answer_output]
|
171 |
+
)
|
172 |
+
|
173 |
+
clear_btn.click(
|
174 |
+
lambda: ("", "", ""),
|
175 |
+
inputs=None,
|
176 |
+
outputs=[input_text, thinking_output, answer_output]
|
177 |
+
)
|
178 |
+
|
179 |
+
return demo
|
180 |
+
|
181 |
+
# Create Gradio UI and launch the app
|
182 |
+
if __name__ == "__main__":
|
183 |
+
# Load model at startup for Gradio
|
184 |
+
load_model()
|
185 |
+
|
186 |
+
# Create and launch Gradio interface
|
187 |
+
demo = create_ui()
|
188 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|