slush0 commited on
Commit
2caf2e7
·
1 Parent(s): 8d5fa1d

Calculates and prints generation speed.

Browse files
Files changed (2) hide show
  1. chat.py +22 -4
  2. prompt.py +23 -6
chat.py CHANGED
@@ -4,6 +4,7 @@
4
  import traceback
5
  import gradio as gr
6
  import chat_client
 
7
  import json
8
  import re
9
 
@@ -36,6 +37,19 @@ def generate(state, prompt, model, context, output, *args):
36
  def _generate(state, prompt, model, context, output, endseq, max_length,
37
  do_sample, top_k, top_p, temperature):
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  print('prompt', prompt)
40
  eos = "</s>\n" if "bloomz" in model else "\n\n"
41
 
@@ -54,7 +68,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
54
  state['model'] = model
55
  except Exception:
56
  print(traceback.format_exc())
57
- raise gr.Error(traceback.format_exc())
58
 
59
  else:
60
  context = ''
@@ -106,7 +120,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
106
  output += prompt2
107
 
108
  # Update widgets even before we get the first response
109
- yield state, state['history'] + [[prompt, '']], None, output
110
 
111
  orig_history = state['history']
112
  new_line = ''
@@ -126,6 +140,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
126
  # Stopping generation
127
  return
128
 
 
129
  new_line += out
130
 
131
  # Detect end sequences and finish the generation
@@ -142,9 +157,12 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
142
 
143
  # Keep original history untouched as we're adding just
144
  # a chunks at one moment.
145
- state['history'] = orig_history + [[prompt, new_line]]
146
  yield state, state['history'], None, output
147
 
 
 
 
148
  except (json.decoder.JSONDecodeError, BrokenPipeError):
149
  # Session was interrupted
150
  # Handled in upstream func
@@ -160,7 +178,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
160
  state['model'] = None
161
 
162
  print(traceback.format_exc())
163
- raise gr.Error(traceback.format_exc())
164
 
165
  def reset(state):
166
  """Resets the session and clears the chat window."""
 
4
  import traceback
5
  import gradio as gr
6
  import chat_client
7
+ import time
8
  import json
9
  import re
10
 
 
37
  def _generate(state, prompt, model, context, output, endseq, max_length,
38
  do_sample, top_k, top_p, temperature):
39
 
40
+ start = time.time()
41
+ cnt = 0 # Tokens generated
42
+
43
+ def stats():
44
+ # Produces inline stats for generation speed
45
+ if cnt == 0:
46
+ return "\u2026 | ? sec/t"
47
+ if cnt > time.time() - start:
48
+ items_per_sec = cnt / (time.time() - start)
49
+ return f" | {items_per_sec:.1f} t/sec"
50
+ sec_per_item = (time.time() - start) / cnt
51
+ return f" | {sec_per_item:.1f} sec/t"
52
+
53
  print('prompt', prompt)
54
  eos = "</s>\n" if "bloomz" in model else "\n\n"
55
 
 
68
  state['model'] = model
69
  except Exception:
70
  print(traceback.format_exc())
71
+ raise gr.Error(traceback.format_exc(limit=3))
72
 
73
  else:
74
  context = ''
 
120
  output += prompt2
121
 
122
  # Update widgets even before we get the first response
123
+ yield state, state['history'] + [[prompt, stats()]], None, output
124
 
125
  orig_history = state['history']
126
  new_line = ''
 
140
  # Stopping generation
141
  return
142
 
143
+ cnt += 1
144
  new_line += out
145
 
146
  # Detect end sequences and finish the generation
 
157
 
158
  # Keep original history untouched as we're adding just
159
  # a chunks at one moment.
160
+ state['history'] = orig_history + [[prompt, new_line + stats()]]
161
  yield state, state['history'], None, output
162
 
163
+ # Final line w/o statistics
164
+ yield state, state['history'], None, output
165
+
166
  except (json.decoder.JSONDecodeError, BrokenPipeError):
167
  # Session was interrupted
168
  # Handled in upstream func
 
178
  state['model'] = None
179
 
180
  print(traceback.format_exc())
181
+ raise gr.Error(traceback.format_exc(limit=3))
182
 
183
  def reset(state):
184
  """Resets the session and clears the chat window."""
prompt.py CHANGED
@@ -4,6 +4,7 @@
4
  import traceback
5
  import gradio as gr
6
  import chat_client
 
7
 
8
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
9
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
@@ -22,13 +23,25 @@ def _generate(state, prompt, model, endseq, max_length,
22
  do_sample, top_k, top_p, temperature,
23
  add_stoptoken, copy_output):
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  try:
26
  client = chat_client.ModelClient(CHAT_URL)
27
  client.open_session(f"bigscience/{model}-petals", max_length)
28
  except Exception:
29
  print(traceback.format_exc())
30
- yield state, prompt, "Error: " + traceback.format_exc()
31
- return
32
 
33
  if add_stoptoken:
34
  prompt += "</s>" if "bloomz" in model else "\n\n"
@@ -61,7 +74,7 @@ def _generate(state, prompt, model, endseq, max_length,
61
 
62
  # This render prompt dialog immediately and
63
  # don't wait to generator to return first result
64
- yield [state, prompt2, output]
65
 
66
  try:
67
  for out in client.generate(prompt,
@@ -77,15 +90,19 @@ def _generate(state, prompt, model, endseq, max_length,
77
  client.close_session()
78
  return
79
 
 
80
  output += out
 
81
  if copy_output:
82
  prompt2 += out
83
 
84
- yield state, prompt2, output
 
 
 
85
  except Exception:
86
  print(traceback.format_exc())
87
- yield state, prompt, output + "\nError: " + traceback.format_exc()
88
- return
89
 
90
  def stop(state):
91
  """Stops generating."""
 
4
  import traceback
5
  import gradio as gr
6
  import chat_client
7
+ import time
8
 
9
  CHAT_URL='ws://chat.petals.ml/api/v2/generate'
10
  #CHAT_URL='ws://localhost:8000/api/v2/generate'
 
23
  do_sample, top_k, top_p, temperature,
24
  add_stoptoken, copy_output):
25
 
26
+ start = time.time()
27
+ cnt = 0
28
+
29
+ def stats():
30
+ # Produces inline stats for generation speed
31
+ if cnt == 0:
32
+ return "\u2026 | ? sec/t"
33
+ if cnt > time.time() - start:
34
+ items_per_sec = cnt / (time.time() - start)
35
+ return f" | {items_per_sec:.1f} t/sec"
36
+ sec_per_item = (time.time() - start) / cnt
37
+ return f" | {sec_per_item:.1f} sec/t"
38
+
39
  try:
40
  client = chat_client.ModelClient(CHAT_URL)
41
  client.open_session(f"bigscience/{model}-petals", max_length)
42
  except Exception:
43
  print(traceback.format_exc())
44
+ raise gr.Error(traceback.format_exc(limit=3))
 
45
 
46
  if add_stoptoken:
47
  prompt += "</s>" if "bloomz" in model else "\n\n"
 
74
 
75
  # This render prompt dialog immediately and
76
  # don't wait to generator to return first result
77
+ yield [state, prompt2, stats()]
78
 
79
  try:
80
  for out in client.generate(prompt,
 
90
  client.close_session()
91
  return
92
 
93
+ cnt += 1
94
  output += out
95
+
96
  if copy_output:
97
  prompt2 += out
98
 
99
+ yield state, prompt2, output + stats()
100
+
101
+ # Prints final result w/o statistics
102
+ yield state, prompt2, output
103
  except Exception:
104
  print(traceback.format_exc())
105
+ raise gr.Error(traceback.format_exc(limit=3))
 
106
 
107
  def stop(state):
108
  """Stops generating."""