gabrielclark3330 commited on
Commit
c941cf9
·
1 Parent(s): bcc5c70

Manage long inputs and outputs

Browse files
Files changed (1) hide show
  1. app.py +166 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -169,3 +170,168 @@ with gr.Blocks() as demo:
169
 
170
  if __name__ == "__main__":
171
  demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
  import os
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
170
 
171
  if __name__ == "__main__":
172
  demo.queue().launch()
173
+ '''
174
+
175
+ import os
176
+ import gradio as gr
177
+ from transformers import AutoTokenizer, AutoModelForCausalLM
178
+ import torch
179
+
180
+ model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
181
+ model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
182
+
183
+ tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
184
+ model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
185
+ model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
186
+ )
187
+
188
+ tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct)
189
+ model_7B_instruct = AutoModelForCausalLM.from_pretrained(
190
+ model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
191
+ )
192
+
193
+ def extract_assistant_response(generated_text):
194
+ assistant_token = '<|im_start|> assistant'
195
+ end_token = '<|im_end|>'
196
+ start_idx = generated_text.rfind(assistant_token)
197
+ if start_idx == -1:
198
+ # Assistant token not found
199
+ return generated_text.strip()
200
+ start_idx += len(assistant_token)
201
+ end_idx = generated_text.find(end_token, start_idx)
202
+ if end_idx == -1:
203
+ # End token not found, return from start_idx to end
204
+ return generated_text[start_idx:].strip()
205
+ else:
206
+ return generated_text[start_idx:end_idx].strip()
207
+
208
+ def generate_response_2_7B_instruct(chat_history, max_new_tokens):
209
+ sample = []
210
+ for turn in chat_history:
211
+ if turn[0]:
212
+ sample.append({'role': 'user', 'content': turn[0]})
213
+ if turn[1]:
214
+ sample.append({'role': 'assistant', 'content': turn[1]})
215
+ chat_sample = tokenizer_2_7B_instruct.apply_chat_template(sample, tokenize=False)
216
+ input_ids = tokenizer_2_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(model_2_7B_instruct.device)
217
+
218
+ # Handle context length limit
219
+ max_context_length = 4096
220
+ max_new_tokens = int(max_new_tokens)
221
+ max_input_length = max_context_length - max_new_tokens
222
+ if input_ids.size(1) > max_input_length:
223
+ input_ids = input_ids[:, -max_input_length:] # Truncate from the left (oldest tokens)
224
+
225
+ outputs = model_2_7B_instruct.generate(
226
+ input_ids=input_ids,
227
+ max_new_tokens=max_new_tokens,
228
+ return_dict_in_generate=False,
229
+ output_scores=False,
230
+ use_cache=True,
231
+ num_beams=1,
232
+ do_sample=False
233
+ )
234
+
235
+ generated_text = tokenizer_2_7B_instruct.decode(outputs[0])
236
+ assistant_response = extract_assistant_response(generated_text)
237
+ return assistant_response
238
+
239
+ def generate_response_7B_instruct(chat_history, max_new_tokens):
240
+ sample = []
241
+ for turn in chat_history:
242
+ if turn[0]:
243
+ sample.append({'role': 'user', 'content': turn[0]})
244
+ if turn[1]:
245
+ sample.append({'role': 'assistant', 'content': turn[1]})
246
+ chat_sample = tokenizer_7B_instruct.apply_chat_template(sample, tokenize=False)
247
+ input_ids = tokenizer_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(model_7B_instruct.device)
248
+
249
+ # Handle context length limit
250
+ max_context_length = 4096
251
+ max_new_tokens = int(max_new_tokens)
252
+ max_input_length = max_context_length - max_new_tokens
253
+ if input_ids.size(1) > max_input_length:
254
+ input_ids = input_ids[:, -max_input_length:] # Truncate from the left (oldest tokens)
255
+
256
+ outputs = model_7B_instruct.generate(
257
+ input_ids=input_ids,
258
+ max_new_tokens=max_new_tokens,
259
+ return_dict_in_generate=False,
260
+ output_scores=False,
261
+ use_cache=True,
262
+ num_beams=1,
263
+ do_sample=False
264
+ )
265
+
266
+ generated_text = tokenizer_7B_instruct.decode(outputs[0])
267
+ assistant_response = extract_assistant_response(generated_text)
268
+ return assistant_response
269
+
270
+ with gr.Blocks() as demo:
271
+ gr.Markdown("# Zamba2 Model Selector")
272
+ with gr.Tabs():
273
+ with gr.TabItem("2.7B Instruct Model"):
274
+ gr.Markdown("### Zamba2-2.7B Instruct Model")
275
+ with gr.Column():
276
+ chat_history_2_7B_instruct = gr.State([])
277
+ chatbot_2_7B_instruct = gr.Chatbot()
278
+ message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
279
+ with gr.Accordion("Generation Parameters", open=False):
280
+ max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
281
+
282
+ def user_message_2_7B_instruct(message, chat_history):
283
+ chat_history = chat_history + [[message, None]]
284
+ return gr.update(value=""), chat_history, chat_history
285
+
286
+ def bot_response_2_7B_instruct(chat_history, max_new_tokens):
287
+ response = generate_response_2_7B_instruct(chat_history, max_new_tokens)
288
+ chat_history[-1][1] = response
289
+ return chat_history, chat_history
290
+
291
+ send_button_2_7B_instruct = gr.Button("Send")
292
+ send_button_2_7B_instruct.click(
293
+ fn=user_message_2_7B_instruct,
294
+ inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
295
+ outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
296
+ ).then(
297
+ fn=bot_response_2_7B_instruct,
298
+ inputs=[
299
+ chat_history_2_7B_instruct,
300
+ max_new_tokens_2_7B_instruct
301
+ ],
302
+ outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
303
+ )
304
+ with gr.TabItem("7B Instruct Model"):
305
+ gr.Markdown("### Zamba2-7B Instruct Model")
306
+ with gr.Column():
307
+ chat_history_7B_instruct = gr.State([])
308
+ chatbot_7B_instruct = gr.Chatbot()
309
+ message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
310
+ with gr.Accordion("Generation Parameters", open=False):
311
+ max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
312
+
313
+ def user_message_7B_instruct(message, chat_history):
314
+ chat_history = chat_history + [[message, None]]
315
+ return gr.update(value=""), chat_history, chat_history
316
+
317
+ def bot_response_7B_instruct(chat_history, max_new_tokens):
318
+ response = generate_response_7B_instruct(chat_history, max_new_tokens)
319
+ chat_history[-1][1] = response
320
+ return chat_history, chat_history
321
+
322
+ send_button_7B_instruct = gr.Button("Send")
323
+ send_button_7B_instruct.click(
324
+ fn=user_message_7B_instruct,
325
+ inputs=[message_7B_instruct, chat_history_7B_instruct],
326
+ outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
327
+ ).then(
328
+ fn=bot_response_7B_instruct,
329
+ inputs=[
330
+ chat_history_7B_instruct,
331
+ max_new_tokens_7B_instruct
332
+ ],
333
+ outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ demo.queue().launch()