import gradio as gr import json from pathlib import Path import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Default system prompt for the chat interface DEFAULT_SYSTEM_PROMPT = """You are DeepThink, a helpful and knowledgeable AI assistant. You aim to provide accurate, informative, and engaging responses while maintaining a professional and friendly demeanor.""" class ChatInterface: """Main chat interface handler with memory and parameter management""" def __init__(self): """Initialize the chat interface with default settings""" self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained(self.model_name) self.chat_history = [] self.system_prompt = DEFAULT_SYSTEM_PROMPT def load_context_from_json(self, file_obj): """Load additional context from a JSON file""" if file_obj is None: return "No file uploaded", self.system_prompt try: content = json.load(file_obj) if "system_prompt" in content: self.system_prompt = content["system_prompt"] return "Context loaded successfully!", self.system_prompt except Exception as e: return f"Error loading context: {str(e)}", self.system_prompt def generate_response(self, message, temperature, max_length, top_p, presence_penalty, frequency_penalty): """Generate AI response with given parameters""" # Format the input with system prompt and chat history conversation = f"System: {self.system_prompt}\n\n" for msg in self.chat_history: conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n\n" conversation += f"Human: {message}\nAssistant:" # Generate response with specified parameters inputs = self.tokenizer(conversation, return_tensors="pt") outputs = self.model.generate( inputs["input_ids"], max_length=max_length, temperature=temperature, top_p=top_p, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract assistant's response and update chat history response = response.split("Assistant:")[-1].strip() self.chat_history.append((message, response)) return response, self.format_chat_history() def format_chat_history(self): """Format chat history for display""" return [(f"User: {msg[0]}", f"Assistant: {msg[1]}") for msg in self.chat_history] def clear_history(self): """Clear the chat history""" self.chat_history = [] return self.format_chat_history() # Initialize the chat interface chat_interface = ChatInterface() # Create the Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=2): # Main chat interface chatbot = gr.Chatbot( label="Chat History", height=600, show_label=True, ) with gr.Row(): message = gr.Textbox( label="Your message", placeholder="Type your message here...", lines=2 ) submit_btn = gr.Button("Send", variant="primary") with gr.Column(scale=1): # System settings and parameters with gr.Group(label="System Configuration"): system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=4 ) context_file = gr.File( label="Upload Context JSON", file_types=[".json"] ) upload_button = gr.Button("Load Context") context_status = gr.Textbox(label="Context Status", interactive=False) with gr.Group(label="Generation Parameters"): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) max_length = gr.Slider( minimum=50, maximum=2000, value=500, step=50, label="Max Length" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P" ) presence_penalty = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Presence Penalty" ) frequency_penalty = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty" ) clear_btn = gr.Button("Clear Chat History") # Event handlers def submit_message(message, temperature, max_length, top_p, presence_penalty, frequency_penalty): response, history = chat_interface.generate_response( message, temperature, max_length, top_p, presence_penalty, frequency_penalty ) return "", history submit_btn.click( submit_message, inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty], outputs=[message, chatbot] ) message.submit( submit_message, inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty], outputs=[message, chatbot] ) clear_btn.click( lambda: (chat_interface.clear_history(), ""), outputs=[chatbot, message] ) upload_button.click( chat_interface.load_context_from_json, inputs=[context_file], outputs=[context_status, system_prompt] ) # Launch the interface demo.launch()