slush0 commited on
Commit
104f494
·
1 Parent(s): d17e7da

Adding chat mode.

Browse files
Files changed (3) hide show
  1. app.py +2 -1
  2. chat.py +160 -0
  3. prompt.py +7 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
 
3
  from prompt import iface_prompt
 
4
 
5
  with gr.Blocks() as iface:
6
  gr.Markdown("""# Petals playground
@@ -13,7 +14,7 @@ with gr.Blocks() as iface:
13
 
14
  BLOOMZ performs better in chat mode and understands the instructions better.""")
15
 
16
- gr.TabbedInterface([iface_prompt, ], ["Prompt mode",])
17
 
18
  # Queues are required to enable generators
19
  iface.queue(concurrency_count=5)
 
1
  import gradio as gr
2
 
3
  from prompt import iface_prompt
4
+ from chat import iface_chat
5
 
6
  with gr.Blocks() as iface:
7
  gr.Markdown("""# Petals playground
 
14
 
15
  BLOOMZ performs better in chat mode and understands the instructions better.""")
16
 
17
+ gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
18
 
19
  # Queues are required to enable generators
20
  iface.queue(concurrency_count=5)
chat.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # or gradio app.py
3
+
4
+ import traceback
5
+ import gradio as gr
6
+ import chat_client
7
+
8
+ CHAT_URL='ws://chat.petals.ml/api/v2/generate'
9
+ #CHAT_URL='ws://localhost:8000/api/v2/generate'
10
+
11
+ def generate(prompt, model, endseq, max_length,
12
+ do_sample, top_k, top_p, temperature,
13
+ context, state):
14
+
15
+ eos = "</s>\n" if "bloomz" in model else "\n\n"
16
+
17
+ try:
18
+ client = chat_client.ModelClient(CHAT_URL)
19
+ client.open_session(f"bigscience/{model}-petals", max_length)
20
+ except Exception:
21
+ print(traceback.format_exc())
22
+ yield state, state, prompt, "Error: " + traceback.format_exc()
23
+ return
24
+
25
+ context += eos
26
+ for question, answer in state:
27
+ context += f"Human: {question}{eos}AI: {answer}{eos}"
28
+
29
+ # Fix eventual eos token mismatch and add eos token to context and prompt
30
+ if "bloomz" in model:
31
+ context = context.replace("\n\n", eos)
32
+ prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
33
+ else:
34
+ context = context.replace("</s>", eos)
35
+ prompt2 = prompt.replace("</s>", eos) + "\n\n"
36
+
37
+ prompt2 = f"{context}Human: {prompt2}AI: "
38
+
39
+ # Translate checkbox items to actual sequences
40
+ seq = []
41
+ for s in endseq:
42
+ if s == "Human:":
43
+ seq.append("Human:")
44
+ if s == "AI:":
45
+ seq.append("AI:")
46
+ if s == "\\n":
47
+ seq.append("\n")
48
+ elif s == "</s>":
49
+ seq.append("</s>")
50
+ elif s == "? (question mark)":
51
+ seq.append("?")
52
+ elif s == ". (dot)":
53
+ seq.append(".")
54
+
55
+ # only top_k or top_p can be set
56
+ if top_k == 0:
57
+ top_k = None
58
+ if top_p == 0:
59
+ top_p = None
60
+ if top_p and top_k:
61
+ top_k = None
62
+
63
+ if temperature == 0:
64
+ temperature = 1.0
65
+
66
+ # Update widgets even before we get the first response
67
+ yield state + [[prompt, '']], state, None, prompt2
68
+
69
+ output = ''
70
+ output_raw = ''
71
+ try:
72
+ for out in client.generate(prompt2,
73
+ max_new_tokens=1,
74
+ do_sample=do_sample,
75
+ temperature=temperature,
76
+ top_k=top_k,
77
+ top_p=top_p,
78
+ extra_stop_sequences=seq
79
+ ):
80
+
81
+ output_raw += out
82
+ output += out
83
+
84
+ # Detect end sequences and finish the generation
85
+ # prematurely if found.
86
+ for s in seq:
87
+ spl = output.split(s)
88
+ output = spl[0]
89
+ if len(spl) > 1:
90
+ state2 = state + [[prompt, output]]
91
+ yield state2, state2, None, prompt2 + output_raw
92
+ return
93
+
94
+ state2 = state + [[prompt, output]]
95
+ yield state2, state2, None, prompt2 + output_raw
96
+ except Exception:
97
+ print(traceback.format_exc())
98
+ yield state, state, prompt, "Error: " + traceback.format_exc()
99
+ return
100
+
101
+ with gr.Blocks() as iface_chat:
102
+ gr.Markdown("""**Let's talk to Bloom in a chat!**""")
103
+
104
+ with gr.Row():
105
+ model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloomz', label="Use model")
106
+
107
+ # Additional ending sequence, at which generation shoud stop
108
+ endseq = gr.CheckboxGroup(["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
109
+ value=["Human:", "AI:", "\\n", "</s>"], label='Extra end sequences')
110
+
111
+ # Maximum length of inference session
112
+ max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length")
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ # Switch between sampling and greedy generation
117
+ do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
118
+ context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
119
+ value="A human talks to a powerful AI that follows the human's instructions.</s>\n"
120
+ "Human: Hi!</s>\n"
121
+ "AI: How can I help you?")
122
+
123
+ # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
124
+ top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
125
+ top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
126
+ # TODO num_beams
127
+
128
+ # Generation temperature
129
+ temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
130
+
131
+
132
+ chat = gr.Chatbot(label='Chat window')
133
+ prompt = gr.Textbox(show_label=False,
134
+ placeholder="Prompt Here and press Enter...").style(container=False)
135
+
136
+ with gr.Row():
137
+ button_generate = gr.Button("Generate")
138
+ # button_clear = gr.Button("Clear session") # TODO
139
+ # button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
140
+
141
+ output = gr.Textbox(lines=3, label='Raw Prompt Log')
142
+
143
+ # Chat history
144
+ state = gr.State([])
145
+
146
+ inputs = [prompt, model, endseq, max_length, do_sample,
147
+ top_k, top_p, temperature, context, state]
148
+ outputs=[chat, state, prompt, output]
149
+
150
+ prompt.submit(generate, inputs=inputs, outputs=outputs)
151
+ button_generate.click(generate, inputs=inputs, outputs=outputs)
152
+
153
+ examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
154
+ examples=[
155
+ ["A human talks to a powerful AI that follows the human's instructions.</s>\n"
156
+ "Human: Hi!</s>\n"
157
+ "AI: Hi! How can I help you?",
158
+ "Could you remind me please what's the capital of Portugal?",
159
+ "bloomz", True, 0, 0.9, 0.75]
160
+ ])
prompt.py CHANGED
@@ -43,6 +43,9 @@ def generate(prompt, model, endseq, max_length,
43
  if top_p and top_k:
44
  top_k = None
45
 
 
 
 
46
  prompt2 = prompt
47
  output = ''
48
 
@@ -110,8 +113,10 @@ with gr.Blocks() as iface_prompt:
110
 
111
  output = gr.Textbox(lines=3, label='Output')
112
 
113
- button_generate.click(generate, inputs=[prompt, model, endseq,
114
- max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output], outputs=[prompt, output])
 
 
115
 
116
  examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
117
  examples=[
 
43
  if top_p and top_k:
44
  top_k = None
45
 
46
+ if not temperature:
47
+ temperature = 1.0
48
+
49
  prompt2 = prompt
50
  output = ''
51
 
 
113
 
114
  output = gr.Textbox(lines=3, label='Output')
115
 
116
+ inputs = [prompt, model, endseq, max_length, do_sample,
117
+ top_k, top_p, temperature, add_stoptoken, copy_output]
118
+ outputs = [prompt, output]
119
+ button_generate.click(generate, inputs=inputs, outputs=outputs)
120
 
121
  examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
122
  examples=[