slush0 commited on
Commit
57a7522
·
1 Parent(s): dac87a3

Added automatic resuming of sessions in Chat mode (resends the context to the API).

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. chat.py +71 -35
  3. chat_client.py +13 -7
app.py CHANGED
@@ -18,4 +18,4 @@ with gr.Blocks() as iface:
18
 
19
  # Queues are required to enable generators
20
  iface.queue(concurrency_count=5, max_size=50)
21
- iface.launch()
 
18
 
19
  # Queues are required to enable generators
20
  iface.queue(concurrency_count=5, max_size=50)
21
+ iface.launch(show_error=True)
chat.py CHANGED
@@ -10,37 +10,54 @@ CHAT_URL='ws://chat.petals.ml/api/v2/generate'
10
 
11
  EMPTY_STATE = {
12
  'generate': False,
 
 
13
  'history': [],
14
  }
15
 
16
- def generate(state, *args):
17
  # Save that we're in generating loop
18
  state['generate'] = True
19
 
20
  try:
21
- for x in _generate(state, *args):
22
  yield x
 
 
 
 
 
 
 
23
  finally:
24
  state['generate'] = False
25
 
26
- def _generate(state, prompt, model, endseq, max_length,
27
- do_sample, top_k, top_p, temperature,
28
- context):
29
 
 
30
  eos = "</s>\n" if "bloomz" in model else "\n\n"
31
 
32
- try:
33
- client = chat_client.ModelClient(CHAT_URL)
34
- client.open_session(f"bigscience/{model}-petals", max_length)
35
- state['client'] = client
36
- except Exception:
37
- print(traceback.format_exc())
38
- yield state, state['history'], prompt, "Error: " + traceback.format_exc()
39
- return
 
 
 
 
 
 
 
 
40
 
41
  context += eos
42
- for question, answer in state['history']:
43
- context += f"Human: {question}{eos}AI: {answer}{eos}"
44
 
45
  # Fix eventual eos token mismatch and add eos token to context and prompt
46
  if "bloomz" in model:
@@ -50,7 +67,7 @@ def _generate(state, prompt, model, endseq, max_length,
50
  context = context.replace("</s>", eos)
51
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
52
 
53
- prompt2 = f"{context}Human: {prompt2}AI: "
54
 
55
  # Translate checkbox items to actual sequences
56
  seq = []
@@ -79,12 +96,13 @@ def _generate(state, prompt, model, endseq, max_length,
79
  if temperature == 0:
80
  temperature = 1.0
81
 
 
 
82
  # Update widgets even before we get the first response
83
- yield state, state['history'] + [[prompt, '']], None, prompt2
84
 
85
- output = ''
86
- output_raw = ''
87
  orig_history = state['history']
 
88
  try:
89
  for out in client.generate(prompt2,
90
  max_new_tokens=1,
@@ -97,36 +115,53 @@ def _generate(state, prompt, model, endseq, max_length,
97
 
98
  if not state['generate']:
99
  client.close_session()
100
- yield state, [], None, ''
 
101
  return
102
 
103
- output_raw += out
104
- output += out
105
 
106
  # Detect end sequences and finish the generation
107
  # prematurely if found.
108
  for s in seq:
109
- spl = output.split(s)
110
- output = spl[0]
111
  if len(spl) > 1:
112
- state['history'] = orig_history + [[prompt, output]]
113
- yield state, state['history'], None, prompt2 + output_raw
 
 
114
  return
115
 
116
  # Keep original history untouched as we're adding just
117
  # a chunks at one moment.
118
- state['history'] = orig_history + [[prompt, output]]
119
-
120
- yield state, state['history'], None, prompt2 + output_raw
 
 
 
 
 
 
 
 
 
121
  except Exception:
 
 
 
 
122
  print(traceback.format_exc())
123
- yield state, state['history'], prompt, output_raw + "\nError: " + traceback.format_exc()
 
 
124
  return
125
 
126
  def reset(state):
127
  """Resets the session and clears the chat window."""
128
  state.update(EMPTY_STATE)
129
- return state, [], ''
130
 
131
  with gr.Blocks() as iface_chat:
132
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
@@ -163,6 +198,7 @@ with gr.Blocks() as iface_chat:
163
  chat = gr.Chatbot(label='Chat window')
164
  prompt = gr.Textbox(show_label=False, label='Prompt',
165
  placeholder="Prompt Here and press Enter...").style(container=False)
 
166
 
167
  with gr.Row():
168
  button_generate = gr.Button("Generate")
@@ -174,13 +210,13 @@ with gr.Blocks() as iface_chat:
174
  # Chat history
175
  state = gr.State(EMPTY_STATE)
176
 
177
- inputs = [state, prompt, model, endseq, max_length, do_sample,
178
- top_k, top_p, temperature, context]
179
- outputs=[state, chat, prompt, output]
180
 
181
  prompt.submit(generate, inputs=inputs, outputs=outputs)
182
  button_generate.click(generate, inputs=inputs, outputs=outputs)
183
- button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
184
 
185
  examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
186
  examples=[
 
10
 
11
  EMPTY_STATE = {
12
  'generate': False,
13
+ 'model': None,
14
+ 'client': None,
15
  'history': [],
16
  }
17
 
18
+ def generate(state, prompt, model, context, output, *args):
19
  # Save that we're in generating loop
20
  state['generate'] = True
21
 
22
  try:
23
+ for x in _generate(state, prompt, model, context, output, *args):
24
  yield x
25
+ except BrokenPipeError:
26
+ # Broken session, try to renew
27
+ # TODO This is a bit fragile because of recursive call...
28
+ print("Retrying session...")
29
+ context = output
30
+ output = ''
31
+ yield from generate(state, prompt, model, context, output, *args)
32
  finally:
33
  state['generate'] = False
34
 
35
+ def _generate(state, prompt, model, context, output, endseq, max_length,
36
+ do_sample, top_k, top_p, temperature):
 
37
 
38
+ print('prompt', prompt)
39
  eos = "</s>\n" if "bloomz" in model else "\n\n"
40
 
41
+ if state['model'] != model or \
42
+ state['client'] == None or state['client'].is_session() == False:
43
+
44
+ try:
45
+ state['client'] = chat_client.ModelClient(CHAT_URL)
46
+ state['client'].open_session(f"bigscience/{model}-petals", max_length)
47
+ state['model'] = model
48
+ except Exception:
49
+ print(traceback.format_exc())
50
+ yield state, state['history'], prompt, output, \
51
+ gr.update(visible=True, value=traceback.format_exc())
52
+ return
53
+ else:
54
+ context = ''
55
+
56
+ client = state['client']
57
 
58
  context += eos
59
+ #for question, answer in state['history']:
60
+ # context += f"Human: {question}{eos}AI: {answer}{eos}"
61
 
62
  # Fix eventual eos token mismatch and add eos token to context and prompt
63
  if "bloomz" in model:
 
67
  context = context.replace("</s>", eos)
68
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
69
 
70
+ prompt2 = f"{context}Human: {prompt2}AI:"
71
 
72
  # Translate checkbox items to actual sequences
73
  seq = []
 
96
  if temperature == 0:
97
  temperature = 1.0
98
 
99
+ output += prompt2
100
+
101
  # Update widgets even before we get the first response
102
+ yield state, state['history'] + [[prompt, '']], None, output, gr.update(visible=False)
103
 
 
 
104
  orig_history = state['history']
105
+ new_line = ''
106
  try:
107
  for out in client.generate(prompt2,
108
  max_new_tokens=1,
 
115
 
116
  if not state['generate']:
117
  client.close_session()
118
+ yield state, [], None, '', ''
119
+ # Stopping generation
120
  return
121
 
122
+ new_line += out
 
123
 
124
  # Detect end sequences and finish the generation
125
  # prematurely if found.
126
  for s in seq:
127
+ spl = new_line.split(s)
128
+ new_line = spl[0]
129
  if len(spl) > 1:
130
+ state['history'] = orig_history + [[prompt, new_line]]
131
+ output += new_line
132
+ yield state, state['history'], None, output, ''
133
+ # Stopping generation
134
  return
135
 
136
  # Keep original history untouched as we're adding just
137
  # a chunks at one moment.
138
+ state['history'] = orig_history + [[prompt, new_line]]
139
+ yield state, state['history'], None, output, ''
140
+
141
+ except BrokenPipeError:
142
+ # Session was interrupted
143
+ # Handled in upstream func
144
+ client.close_session()
145
+ state['client'] = None
146
+ state['model'] = None
147
+
148
+ print("Broken session!")
149
+ raise
150
  except Exception:
151
+ client.close_session()
152
+ state['client'] = None
153
+ state['model'] = None
154
+
155
  print(traceback.format_exc())
156
+ # TODO Store errors outside output log
157
+ yield state, state['history'], prompt, output, \
158
+ gr.update(visible=True, value=traceback.format_exc())
159
  return
160
 
161
  def reset(state):
162
  """Resets the session and clears the chat window."""
163
  state.update(EMPTY_STATE)
164
+ return state, [], '', gr.update(visible=False, value='')
165
 
166
  with gr.Blocks() as iface_chat:
167
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
 
198
  chat = gr.Chatbot(label='Chat window')
199
  prompt = gr.Textbox(show_label=False, label='Prompt',
200
  placeholder="Prompt Here and press Enter...").style(container=False)
201
+ error = gr.Textbox(label="Error log", visible=False, elem_id="error")
202
 
203
  with gr.Row():
204
  button_generate = gr.Button("Generate")
 
210
  # Chat history
211
  state = gr.State(EMPTY_STATE)
212
 
213
+ inputs = [state, prompt, model, context, output, endseq,
214
+ max_length, do_sample, top_k, top_p, temperature]
215
+ outputs=[state, chat, prompt, output, error]
216
 
217
  prompt.submit(generate, inputs=inputs, outputs=outputs)
218
  button_generate.click(generate, inputs=inputs, outputs=outputs)
219
+ button_reset.click(reset, inputs=[state], outputs=[state, chat, output, error])
220
 
221
  examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
222
  examples=[
chat_client.py CHANGED
@@ -22,29 +22,35 @@ class ModelClient(object):
22
  self.ws.send(json.dumps(payload))
23
  assert json.loads(self.ws.recv())['ok'] == True
24
 
 
 
 
25
  def close_session(self):
26
  if self.ws:
27
  self.ws.close()
 
28
 
29
  def generate(self, prompt, **kwargs):
 
 
 
 
 
 
 
30
  payload = {
31
  "type": "generate",
32
  "inputs": prompt,
33
  "max_new_tokens": 1,
34
  "do_sample": 0,
35
- "temperature": 0,
36
  "stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
37
  }
38
  payload = {**payload, **kwargs}
39
  self.ws.send(json.dumps(payload))
40
 
41
  while True:
42
- try:
43
- data = json.loads(self.ws.recv())
44
- except json.decoder.JSONDecodeError:
45
- self.close_session()
46
- raise
47
-
48
  if not data['ok']:
49
  raise Exception(data['traceback'])
50
  yield data['outputs']
 
22
  self.ws.send(json.dumps(payload))
23
  assert json.loads(self.ws.recv())['ok'] == True
24
 
25
+ def is_session(self):
26
+ return self.ws != None
27
+
28
  def close_session(self):
29
  if self.ws:
30
  self.ws.close()
31
+ self.ws = None
32
 
33
  def generate(self, prompt, **kwargs):
34
+ try:
35
+ return self._generate(prompt, **kwargs)
36
+ except:
37
+ self.close_session()
38
+ raise
39
+
40
+ def _generate(self, prompt, **kwargs):
41
  payload = {
42
  "type": "generate",
43
  "inputs": prompt,
44
  "max_new_tokens": 1,
45
  "do_sample": 0,
46
+ "temperature": 1,
47
  "stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
48
  }
49
  payload = {**payload, **kwargs}
50
  self.ws.send(json.dumps(payload))
51
 
52
  while True:
53
+ data = json.loads(self.ws.recv())
 
 
 
 
 
54
  if not data['ok']:
55
  raise Exception(data['traceback'])
56
  yield data['outputs']