Chris4K commited on
Commit
e1e7413
·
verified ·
1 Parent(s): 797e2ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -52
app.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  from threading import Thread
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from dataclasses import dataclass
6
- from typing import List, Optional
 
7
 
8
  @dataclass
9
  class AppConfig:
@@ -12,8 +13,7 @@ class AppConfig:
12
  MAX_LENGTH: int = 4096
13
  DEFAULT_TEMP: float = 0.7
14
  CHAT_HEIGHT: int = 450
15
-
16
- # Simplified chat template
17
 
18
  CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
19
  {%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}
@@ -27,14 +27,15 @@ CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_gener
27
  {%- endfor -%}
28
  {%- if add_generation_prompt %}<|Assistant|>{% endif -%}"""
29
 
30
- # Improved CSS with better organization and variables
31
  CSS = """
32
  :root {
33
  --primary-color: #1565c0;
 
34
  --text-primary: rgba(0, 0, 0, 0.87);
35
  --text-secondary: rgba(0, 0, 0, 0.65);
36
  --spacing-lg: 30px;
37
  --border-radius: 100vh;
 
38
  }
39
 
40
  .container {
@@ -46,37 +47,41 @@ CSS = """
46
  .header {
47
  text-align: center;
48
  margin-bottom: var(--spacing-lg);
 
 
 
 
49
  }
50
 
51
  .header h1 {
52
  font-size: 28px;
53
- color: var(--text-primary);
54
  margin-bottom: 8px;
55
  }
56
 
57
  .header p {
58
  font-size: 18px;
59
- color: var(--text-secondary);
60
- }
61
-
62
- .action-button {
63
- background: var(--primary-color);
64
- color: white;
65
- border-radius: var(--border-radius);
66
- padding: 8px 16px;
67
- cursor: pointer;
68
- border: none;
69
- transition: opacity 0.2s;
70
- }
71
-
72
- .action-button:hover {
73
  opacity: 0.9;
74
  }
75
 
76
  #chatbot {
77
  border-radius: 8px;
78
  background: white;
79
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  }
81
  """
82
 
@@ -86,35 +91,54 @@ class ChatBot:
86
  self.setup_model()
87
 
88
  def setup_model(self):
89
- """Initialize the model and tokenizer"""
90
  self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
 
 
 
 
 
91
  self.tokenizer.chat_template = CHAT_TEMPLATE
92
 
93
  self.model = AutoModelForCausalLM.from_pretrained(
94
  self.config.MODEL_NAME,
95
- device_map="auto" # Automatically choose best device
 
96
  )
97
-
98
- def generate_response(self, message: str,
99
- history: List[tuple],
100
- temperature: float,
101
- max_new_tokens: int) -> str:
102
- """Generate streaming response with improved error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  try:
104
- conversation = []
105
- for user, assistant in history:
106
- conversation.extend([
107
- {"role": "user", "content": user},
108
- {"role": "assistant", "content": assistant}
109
- ])
110
  conversation.append({"role": "user", "content": message})
111
 
112
- input_ids = self.tokenizer.apply_chat_template(
 
113
  conversation,
114
  return_tensors="pt",
115
  add_generation_prompt=True
116
  ).to(self.model.device)
117
-
 
 
118
  streamer = TextIteratorStreamer(
119
  self.tokenizer,
120
  timeout=10.0,
@@ -123,12 +147,14 @@ class ChatBot:
123
  )
124
 
125
  generate_kwargs = {
126
- "input_ids": input_ids,
 
127
  "streamer": streamer,
128
  "max_new_tokens": max_new_tokens,
129
  "do_sample": temperature > 0,
130
  "temperature": temperature,
131
- "eos_token_id": [self.tokenizer.eos_token_id],
 
132
  }
133
 
134
  thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
@@ -143,29 +169,34 @@ class ChatBot:
143
  """Process the streaming output with improved text cleaning"""
144
  outputs = []
145
  for text in streamer:
146
- # Clean thinking tags and other special tokens
147
  text = (text.replace("<think>", "[think]")
148
  .replace("</think>", "[/think]")
 
149
  .strip())
150
  outputs.append(text)
151
  yield "".join(outputs)
152
 
153
  def create_gradio_interface(chatbot: ChatBot):
154
- """Create the Gradio interface with improved layout"""
155
  examples = [
156
- ['Why is A.I. good?'],
157
- ['What is an A.I. state machine?']
 
158
  ]
159
 
160
- chatbot_interface = gr.Chatbot(
161
- height=chatbot.config.CHAT_HEIGHT,
162
- container=True,
163
- elem_id="chatbot"
164
- )
165
-
166
  with gr.Blocks(css=CSS) as demo:
167
  with gr.Column(elem_classes="container"):
168
- gr.Markdown("# DeepSeek R1 Distill Qwen 1.5B Chat Interface")
 
 
 
 
 
 
 
 
 
169
 
170
  interface = gr.ChatInterface(
171
  fn=chatbot.generate_response,
@@ -174,16 +205,21 @@ def create_gradio_interface(chatbot: ChatBot):
174
  gr.Slider(
175
  minimum=0, maximum=1,
176
  value=chatbot.config.DEFAULT_TEMP,
177
- label="Temperature"
 
178
  ),
179
  gr.Slider(
180
  minimum=128, maximum=chatbot.config.MAX_LENGTH,
181
  value=1024,
182
- label="Max new tokens"
 
183
  ),
184
  ],
185
  examples=examples,
186
  cache_examples=False,
 
 
 
187
  )
188
 
189
  return demo
@@ -192,4 +228,10 @@ if __name__ == "__main__":
192
  config = AppConfig()
193
  chatbot = ChatBot(config)
194
  demo = create_gradio_interface(chatbot)
195
- demo.launch(debug=True)
 
 
 
 
 
 
 
3
  from threading import Thread
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from dataclasses import dataclass
6
+ from typing import List, Dict, Any, Optional
7
+ import torch
8
 
9
  @dataclass
10
  class AppConfig:
 
13
  MAX_LENGTH: int = 4096
14
  DEFAULT_TEMP: float = 0.7
15
  CHAT_HEIGHT: int = 450
16
+ PAD_TOKEN: str = "[PAD]"
 
17
 
18
  CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
19
  {%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}
 
27
  {%- endfor -%}
28
  {%- if add_generation_prompt %}<|Assistant|>{% endif -%}"""
29
 
 
30
  CSS = """
31
  :root {
32
  --primary-color: #1565c0;
33
+ --secondary-color: #1976d2;
34
  --text-primary: rgba(0, 0, 0, 0.87);
35
  --text-secondary: rgba(0, 0, 0, 0.65);
36
  --spacing-lg: 30px;
37
  --border-radius: 100vh;
38
+ --shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
39
  }
40
 
41
  .container {
 
47
  .header {
48
  text-align: center;
49
  margin-bottom: var(--spacing-lg);
50
+ padding: 20px;
51
+ background: var(--primary-color);
52
+ color: white;
53
+ border-radius: 8px;
54
  }
55
 
56
  .header h1 {
57
  font-size: 28px;
 
58
  margin-bottom: 8px;
59
  }
60
 
61
  .header p {
62
  font-size: 18px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  opacity: 0.9;
64
  }
65
 
66
  #chatbot {
67
  border-radius: 8px;
68
  background: white;
69
+ box-shadow: var(--shadow);
70
+ }
71
+
72
+ .message {
73
+ padding: 12px 16px;
74
+ border-radius: 8px;
75
+ margin: 8px 0;
76
+ }
77
+
78
+ .user-message {
79
+ background: var(--primary-color);
80
+ color: white;
81
+ }
82
+
83
+ .assistant-message {
84
+ background: #f5f5f5;
85
  }
86
  """
87
 
 
91
  self.setup_model()
92
 
93
  def setup_model(self):
94
+ """Initialize the model and tokenizer with proper configuration"""
95
  self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME)
96
+
97
+ # Add pad token if it doesn't exist
98
+ if self.tokenizer.pad_token is None:
99
+ self.tokenizer.add_special_tokens({'pad_token': self.config.PAD_TOKEN})
100
+
101
  self.tokenizer.chat_template = CHAT_TEMPLATE
102
 
103
  self.model = AutoModelForCausalLM.from_pretrained(
104
  self.config.MODEL_NAME,
105
+ device_map="auto",
106
+ torch_dtype=torch.float16 # Use half precision for better memory efficiency
107
  )
108
+
109
+ # Resize token embeddings if needed
110
+ self.model.resize_token_embeddings(len(self.tokenizer))
111
+
112
+ def _convert_history_to_messages(self, history: List[tuple]) -> List[Dict[str, str]]:
113
+ """Convert tuple history to message format"""
114
+ messages = []
115
+ for user, assistant in history:
116
+ messages.extend([
117
+ {"role": "user", "content": user},
118
+ {"role": "assistant", "content": assistant}
119
+ ])
120
+ return messages
121
+
122
+ def generate_response(self,
123
+ message: str,
124
+ history: List[tuple],
125
+ temperature: float,
126
+ max_new_tokens: int) -> str:
127
+ """Generate streaming response with improved error handling and attention mask"""
128
  try:
129
+ # Convert history to messages format
130
+ conversation = self._convert_history_to_messages(history)
 
 
 
 
131
  conversation.append({"role": "user", "content": message})
132
 
133
+ # Prepare input with attention mask
134
+ inputs = self.tokenizer.apply_chat_template(
135
  conversation,
136
  return_tensors="pt",
137
  add_generation_prompt=True
138
  ).to(self.model.device)
139
+
140
+ attention_mask = torch.ones_like(inputs)
141
+
142
  streamer = TextIteratorStreamer(
143
  self.tokenizer,
144
  timeout=10.0,
 
147
  )
148
 
149
  generate_kwargs = {
150
+ "input_ids": inputs,
151
+ "attention_mask": attention_mask,
152
  "streamer": streamer,
153
  "max_new_tokens": max_new_tokens,
154
  "do_sample": temperature > 0,
155
  "temperature": temperature,
156
+ "pad_token_id": self.tokenizer.pad_token_id,
157
+ "eos_token_id": self.tokenizer.eos_token_id,
158
  }
159
 
160
  thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
 
169
  """Process the streaming output with improved text cleaning"""
170
  outputs = []
171
  for text in streamer:
172
+ # Clean special tokens and normalize whitespace
173
  text = (text.replace("<think>", "[think]")
174
  .replace("</think>", "[/think]")
175
+ .replace("<|end▁of▁sentence|>", "")
176
  .strip())
177
  outputs.append(text)
178
  yield "".join(outputs)
179
 
180
  def create_gradio_interface(chatbot: ChatBot):
181
+ """Create the Gradio interface with improved layout and modern message format"""
182
  examples = [
183
+ ['Tell me about artificial intelligence.'],
184
+ ['What are neural networks?'],
185
+ ['Explain machine learning in simple terms.']
186
  ]
187
 
 
 
 
 
 
 
188
  with gr.Blocks(css=CSS) as demo:
189
  with gr.Column(elem_classes="container"):
190
+ with gr.Column(elem_classes="header"):
191
+ gr.Markdown("# DeepSeek R1 Chat Interface")
192
+ gr.Markdown("An efficient and responsive chat interface powered by DeepSeek R1 Distill")
193
+
194
+ chatbot_interface = gr.Chatbot(
195
+ height=chatbot.config.CHAT_HEIGHT,
196
+ container=True,
197
+ elem_id="chatbot",
198
+ type="messages" # Use modern message format
199
+ )
200
 
201
  interface = gr.ChatInterface(
202
  fn=chatbot.generate_response,
 
205
  gr.Slider(
206
  minimum=0, maximum=1,
207
  value=chatbot.config.DEFAULT_TEMP,
208
+ label="Temperature",
209
+ info="Higher values make the output more random"
210
  ),
211
  gr.Slider(
212
  minimum=128, maximum=chatbot.config.MAX_LENGTH,
213
  value=1024,
214
+ label="Max new tokens",
215
+ info="Maximum length of the generated response"
216
  ),
217
  ],
218
  examples=examples,
219
  cache_examples=False,
220
+ retry_btn="Regenerate Response",
221
+ undo_btn="Undo Last",
222
+ clear_btn="Clear Chat",
223
  )
224
 
225
  return demo
 
228
  config = AppConfig()
229
  chatbot = ChatBot(config)
230
  demo = create_gradio_interface(chatbot)
231
+ demo.launch(
232
+ debug=True,
233
+ share=False, # Set to True to create a public link
234
+ server_name="0.0.0.0",
235
+ server_port=7860,
236
+ ssr=False # Disable SSR to avoid experimental features
237
+ )