import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread import spaces class ChatInterface: def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) def format_chat_prompt(self, message, history, system_message): messages = [{"role": "system", "content": system_message}] for user_msg, assistant_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) # Format messages according to model's expected chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt @spaces.GPU def generate_response( self, message, history, system_message, max_tokens, temperature, top_p, ): prompt = self.format_chat_prompt(message, history, system_message) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # Setup streamer streamer = TextIteratorStreamer( self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) # Generate in a separate thread to enable streaming generation_kwargs = dict( inputs=inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Stream the response response = "" for new_text in streamer: response += new_text yield response def create_demo(): chat_interface = ChatInterface() demo = gr.ChatInterface( chat_interface.generate_response, additional_inputs=[ gr.Textbox( value="You are a friendly Chatbot.", label="System message" ), gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()