import torch import spaces import gradio as gr from threading import Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList ) MODEL_ID = "FuseAI/FuseO1-DeepSeekR1-QwQ-SkyT1-32B-Preview" DEFAULT_SYSTEM_PROMPT = """ **Role:** You are an Expert Coding Assistant. Your responses MUST follow this structured workflow: ``` [Understand]: Analyze the problem, identify constraints, and clarify objectives. [Plan]: Outline a technical methodology with numbered steps (algorithms, tools, etc.). [Reason]: Execute the plan using code snippets, equations, or logic flows. [Verify]: Validate correctness via tests, edge cases, or formal proofs. [Conclude]: Summarize results with key insights/recommendations. ``` **Rules:** 1. Use markdown code blocks for all code/equations (e.g., `python`, `javascript`, `latex`). 2. Prioritize computational thinking (e.g., "To solve X, we can model it as a graph problem because..."). 3. Structure EVERY answer using the exact tags: [Understand], [Plan], [Reason], [Verify], [Conclude]. 4. Never combine steps - keep sections distinct. 5. Use technical precision over verbose explanations. **Example Output Format:** [Understand] - Key problem: "Develop a function to find prime numbers..." - Constraints: O(n log n) time, memory < 500MB. [Plan] 1. Implement Sieve of Eratosthenes 2. Optimize memory via bitwise array 3. Handle edge case: n < 2 [Reason] ```python def count_primes(n: int) -> int: if n <= 2: return 0 sieve = [True] * n # ... (full implementation) ``` [Verify] Test Cases: - n=10 → Primes [2,3,5,7] → Output 4 ✔️ - n=1 → Output 0 ✔️ - Benchmark: 1e6 in 0.8s ✅ [Conclude] Solution achieves O(n log log n) time with bitwise compression. Recommended for large-scale prime detection ``` Always Use Code to solve your problems. """ CSS = """ .gr-chatbot { min-height: 500px; border-radius: 15px; } .special-tag { color: #2ecc71; font-weight: 600; } footer { display: none !important; } """ class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return input_ids[0][-1] == tokenizer.eos_token_id def initialize_model(): quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="cuda", quantization_config=quantization_config, torch_dtype=torch.bfloat16, trust_remote_code=True ).to("cuda") return model, tokenizer def format_response(text): return text.replace("[Understand]", '\n[Understand]\n') \ .replace("[Plan]", '\n[Plan]\n') \ .replace("[Conclude]", '\n[Conclude]\n') \ .replace("[Reason]", '\n[Reason]\n') \ .replace("[Verify]", '\n[Verify]\n') @spaces.GPU(duration=360) def generate_response(message, chat_history, system_prompt, temperature, max_tokens): # Create conversation history for model conversation = [{"role": "system", "content": system_prompt}] for user_msg, bot_msg in chat_history: conversation.extend([ {"role": "user", "content": user_msg}, {"role": "assistant", "content": bot_msg} ]) conversation.append({"role": "user", "content": message}) # Tokenize input input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Setup streaming streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, stopping_criteria=StoppingCriteriaList([StopOnTokens()]) ) # Start generation thread Thread(target=model.generate, kwargs=generate_kwargs).start() # Initialize response buffer partial_message = "" new_history = chat_history + [(message, "")] # Stream response for new_token in streamer: partial_message += new_token formatted = format_response(partial_message) new_history[-1] = (message, formatted + "▌") yield new_history # Final update without cursor new_history[-1] = (message, format_response(partial_message)) yield new_history model, tokenizer = initialize_model() with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.Markdown("""

🧠 AI Reasoning Assistant

Ask me Hatd questions

""") chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot") msg = gr.Textbox(label="Your Question", placeholder="Type your question...") with gr.Accordion("⚙️ Settings", open=False): system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions") temperature = gr.Slider(0, 1, value=0.5, label="Creativity") max_tokens = gr.Slider(128, 4096, value=2048, label="Max Response Length") clear = gr.Button("Clear History") msg.submit( generate_response, [msg, chatbot, system_prompt, temperature, max_tokens], [chatbot], show_progress=True ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.queue().launch()