File size: 7,868 Bytes
b70859f
56e9926
3d3362f
b70859f
3d3362f
 
c461bd0
eec2373
b70859f
3d3362f
dac87a3
 
b70859f
dac87a3
 
b70859f
dac87a3
b70859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3362f
2caf2e7
 
 
 
 
b70859f
2caf2e7
 
 
 
 
 
 
 
820d5f8
 
eec2373
56e9926
 
 
3d3362f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104f494
 
 
3d3362f
b70859f
3d3362f
 
 
2caf2e7
3d3362f
820d5f8
b70859f
 
 
 
 
 
 
eec2373
b70859f
 
 
dac87a3
 
 
2caf2e7
820d5f8
2caf2e7
820d5f8
 
3d3362f
2caf2e7
 
fd37b55
 
 
 
 
2caf2e7
 
56e9926
 
 
3d3362f
b70859f
dac87a3
 
 
 
 
b70859f
 
 
d17e7da
b70859f
 
 
 
3d3362f
 
b70859f
c461bd0
b70859f
3d3362f
 
b70859f
 
c461bd0
b70859f
 
3d3362f
 
b70859f
 
 
 
 
 
3d3362f
 
 
 
 
 
b70859f
 
 
 
 
 
 
3d3362f
 
 
 
58ea66a
3d3362f
 
b70859f
 
 
3d3362f
b70859f
 
3d3362f
 
 
dac87a3
3d3362f
 
 
 
b70859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dac87a3
3d3362f
b70859f
 
3d3362f
b70859f
 
c461bd0
 
 
 
 
 
 
 
 
 
 
b70859f
 
 
 
 
 
 
 
c461bd0
b70859f
 
 
 
 
 
 
 
 
 
 
 
 
c461bd0
b70859f
 
 
 
 
 
97f8444
 
 
 
c461bd0
97f8444
 
 
 
 
 
9f02861
92a31fb
c461bd0
92a31fb
9f02861
92a31fb
 
9f02861
 
 
b70859f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import time
from datetime import datetime
import gradio as gr

import chat_client

CHAT_URL = "wss://chat.petals.dev/api/v2/generate"
#CHAT_URL='ws://localhost:8000/api/v2/generate'


def generate(state, *args):
    # Save that we're in generating loop
    state["generate"] = True

    try:
        yield from _generate(state, *args)
    finally:
        state["generate"] = False


def _generate(
    state,
    prompt,
    model,
    endseq,
    max_length,
    do_sample,
    top_k,
    top_p,
    temperature,
    add_stoptoken,
    copy_output,
):

    start = time.time()
    cnt = 0

    def stats():
        # Produces inline stats for generation speed
        # sec/t or t/sec depending on the speed
        if cnt == 0:
            return "\u2026 | ? sec/t"
        if cnt > time.time() - start:
            items_per_sec = cnt / (time.time() - start)
            return f" | {items_per_sec:.1f} t/sec"
        sec_per_item = (time.time() - start) / cnt
        return f" | {sec_per_item:.1f} sec/t"

    try:
        client = chat_client.ModelClient(CHAT_URL)
        client.open_session(model, max_length)
    except Exception as e:
        print(datetime.now(), str(e)[-500:])
        raise gr.Error(str(e)[-500:])

    if add_stoptoken:
        prompt += "</s>" if "bloomz" in model else "\n\n"

    # Translate checkbox items to actual sequences
    seq = []
    for s in endseq:
        if s == "\\n":
            seq.append("\n")
        elif s == "</s>":
            seq.append("</s>")
        elif s == "? (question mark)":
            seq.append("?")
        elif s == ". (dot)":
            seq.append(".")

    # only top_k or top_p can be set
    if top_k == 0:
        top_k = None
    if top_p == 0:
        top_p = None
    if top_p and top_k:
        top_k = None

    if not temperature:
        temperature = 1.0

    prompt2 = prompt
    output = ""

    # This render prompt dialog immediately and
    # don't wait to generator to return first result
    yield [state, prompt2, stats()]

    try:
        for out in client.generate(
            prompt,
            max_new_tokens=1,
            do_sample=do_sample,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            stop_sequences=seq,
        ):

            if not state["generate"]:
                client.close_session()
                return

            cnt += 1
            output += out

            if copy_output:
                prompt2 += out

            yield state, prompt2, output + stats()

            # Avoid throwing exception by generate()
            # to prevent UI errors.
            if cnt >= max_length - 6: # FIXME bulgarian constant
                break

        # Prints final result w/o statistics
        yield state, prompt2, output
    except Exception as e:
        print(datetime.now(), str(e)[-500:])
        raise gr.Error(str(e)[-500:])


def stop(state):
    """Stops generating."""
    state.update({"generate": False})
    return state


# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_prompt:
    gr.Markdown(
        """**Useful for testing raw prompts with zero,
        one or few-shot prompting.**"""
    )

    with gr.Row():
        model = gr.Radio(
            ["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz", "bigscience/bloom"], value="stabilityai/StableBeluga2", label="Use model"
        )

        # Additional ending sequence, at which generation shoud stop
        endseq = gr.CheckboxGroup(
            ["\\n", "</s>", "? (question mark)", ". (dot)"],
            value=["</s>"],
            label="Extra end sequences",
        )

        # Maximum length of inference session
        max_length = gr.Radio(
            [64, 128, 256, 512, 1024, 2048],
            value=512,
            interactive=True,
            label="Max length",
        )

    with gr.Row():
        with gr.Column():
            # Switch between sampling and greedy generation
            do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")

            # Should the app append stop sequence at the end of prompt
            # or should it leave the prompt open?
            add_stoptoken = gr.Checkbox(
                value=True,
                interactive=True,
                label="Automatically add eos token to the prompt.",
            )

        # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
        top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
        top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
        # TODO num_beams

        # Generation temperature
        temperature = gr.Number(
            value=0.75, precision=2, interactive=True, label="Temperature"
        )

    prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...")
    state = gr.State({"generate": False})

    with gr.Row():
        button_generate = gr.Button("Generate")
        button_stop = gr.Button("Stop")

        # Automatically copy the output at the end of prompt
        copy_output = gr.Checkbox(label="Output -> Prompt")

    output = gr.Textbox(lines=3, label="Output")

    # Define button actions
    button_generate.click(
        generate,
        inputs=[
            state,
            prompt,
            model,
            endseq,
            max_length,
            do_sample,
            top_k,
            top_p,
            temperature,
            add_stoptoken,
            copy_output,
        ],
        outputs=[state, prompt, output],
    )
    button_stop.click(stop, inputs=[state], outputs=[state])

    examples = gr.Examples(
        inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
        examples=[
            [
                "The SQL command to extract all the users whose name starts with A is: ",
                "stabilityai/StableBeluga2",
                False,
                0,
                0,
                1,
                False,
            ],
            [
                "// Returns every other value in the list as a new list.\n"
                "def every_other(l):\n",
                "stabilityai/StableBeluga2",
                False,
                0,
                0,
                1,
                False,
            ],
            [
                "The Spanish translation of thank you for your help is: ",
                "stabilityai/StableBeluga2",
                False,
                0,
                0,
                1,
                False,
            ],
            [
                "A human talks to a powerful AI that follows the Human's instructions.\n"
                "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
                "Human: Hi!</s>\n"
                "AI: Hi! How can I help you?</s>\n"
                "Human: What's the capital of Portugal?</s>\n"
                "AI: ",
                "stabilityai/StableBeluga2",
                True,
                0,
                0.9,
                0.75,
                False,
            ],
            [
                "Here is a very polite and formal e-mail writing to staff that they are fired because of late delivery of the project and higher costs:\n"
                "Dear staff,\n"
                "it is with utmost ",
                "stabilityai/StableBeluga2",
                True,
                0,
                0.9,
                0.75,
                False,
            ],
            [
                "Lorem ipsum dolor sit amet, ",
                "stabilityai/StableBeluga2",
                True,
                0,
                0.9,
                0.75,
                False,
            ],
         ],
    )