import os import re import logging import base64 from threading import Thread from typing import List import torch import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Enable faster CUDNN kernels torch.backends.cudnn.benchmark = True # Model setup model_name = "smirki/UIGEN-T1.1-Qwen-7B" logger.info("Loading model and tokenizer...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) # Prompt templates s1_inference_prompt_think_only = """<|im_start|>user {question}<|im_end|> <|im_start|>assistant <|im_start|>think """ # Constants THINK_MAX_NEW_TOKENS = 8000 ANSWER_MAX_NEW_TOKENS = 8000 def initialize_gen_kwargs(): return { "max_new_tokens": 512, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.05, "pad_token_id": tokenizer.pad_token_id, "use_cache": True, } def extract_html_code_block(text: str) -> str: pattern = r"html\s*(.*?)\s*" match = re.search(pattern, text, re.DOTALL) return match.group(1).strip() if match else text.strip() def send_to_sandbox(html_code: str) -> str: encoded_html = base64.b64encode(html_code.encode("utf-8")).decode("utf-8") data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" return f'' @spaces.GPU(duration=120) # Allocate GPU for 120 seconds def generate_response(text: str, history: List[List[str]]): history.append([text, ""]) logger.info(f"New chat prompt: {text}") # Think Phase formatted_think_prompt = s1_inference_prompt_think_only.format(question=text) input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device) attention_mask_think = (input_ids_think != tokenizer.pad_token_id).to(model.device) gen_kwargs_think = initialize_gen_kwargs() gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs_think["streamer"] = think_streamer full_think = "" try: with torch.inference_mode(): thread = Thread( target=lambda: model.generate( input_ids=input_ids_think, attention_mask=attention_mask_think, **gen_kwargs_think ) ) thread.start() for new_text in think_streamer: full_think += new_text history[-1][1] = f"<|im_start|>think\n{full_think.strip()}" yield history thread.join() except Exception as e: logger.error(f"Error during think phase: {e}") history[-1][1] = f"Error in think phase: {str(e)}" yield history return # Answer Phase new_prompt = ( formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n" ) input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device) attention_mask_answer = (input_ids_answer != tokenizer.pad_token_id).to(model.device) gen_kwargs_answer = initialize_gen_kwargs() gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs_answer["streamer"] = answer_streamer full_answer = "" try: with torch.inference_mode(): thread = Thread( target=lambda: model.generate( input_ids=input_ids_answer, attention_mask=attention_mask_answer, **gen_kwargs_answer ) ) thread.start() for new_text in answer_streamer: full_answer += new_text display_text = ( f"<|im_start|>think\n{full_think.strip()}\n\n" f"<|im_start|>answer\n{full_answer.strip()}" ) history[-1][1] = display_text yield history thread.join() except Exception as e: logger.error(f"Error during answer phase: {e}") history[-1][1] = f"Error in answer phase: {str(e)}" yield history return def process_artifact(history: List[List[str]]): if not history or not history[-1][1]: return "" html_code = extract_html_code_block(history[-1][1]) return send_to_sandbox(html_code) def clear_chat(): return [], "", "" # Gradio UI css = """ .left_header { display: flex; flex-direction: column; justify-content: center; align-items: center; } .right_panel { margin-top: 16px; border: 1px solid #BFBFC4; border-radius: 8px; overflow: hidden; } .render_header { height: 30px; width: 100%; padding: 5px 16px; background-color: #f5f5f5; } .header_btn { display: inline-block; height: 10px; width: 10px; border-radius: 50%; margin-right: 4px; } .render_header > .header_btn:nth-child(1) { background-color: #f5222d; } .render_header > .header_btn:nth-child(2) { background-color: #faad14; } .render_header > .header_btn:nth-child(3) { background-color: #52c41a; } .right_content { height: 920px; display: flex; flex-direction: column; justify-content: center; align-items: center; } .html_content { width: 100%; height: 920px; } """ svg_logo = """ """ def launch_app(): with gr.Blocks(title=model_name.split('/')[-1], css=css) as demo: gr.HTML(f"""
{svg_logo}

{model_name.split('/')[-1]} - Chat + Artifacts

(Two-phase chain-of-thought with artifact extraction)

""") with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( label="Chat", height=520, show_copy_button=True ) with gr.Row(): text_input = gr.Textbox( label="Prompt", placeholder="Enter your query...", lines=1 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=6): gr.HTML('
') artifact_html = gr.HTML( value="", elem_classes="html_content" ) # Event handlers text_input.submit( fn=generate_response, inputs=[text_input, chatbot], outputs=chatbot ).then( fn=lambda: "", outputs=text_input ).then( fn=process_artifact, inputs=[chatbot], outputs=artifact_html ) submit_btn.click( fn=generate_response, inputs=[text_input, chatbot], outputs=chatbot ).then( fn=lambda: "", outputs=text_input ).then( fn=process_artifact, inputs=[chatbot], outputs=artifact_html ) clear_btn.click( fn=clear_chat, outputs=[chatbot, text_input, artifact_html] ) return demo if __name__ == "__main__": logger.info("Launching Gradio demo...") demo = launch_app() demo.queue().launch(server_name="0.0.0.0", share=True)