slush0 commited on
Commit
f2475e8
·
1 Parent(s): 2f09f5b

Adding session reset button to Chat mode.

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. chat.py +49 -16
app.py CHANGED
@@ -17,5 +17,5 @@ with gr.Blocks() as iface:
17
  gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
18
 
19
  # Queues are required to enable generators
20
- iface.queue(concurrency_count=5)
21
  iface.launch()
 
17
  gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
18
 
19
  # Queues are required to enable generators
20
+ iface.queue(concurrency_count=5, max_size=50)
21
  iface.launch()
chat.py CHANGED
@@ -8,22 +8,38 @@ import chat_client
8
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
9
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
10
 
11
- def generate(prompt, model, endseq, max_length,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  do_sample, top_k, top_p, temperature,
13
- context, state):
14
 
15
  eos = "</s>\n" if "bloomz" in model else "\n\n"
16
 
17
  try:
18
  client = chat_client.ModelClient(CHAT_URL)
19
  client.open_session(f"bigscience/{model}-petals", max_length)
 
20
  except Exception:
21
  print(traceback.format_exc())
22
- yield state, state, prompt, "Error: " + traceback.format_exc()
23
  return
24
 
25
  context += eos
26
- for question, answer in state:
27
  context += f"Human: {question}{eos}AI: {answer}{eos}"
28
 
29
  # Fix eventual eos token mismatch and add eos token to context and prompt
@@ -64,10 +80,11 @@ def generate(prompt, model, endseq, max_length,
64
  temperature = 1.0
65
 
66
  # Update widgets even before we get the first response
67
- yield state + [[prompt, '']], state, None, prompt2
68
 
69
  output = ''
70
  output_raw = ''
 
71
  try:
72
  for out in client.generate(prompt2,
73
  max_new_tokens=1,
@@ -78,6 +95,13 @@ def generate(prompt, model, endseq, max_length,
78
  extra_stop_sequences=seq
79
  ):
80
 
 
 
 
 
 
 
 
81
  output_raw += out
82
  output += out
83
 
@@ -87,17 +111,25 @@ def generate(prompt, model, endseq, max_length,
87
  spl = output.split(s)
88
  output = spl[0]
89
  if len(spl) > 1:
90
- state2 = state + [[prompt, output]]
91
- yield state2, state2, None, prompt2 + output_raw
92
  return
93
 
94
- state2 = state + [[prompt, output]]
95
- yield state2, state2, None, prompt2 + output_raw
 
 
 
96
  except Exception:
97
  print(traceback.format_exc())
98
- yield state, state, prompt, "Error: " + traceback.format_exc()
99
  return
100
 
 
 
 
 
 
101
  with gr.Blocks() as iface_chat:
102
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
103
 
@@ -116,7 +148,7 @@ with gr.Blocks() as iface_chat:
116
  # Switch between sampling and greedy generation
117
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
118
  context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
119
- value="A human talks to a powerful AI that follows the human's instructions. "
120
  "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
121
  "Human: Hi!</s>\n"
122
  "AI: How can I help you?")
@@ -136,21 +168,22 @@ with gr.Blocks() as iface_chat:
136
 
137
  with gr.Row():
138
  button_generate = gr.Button("Generate")
139
- # button_clear = gr.Button("Clear session") # TODO
140
  # button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
141
 
142
  with gr.Accordion("Raw prompt log", open=False):
143
  output = gr.Textbox(lines=3, show_label=False).style(container=False)
144
 
145
  # Chat history
146
- state = gr.State([])
147
 
148
- inputs = [prompt, model, endseq, max_length, do_sample,
149
- top_k, top_p, temperature, context, state]
150
- outputs=[chat, state, prompt, output]
151
 
152
  prompt.submit(generate, inputs=inputs, outputs=outputs)
153
  button_generate.click(generate, inputs=inputs, outputs=outputs)
 
154
 
155
  examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
156
  examples=[
 
8
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
9
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
10
 
11
+ EMPTY_STATE = {
12
+ 'generate': False,
13
+ 'history': [],
14
+ }
15
+
16
+ def generate(state, *args):
17
+ # Save that we're in generating loop
18
+ state['generate'] = True
19
+
20
+ try:
21
+ for x in _generate(state, *args):
22
+ yield x
23
+ finally:
24
+ state['generate'] = False
25
+
26
+ def _generate(state, prompt, model, endseq, max_length,
27
  do_sample, top_k, top_p, temperature,
28
+ context):
29
 
30
  eos = "</s>\n" if "bloomz" in model else "\n\n"
31
 
32
  try:
33
  client = chat_client.ModelClient(CHAT_URL)
34
  client.open_session(f"bigscience/{model}-petals", max_length)
35
+ state['client'] = client
36
  except Exception:
37
  print(traceback.format_exc())
38
+ yield state, state['history'], prompt, "Error: " + traceback.format_exc()
39
  return
40
 
41
  context += eos
42
+ for question, answer in state['history']:
43
  context += f"Human: {question}{eos}AI: {answer}{eos}"
44
 
45
  # Fix eventual eos token mismatch and add eos token to context and prompt
 
80
  temperature = 1.0
81
 
82
  # Update widgets even before we get the first response
83
+ yield state, state['history'] + [[prompt, '']], None, prompt2
84
 
85
  output = ''
86
  output_raw = ''
87
+ orig_history = state['history']
88
  try:
89
  for out in client.generate(prompt2,
90
  max_new_tokens=1,
 
95
  extra_stop_sequences=seq
96
  ):
97
 
98
+ if not state['generate']:
99
+ print("Stopping generation.")
100
+ client.close_session()
101
+ yield state, [], None, ''
102
+ return
103
+ #return state, state['history'], None, prompt2 + output_raw
104
+
105
  output_raw += out
106
  output += out
107
 
 
111
  spl = output.split(s)
112
  output = spl[0]
113
  if len(spl) > 1:
114
+ state['history'] = orig_history + [[prompt, output]]
115
+ yield state, state['history'], None, prompt2 + output_raw
116
  return
117
 
118
+ # Keep original history untouched as we're adding just
119
+ # a chunks at one moment.
120
+ state['history'] = orig_history + [[prompt, output]]
121
+
122
+ yield state, state['history'], None, prompt2 + output_raw
123
  except Exception:
124
  print(traceback.format_exc())
125
+ yield state, state['history'], prompt, output_raw + "\nError: " + traceback.format_exc()
126
  return
127
 
128
+ def reset(state):
129
+ """Resets the session and clears the chat window."""
130
+ state.update(EMPTY_STATE)
131
+ return state, [], ''
132
+
133
  with gr.Blocks() as iface_chat:
134
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
135
 
 
148
  # Switch between sampling and greedy generation
149
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
150
  context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
151
+ value="A human talks to a powerful AI that follows the human's instructions.\n"
152
  "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
153
  "Human: Hi!</s>\n"
154
  "AI: How can I help you?")
 
168
 
169
  with gr.Row():
170
  button_generate = gr.Button("Generate")
171
+ button_reset = gr.Button("Reset/Clear session")
172
  # button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
173
 
174
  with gr.Accordion("Raw prompt log", open=False):
175
  output = gr.Textbox(lines=3, show_label=False).style(container=False)
176
 
177
  # Chat history
178
+ state = gr.State(EMPTY_STATE)
179
 
180
+ inputs = [state, prompt, model, endseq, max_length, do_sample,
181
+ top_k, top_p, temperature, context]
182
+ outputs=[state, chat, prompt, output]
183
 
184
  prompt.submit(generate, inputs=inputs, outputs=outputs)
185
  button_generate.click(generate, inputs=inputs, outputs=outputs)
186
+ button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
187
 
188
  examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
189
  examples=[