lewtun HF staff commited on
Commit
fe5de16
2 Parent(s): 2af3373 cf5c7de

Merge branch 'main' of https://huggingface.co/spaces/HuggingFaceH4/starchat-playground

Browse files
Files changed (1) hide show
  1. app.py +76 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import datetime
2
  import os
3
  import re
 
4
  from io import StringIO
5
 
6
  import gradio as gr
@@ -23,6 +24,11 @@ model2endpoint = {
23
  model_names = list(model2endpoint.keys())
24
 
25
 
 
 
 
 
 
26
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model):
27
  buffer = StringIO()
28
  timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
@@ -77,6 +83,7 @@ def has_no_history(chatbot, history):
77
 
78
 
79
  def generate(
 
80
  model_name,
81
  system_message,
82
  user_message,
@@ -98,7 +105,11 @@ def generate(
98
  if not user_message:
99
  print("Empty input")
100
 
101
- history.append(user_message)
 
 
 
 
102
 
103
  past_messages = []
104
  for data in chatbot:
@@ -138,7 +149,7 @@ def generate(
138
  repetition_penalty=repetition_penalty,
139
  do_sample=True,
140
  truncate=4096,
141
- seed=42,
142
  stop_sequences=["<|end|>"],
143
  )
144
 
@@ -208,6 +219,45 @@ def process_example(args):
208
  return [x, y]
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  title = """<h1 align="center">⭐ StarChat Playground 💬</h1>"""
212
  custom_css = """
213
  #banner-image {
@@ -270,7 +320,8 @@ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
270
  with gr.Row():
271
  send_button = gr.Button("Send", elem_id="send-btn", visible=True)
272
 
273
- # regenerate_button = gr.Button("Regenerate", elem_id="send-btn", visible=True)
 
274
  delete_turn_button = gr.Button("Delete last turn", elem_id="delete-btn", visible=True)
275
 
276
  clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
@@ -335,12 +386,15 @@ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
335
  )
336
 
337
  history = gr.State([])
 
 
338
  # To clear out "message" input textbox and use this to regenerate message
339
  last_user_message = gr.State("")
340
 
341
  user_message.submit(
342
  generate,
343
  inputs=[
 
344
  selected_model,
345
  system_message,
346
  user_message,
@@ -359,6 +413,7 @@ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
359
  send_button.click(
360
  generate,
361
  inputs=[
 
362
  selected_model,
363
  system_message,
364
  user_message,
@@ -374,6 +429,24 @@ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
374
  outputs=[chatbot, history, last_user_message, user_message],
375
  )
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history])
378
  clear_chat_button.click(clear_chat, outputs=[chatbot, history])
379
  selected_model.change(clear_chat, outputs=[chatbot, history])
 
1
  import datetime
2
  import os
3
  import re
4
+ import random
5
  from io import StringIO
6
 
7
  import gradio as gr
 
24
  model_names = list(model2endpoint.keys())
25
 
26
 
27
+ def randomize_seed_generator():
28
+ seed = random.randint(0, 1000000)
29
+ return seed
30
+
31
+
32
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model):
33
  buffer = StringIO()
34
  timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
 
83
 
84
 
85
  def generate(
86
+ RETRY_FLAG,
87
  model_name,
88
  system_message,
89
  user_message,
 
105
  if not user_message:
106
  print("Empty input")
107
 
108
+ if not RETRY_FLAG:
109
+ history.append(user_message)
110
+ seed=42
111
+ else:
112
+ seed=randomize_seed_generator()
113
 
114
  past_messages = []
115
  for data in chatbot:
 
149
  repetition_penalty=repetition_penalty,
150
  do_sample=True,
151
  truncate=4096,
152
+ seed=seed,
153
  stop_sequences=["<|end|>"],
154
  )
155
 
 
219
  return [x, y]
220
 
221
 
222
+ # Regenerate response
223
+ def retry_last_answer(
224
+ selected_model,
225
+ system_message,
226
+ user_message,
227
+ chat,
228
+ history,
229
+ temperature,
230
+ top_k,
231
+ top_p,
232
+ max_new_tokens,
233
+ repetition_penalty,
234
+ do_save):
235
+
236
+ if chat and history:
237
+ # Removing the previous conversation from chat
238
+ chat.pop(-1)
239
+ # Removing bot response from the history
240
+ history.pop(-1)
241
+ # Setting up a flag to capture a retry
242
+ RETRY_FLAG = True
243
+ # Getting last message from user
244
+ user_message = history[-1]
245
+
246
+ yield from generate(
247
+ RETRY_FLAG,
248
+ selected_model,
249
+ system_message,
250
+ user_message,
251
+ chat,
252
+ history,
253
+ temperature,
254
+ top_k,
255
+ top_p,
256
+ max_new_tokens,
257
+ repetition_penalty,
258
+ do_save)
259
+
260
+
261
  title = """<h1 align="center">⭐ StarChat Playground 💬</h1>"""
262
  custom_css = """
263
  #banner-image {
 
320
  with gr.Row():
321
  send_button = gr.Button("Send", elem_id="send-btn", visible=True)
322
 
323
+ regenerate_button = gr.Button("Regenerate", elem_id="retry-btn", visible=True)
324
+
325
  delete_turn_button = gr.Button("Delete last turn", elem_id="delete-btn", visible=True)
326
 
327
  clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
 
386
  )
387
 
388
  history = gr.State([])
389
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
390
+
391
  # To clear out "message" input textbox and use this to regenerate message
392
  last_user_message = gr.State("")
393
 
394
  user_message.submit(
395
  generate,
396
  inputs=[
397
+ RETRY_FLAG,
398
  selected_model,
399
  system_message,
400
  user_message,
 
413
  send_button.click(
414
  generate,
415
  inputs=[
416
+ RETRY_FLAG,
417
  selected_model,
418
  system_message,
419
  user_message,
 
429
  outputs=[chatbot, history, last_user_message, user_message],
430
  )
431
 
432
+ regenerate_button.click(
433
+ retry_last_answer,
434
+ inputs = [
435
+ selected_model,
436
+ system_message,
437
+ user_message,
438
+ chatbot,
439
+ history,
440
+ temperature,
441
+ top_k,
442
+ top_p,
443
+ max_new_tokens,
444
+ repetition_penalty,
445
+ do_save,
446
+ ],
447
+ outputs = [chatbot, history, last_user_message, user_message]
448
+ )
449
+
450
  delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history])
451
  clear_chat_button.click(clear_chat, outputs=[chatbot, history])
452
  selected_model.change(clear_chat, outputs=[chatbot, history])