d-delaurier commited on
Commit
716a943
·
verified ·
1 Parent(s): f3fac44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -1
app.py CHANGED
@@ -1,3 +1,177 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- gr.load("models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ from pathlib import Path
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # Default system prompt for the chat interface
8
+ DEFAULT_SYSTEM_PROMPT = """You are DeepThink, a helpful and knowledgeable AI assistant. You aim to provide accurate,
9
+ informative, and engaging responses while maintaining a professional and friendly demeanor."""
10
+
11
+ class ChatInterface:
12
+ """Main chat interface handler with memory and parameter management"""
13
+
14
+ def __init__(self):
15
+ """Initialize the chat interface with default settings"""
16
+ self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
17
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
18
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
19
+ self.chat_history = []
20
+ self.system_prompt = DEFAULT_SYSTEM_PROMPT
21
+
22
+ def load_context_from_json(self, file_obj):
23
+ """Load additional context from a JSON file"""
24
+ if file_obj is None:
25
+ return "No file uploaded", self.system_prompt
26
+
27
+ try:
28
+ content = json.load(file_obj)
29
+ if "system_prompt" in content:
30
+ self.system_prompt = content["system_prompt"]
31
+ return "Context loaded successfully!", self.system_prompt
32
+ except Exception as e:
33
+ return f"Error loading context: {str(e)}", self.system_prompt
34
+
35
+ def generate_response(self, message, temperature, max_length, top_p, presence_penalty, frequency_penalty):
36
+ """Generate AI response with given parameters"""
37
+ # Format the input with system prompt and chat history
38
+ conversation = f"System: {self.system_prompt}\n\n"
39
+ for msg in self.chat_history:
40
+ conversation += f"Human: {msg[0]}\nAssistant: {msg[1]}\n\n"
41
+ conversation += f"Human: {message}\nAssistant:"
42
+
43
+ # Generate response with specified parameters
44
+ inputs = self.tokenizer(conversation, return_tensors="pt")
45
+ outputs = self.model.generate(
46
+ inputs["input_ids"],
47
+ max_length=max_length,
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ presence_penalty=presence_penalty,
51
+ frequency_penalty=frequency_penalty,
52
+ )
53
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+
55
+ # Extract assistant's response and update chat history
56
+ response = response.split("Assistant:")[-1].strip()
57
+ self.chat_history.append((message, response))
58
+
59
+ return response, self.format_chat_history()
60
+
61
+ def format_chat_history(self):
62
+ """Format chat history for display"""
63
+ return [(f"User: {msg[0]}", f"Assistant: {msg[1]}") for msg in self.chat_history]
64
+
65
+ def clear_history(self):
66
+ """Clear the chat history"""
67
+ self.chat_history = []
68
+ return self.format_chat_history()
69
+
70
+ # Initialize the chat interface
71
+ chat_interface = ChatInterface()
72
+
73
+ # Create the Gradio interface
74
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
+ with gr.Row():
76
+ with gr.Column(scale=2):
77
+ # Main chat interface
78
+ chatbot = gr.Chatbot(
79
+ label="Chat History",
80
+ height=600,
81
+ show_label=True,
82
+ )
83
+
84
+ with gr.Row():
85
+ message = gr.Textbox(
86
+ label="Your message",
87
+ placeholder="Type your message here...",
88
+ lines=2
89
+ )
90
+ submit_btn = gr.Button("Send", variant="primary")
91
+
92
+ with gr.Column(scale=1):
93
+ # System settings and parameters
94
+ with gr.Group(label="System Configuration"):
95
+ system_prompt = gr.Textbox(
96
+ label="System Prompt",
97
+ value=DEFAULT_SYSTEM_PROMPT,
98
+ lines=4
99
+ )
100
+ context_file = gr.File(
101
+ label="Upload Context JSON",
102
+ file_types=[".json"]
103
+ )
104
+ upload_button = gr.Button("Load Context")
105
+ context_status = gr.Textbox(label="Context Status", interactive=False)
106
+
107
+ with gr.Group(label="Generation Parameters"):
108
+ temperature = gr.Slider(
109
+ minimum=0.1,
110
+ maximum=2.0,
111
+ value=0.7,
112
+ step=0.1,
113
+ label="Temperature"
114
+ )
115
+ max_length = gr.Slider(
116
+ minimum=50,
117
+ maximum=2000,
118
+ value=500,
119
+ step=50,
120
+ label="Max Length"
121
+ )
122
+ top_p = gr.Slider(
123
+ minimum=0.1,
124
+ maximum=1.0,
125
+ value=0.9,
126
+ step=0.1,
127
+ label="Top P"
128
+ )
129
+ presence_penalty = gr.Slider(
130
+ minimum=0.0,
131
+ maximum=2.0,
132
+ value=0.0,
133
+ step=0.1,
134
+ label="Presence Penalty"
135
+ )
136
+ frequency_penalty = gr.Slider(
137
+ minimum=0.0,
138
+ maximum=2.0,
139
+ value=0.0,
140
+ step=0.1,
141
+ label="Frequency Penalty"
142
+ )
143
+
144
+ clear_btn = gr.Button("Clear Chat History")
145
+
146
+ # Event handlers
147
+ def submit_message(message, temperature, max_length, top_p, presence_penalty, frequency_penalty):
148
+ response, history = chat_interface.generate_response(
149
+ message, temperature, max_length, top_p, presence_penalty, frequency_penalty
150
+ )
151
+ return "", history
152
+
153
+ submit_btn.click(
154
+ submit_message,
155
+ inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty],
156
+ outputs=[message, chatbot]
157
+ )
158
+
159
+ message.submit(
160
+ submit_message,
161
+ inputs=[message, temperature, max_length, top_p, presence_penalty, frequency_penalty],
162
+ outputs=[message, chatbot]
163
+ )
164
+
165
+ clear_btn.click(
166
+ lambda: (chat_interface.clear_history(), ""),
167
+ outputs=[chatbot, message]
168
+ )
169
+
170
+ upload_button.click(
171
+ chat_interface.load_context_from_json,
172
+ inputs=[context_file],
173
+ outputs=[context_status, system_prompt]
174
+ )
175
+
176
+ # Launch the interface
177
+ demo.launch()