|
import os |
|
import time |
|
import torch |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import gradio as gr |
|
|
|
|
|
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" |
|
MAX_LENGTH = 2048 |
|
MAX_NEW_TOKENS = 512 |
|
TEMPERATURE = 0.7 |
|
TOP_P = 0.9 |
|
THINKING_STEPS = 3 |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
def load_model_and_tokenizer(): |
|
global model, tokenizer |
|
|
|
if model is not None and tokenizer is not None: |
|
return |
|
|
|
print(f"Loading model: {MODEL_ID}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
use_fast=True, |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
load_in_4bit=True, |
|
) |
|
|
|
print("Model and tokenizer loaded successfully!") |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS): |
|
|
|
full_prompt = prompt |
|
|
|
|
|
thinking_prompt = full_prompt + "\n\nLet me think through this step by step:" |
|
|
|
|
|
thinking_output = "" |
|
for step in range(thinking_steps): |
|
|
|
inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=MAX_LENGTH, |
|
max_new_tokens=MAX_NEW_TOKENS // thinking_steps, |
|
temperature=TEMPERATURE, |
|
top_p=TOP_P, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
|
thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
|
thinking_output += f"\n\nStep {step+1}: {thinking_step_output}" |
|
|
|
|
|
final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:" |
|
|
|
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=MAX_LENGTH, |
|
max_new_tokens=MAX_NEW_TOKENS // 2, |
|
temperature=TEMPERATURE, |
|
top_p=TOP_P, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
|
answer = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
|
return { |
|
"thinking": thinking_output, |
|
"answer": answer, |
|
"full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer |
|
} |
|
|
|
|
|
@app.route('/api/chat', methods=['POST']) |
|
def chat(): |
|
try: |
|
|
|
if model is None or tokenizer is None: |
|
load_model_and_tokenizer() |
|
|
|
data = request.json |
|
prompt = data.get('prompt', '') |
|
include_thinking = data.get('include_thinking', False) |
|
|
|
if not prompt: |
|
return jsonify({'error': 'Prompt is required'}), 400 |
|
|
|
start_time = time.time() |
|
response = generate_with_thinking(prompt) |
|
end_time = time.time() |
|
|
|
result = { |
|
'answer': response['answer'], |
|
'time_taken': round(end_time - start_time, 2) |
|
} |
|
|
|
|
|
if include_thinking: |
|
result['thinking'] = response['thinking'] |
|
|
|
return jsonify(result) |
|
|
|
except Exception as e: |
|
import traceback |
|
print(f"Error in chat endpoint: {str(e)}") |
|
print(traceback.format_exc()) |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
return jsonify({'status': 'ok'}) |
|
|
|
|
|
def create_ui(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# BitNet Specialist Chatbot with Step-by-Step Thinking") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox( |
|
label="Your question", |
|
placeholder="Ask me anything...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear") |
|
|
|
show_thinking = gr.Checkbox(label="Show thinking steps", value=True) |
|
|
|
with gr.Column(): |
|
thinking_output = gr.Markdown(label="Thinking Process", visible=True) |
|
answer_output = gr.Markdown(label="Final Answer") |
|
|
|
def respond(question, show_thinking): |
|
if not question.strip(): |
|
return "", "Please enter a question" |
|
|
|
|
|
if model is None or tokenizer is None: |
|
load_model_and_tokenizer() |
|
|
|
response = generate_with_thinking(question) |
|
|
|
if show_thinking: |
|
return response["thinking"], response["answer"] |
|
else: |
|
return "", response["answer"] |
|
|
|
submit_btn.click( |
|
respond, |
|
inputs=[input_text, show_thinking], |
|
outputs=[thinking_output, answer_output] |
|
) |
|
|
|
clear_btn.click( |
|
lambda: ("", "", ""), |
|
inputs=None, |
|
outputs=[input_text, thinking_output, answer_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
load_model_and_tokenizer() |
|
|
|
|
|
demo = create_ui() |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |