Thinking / app.py
mike23415's picture
Update app.py
59219bf verified
raw
history blame
6.55 kB
import os
import time
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
# Global variables
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
MAX_LENGTH = 2048
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.9
THINKING_STEPS = 3 # Number of thinking steps
# Global variables for model and tokenizer
model = None
tokenizer = None
# Function to load model and tokenizer
def load_model_and_tokenizer():
global model, tokenizer
if model is not None and tokenizer is not None:
return
print(f"Loading model: {MODEL_ID}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
use_fast=True,
)
# Load model with optimizations for limited resources
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_4bit=True,
)
print("Model and tokenizer loaded successfully!")
# Initialize Flask app
app = Flask(__name__)
CORS(app)
# Helper function for step-by-step thinking
def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS):
# Initialize conversation with prompt
full_prompt = prompt
# Add thinking prefix
thinking_prompt = full_prompt + "\n\nLet me think through this step by step:"
# Generate thinking steps
thinking_output = ""
for step in range(thinking_steps):
# Generate step i of thinking
inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=MAX_LENGTH,
max_new_tokens=MAX_NEW_TOKENS // thinking_steps,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Extract only new tokens
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Add this step to our thinking output
thinking_output += f"\n\nStep {step+1}: {thinking_step_output}"
# Now generate final answer based on the thinking
final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:"
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=MAX_LENGTH,
max_new_tokens=MAX_NEW_TOKENS // 2,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Extract only the new tokens (the answer)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Return thinking process and final answer
return {
"thinking": thinking_output,
"answer": answer,
"full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer
}
# API endpoint for chat
@app.route('/api/chat', methods=['POST'])
def chat():
try:
# Ensure model is loaded
if model is None or tokenizer is None:
load_model_and_tokenizer()
data = request.json
prompt = data.get('prompt', '')
include_thinking = data.get('include_thinking', False)
if not prompt:
return jsonify({'error': 'Prompt is required'}), 400
start_time = time.time()
response = generate_with_thinking(prompt)
end_time = time.time()
result = {
'answer': response['answer'],
'time_taken': round(end_time - start_time, 2)
}
# Include thinking steps if requested
if include_thinking:
result['thinking'] = response['thinking']
return jsonify(result)
except Exception as e:
import traceback
print(f"Error in chat endpoint: {str(e)}")
print(traceback.format_exc())
return jsonify({'error': str(e)}), 500
# Simple health check endpoint
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'ok'})
# Gradio Web UI
def create_ui():
with gr.Blocks() as demo:
gr.Markdown("# BitNet Specialist Chatbot with Step-by-Step Thinking")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Your question",
placeholder="Ask me anything...",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
show_thinking = gr.Checkbox(label="Show thinking steps", value=True)
with gr.Column():
thinking_output = gr.Markdown(label="Thinking Process", visible=True)
answer_output = gr.Markdown(label="Final Answer")
def respond(question, show_thinking):
if not question.strip():
return "", "Please enter a question"
# Ensure model is loaded
if model is None or tokenizer is None:
load_model_and_tokenizer()
response = generate_with_thinking(question)
if show_thinking:
return response["thinking"], response["answer"]
else:
return "", response["answer"]
submit_btn.click(
respond,
inputs=[input_text, show_thinking],
outputs=[thinking_output, answer_output]
)
clear_btn.click(
lambda: ("", "", ""),
inputs=None,
outputs=[input_text, thinking_output, answer_output]
)
return demo
# Create Gradio UI and launch the app
if __name__ == "__main__":
# Load model at startup
load_model_and_tokenizer()
# Create and launch Gradio interface
demo = create_ui()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)