import os import time import torch from flask import Flask, request, jsonify from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr # Global variables MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" MAX_LENGTH = 2048 MAX_NEW_TOKENS = 512 TEMPERATURE = 0.7 TOP_P = 0.9 THINKING_STEPS = 3 # Number of thinking steps # Global variables for model and tokenizer model = None tokenizer = None # Function to load model and tokenizer def load_model_and_tokenizer(): global model, tokenizer if model is not None and tokenizer is not None: return print(f"Loading model: {MODEL_ID}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, use_fast=True, ) # Load model with optimizations for limited resources model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, load_in_4bit=True, ) print("Model and tokenizer loaded successfully!") # Initialize Flask app app = Flask(__name__) CORS(app) # Helper function for step-by-step thinking def generate_with_thinking(prompt, thinking_steps=THINKING_STEPS): # Initialize conversation with prompt full_prompt = prompt # Add thinking prefix thinking_prompt = full_prompt + "\n\nLet me think through this step by step:" # Generate thinking steps thinking_output = "" for step in range(thinking_steps): # Generate step i of thinking inputs = tokenizer(thinking_prompt + thinking_output, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( inputs["input_ids"], max_length=MAX_LENGTH, max_new_tokens=MAX_NEW_TOKENS // thinking_steps, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Extract only new tokens new_tokens = outputs[0][inputs["input_ids"].shape[1]:] thinking_step_output = tokenizer.decode(new_tokens, skip_special_tokens=True) # Add this step to our thinking output thinking_output += f"\n\nStep {step+1}: {thinking_step_output}" # Now generate final answer based on the thinking final_prompt = full_prompt + "\n\n" + thinking_output + "\n\nBased on this thinking, my final answer is:" inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( inputs["input_ids"], max_length=MAX_LENGTH, max_new_tokens=MAX_NEW_TOKENS // 2, temperature=TEMPERATURE, top_p=TOP_P, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Extract only the new tokens (the answer) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] answer = tokenizer.decode(new_tokens, skip_special_tokens=True) # Return thinking process and final answer return { "thinking": thinking_output, "answer": answer, "full_response": thinking_output + "\n\nBased on this thinking, my final answer is: " + answer } # API endpoint for chat @app.route('/api/chat', methods=['POST']) def chat(): try: # Ensure model is loaded if model is None or tokenizer is None: load_model_and_tokenizer() data = request.json prompt = data.get('prompt', '') include_thinking = data.get('include_thinking', False) if not prompt: return jsonify({'error': 'Prompt is required'}), 400 start_time = time.time() response = generate_with_thinking(prompt) end_time = time.time() result = { 'answer': response['answer'], 'time_taken': round(end_time - start_time, 2) } # Include thinking steps if requested if include_thinking: result['thinking'] = response['thinking'] return jsonify(result) except Exception as e: import traceback print(f"Error in chat endpoint: {str(e)}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 # Simple health check endpoint @app.route('/health', methods=['GET']) def health_check(): return jsonify({'status': 'ok'}) # Gradio Web UI def create_ui(): with gr.Blocks() as demo: gr.Markdown("# BitNet Specialist Chatbot with Step-by-Step Thinking") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Your question", placeholder="Ask me anything...", lines=3 ) with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") show_thinking = gr.Checkbox(label="Show thinking steps", value=True) with gr.Column(): thinking_output = gr.Markdown(label="Thinking Process", visible=True) answer_output = gr.Markdown(label="Final Answer") def respond(question, show_thinking): if not question.strip(): return "", "Please enter a question" # Ensure model is loaded if model is None or tokenizer is None: load_model_and_tokenizer() response = generate_with_thinking(question) if show_thinking: return response["thinking"], response["answer"] else: return "", response["answer"] submit_btn.click( respond, inputs=[input_text, show_thinking], outputs=[thinking_output, answer_output] ) clear_btn.click( lambda: ("", "", ""), inputs=None, outputs=[input_text, thinking_output, answer_output] ) return demo # Create Gradio UI and launch the app if __name__ == "__main__": # Load model at startup load_model_and_tokenizer() # Create and launch Gradio interface demo = create_ui() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)