File size: 10,020 Bytes
b70859f
 
 
56e9926
104f494
b70859f
104f494
 
c461bd0
eec2373
104f494
f2475e8
b70859f
 
 
 
f2475e8
 
b70859f
57a7522
f2475e8
b70859f
f2475e8
 
e0349b7
 
57a7522
 
 
 
b70859f
57a7522
f2475e8
b70859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104f494
2caf2e7
b70859f
2caf2e7
 
 
 
 
 
 
 
 
 
 
104f494
 
b70859f
8d5fa1d
b70859f
8d5fa1d
b70859f
8d5fa1d
b70859f
 
 
 
 
 
 
 
 
57a7522
 
b70859f
eec2373
b70859f
56e9926
 
 
8d5fa1d
57a7522
b70859f
104f494
b70859f
104f494
 
 
 
 
 
 
 
8d5fa1d
104f494
b70859f
57a7522
104f494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57a7522
 
b70859f
 
104f494
b70859f
 
 
 
 
 
 
c461bd0
b70859f
 
 
f2475e8
b70859f
57a7522
f2475e8
 
2caf2e7
57a7522
104f494
 
 
 
57a7522
 
104f494
b70859f
57a7522
b70859f
57a7522
104f494
b70859f
f2475e8
 
b70859f
 
57a7522
fd37b55
 
 
 
 
2caf2e7
b70859f
2caf2e7
e0349b7
57a7522
 
 
b70859f
 
57a7522
 
 
56e9926
57a7522
b70859f
 
57a7522
56e9926
 
104f494
b70859f
f2475e8
 
 
b70859f
f2475e8
b70859f
 
 
104f494
c461bd0
104f494
 
b70859f
c461bd0
b70859f
104f494
 
b70859f
 
c461bd0
b70859f
 
104f494
 
b70859f
 
 
 
 
 
104f494
 
 
 
 
b70859f
 
 
 
 
 
 
 
 
 
 
104f494
 
 
 
 
 
 
b70859f
 
 
104f494
b70859f
 
 
 
104f494
 
 
dac87a3
104f494
2f09f5b
 
104f494
 
f2475e8
104f494
b70859f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104f494
 
 
e0349b7
104f494
b70859f
 
104f494
b70859f
8cd6256
 
 
 
 
c461bd0
b70859f
c461bd0
b70859f
 
 
 
 
754225b
 
 
 
 
 
c461bd0
754225b
 
 
 
 
 
b70859f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import json
import re
import time
from datetime import datetime
import gradio as gr

import chat_client

CHAT_URL = "wss://chat.petals.dev/api/v2/generate"
#CHAT_URL='ws://localhost:8000/api/v2/generate'

EMPTY_STATE = {
    "generate": False,
    "model": None,
    "client": None,
    "history": [],
}


def generate(state, prompt, model, context, output, *args):
    # Save that we're in generating loop
    state["generate"] = True

    try:
        yield from _generate(state, prompt, model, context, output, *args)
    except (json.decoder.JSONDecodeError, BrokenPipeError):
        # Broken session, try to renew
        # TODO This is a bit fragile because of recursive call...
        print("Retrying session...")
        context = output
        output = ""
        yield from generate(state, prompt, model, context, output, *args)
    finally:
        state["generate"] = False


def _generate(
    state,
    prompt,
    model,
    context,
    output,
    endseq,
    max_length,
    do_sample,
    top_k,
    top_p,
    temperature,
):

    start = time.time()
    cnt = 0  # Tokens generated

    def stats():
        # Produces inline stats for generation speed
        if cnt == 0:
            return "\u2026 | ? sec/t"
        if cnt > time.time() - start:
            items_per_sec = cnt / (time.time() - start)
            return f" | {items_per_sec:.1f} t/sec"
        sec_per_item = (time.time() - start) / cnt
        return f" | {sec_per_item:.1f} sec/t"

    eos = "</s>\n" if "bloomz" in model else "\n\n"

    if state["model"] != model and output:
        # If the connection is resumed, output is truncated in generate().
        # So this executes when user change model.
        context = output
        output = ""

    # Update widgets even before we get the first response
    print("prompt", prompt)
    yield state, state["history"] + [[prompt, stats()]], "", output

    if (
        state["model"] != model
        or state["client"] == None
        or state["client"].is_session() == False
    ):

        try:
            state["client"] = chat_client.ModelClient(CHAT_URL)
            state["client"].open_session(model, max_length)
            state["model"] = model
        except Exception as e:
            print(datetime.now(), str(e)[-500:])
            raise gr.Error(str(e)[-500:])

    else:
        context = ""

    client = state["client"]
    context += eos

    # Fix eventual eos token mismatch and add eos token to context and prompt
    if "bloomz" in model:
        context = context.replace("\n\n", eos)
        prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
    else:
        context = context.replace("</s>", eos)
        context = re.sub(r"\n\n+", "\n\n", context)
        prompt2 = prompt.replace("</s>", eos) + "\n\n"

    prompt2 = f"{context}Human: {prompt2}AI:"

    # Translate checkbox items to actual sequences
    seq = []
    for s in endseq:
        if s == "Human:":
            seq.append("Human:")
        if s == "AI:":
            seq.append("AI:")
        if s == "\\n":
            seq.append("\n")
        elif s == "</s>":
            seq.append("</s>")
        elif s == "? (question mark)":
            seq.append("?")
        elif s == ". (dot)":
            seq.append(".")

    # only top_k or top_p can be set
    if top_k == 0:
        top_k = None
    if top_p == 0:
        top_p = None
    if top_p and top_k:
        top_k = None

    if temperature == 0:
        temperature = 1.0

    output += prompt2

    orig_history = state["history"]
    new_line = ""
    try:
        for out in client.generate(
            prompt2,
            max_new_tokens=1,
            do_sample=do_sample,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            stop_sequences=seq,
        ):

            if not state["generate"]:
                client.close_session()
                yield state, [], "", ""
                # Stopping generation
                return

            cnt += 1
            new_line += out

            # Detect end sequences and finish the generation
            # prematurely if found.
            for s in seq:
                spl = new_line.split(s)
                new_line = spl[0]
                if len(spl) > 1:
                    state["history"] = orig_history + [[prompt, new_line]]
                    output += new_line
                    yield state, state["history"], "", output
                    # Stopping generation
                    return

            # Keep original history untouched as we're adding just
            # a chunks at one moment.
            state["history"] = orig_history + [[prompt, new_line + stats()]]
            yield state, state["history"], "", output

            # Avoid throwing an exception by generate()
            # to prevent UI errors.
            if cnt >= max_length - 6: # FIXME Bulgarian constant
                break

        # Final line w/o statistics
        yield state, state["history"], "", output

    except (json.decoder.JSONDecodeError, BrokenPipeError):
        # Session was interrupted
        # Handled in upstream func
        client.close_session()
        state["client"] = None
        state["model"] = None

        print("Broken session!")
        raise
    except Exception as e:
        client.close_session()
        state["client"] = None
        state["model"] = None

        print(datetime.now(), str(e)[-500:])
        raise gr.Error(str(e)[-500:])


def reset(state):
    """Resets the session and clears the chat window."""
    state.update(EMPTY_STATE)
    return state, [], ""


# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_chat:
    gr.Markdown("""**Let's talk to AI in a chat!**""")

    with gr.Row():
        model = gr.Radio(
            ["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz"], value="stabilityai/StableBeluga2", label="Use model"
        )

        # Additional ending sequence, at which generation shoud stop
        endseq = gr.CheckboxGroup(
            ["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
            value=["Human:", "AI:", "</s>"],
            label="Extra end sequences",
        )

        # Maximum length of inference session
        max_length = gr.Radio(
            [64, 128, 256, 512, 1024, 2048],
            value=1024,
            interactive=True,
            label="Max length",
        )

    with gr.Row():
        with gr.Column():
            # Switch between sampling and greedy generation
            do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
            context = gr.Textbox(
                lines=3,
                label="Initial context:",
                interactive=True,
                value="A Human talks to a powerful AI that follows "
                "the Human's instructions.\n"
                "AI is talkative, friendly, positive and provides "
                "detailed answers to any question.</s>\n"
                "Human: Hi!</s>\n"
                "AI: How can I help you?",
            )

        # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
        top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
        top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
        # TODO num_beams

        # Generation temperature
        temperature = gr.Number(
            value=0.75, precision=2, interactive=True, label="Temperature"
        )

    chat = gr.Chatbot(label="Chat window")
    prompt = gr.Textbox(
        show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
    ).style(container=False)

    with gr.Row():
        button_generate = gr.Button("Generate")
        button_reset = gr.Button("Reset session")

    with gr.Accordion("Raw prompt log", open=False):
        output = gr.Textbox(lines=3, show_label=False).style(container=False)

    # Chat history
    state = gr.State(EMPTY_STATE)

    # Define button actions
    inputs = [
        state,
        prompt,
        model,
        context,
        output,
        endseq,
        max_length,
        do_sample,
        top_k,
        top_p,
        temperature,
    ]
    outputs = [state, chat, prompt, output]

    prompt.submit(generate, inputs=inputs, outputs=outputs)
    button_generate.click(generate, inputs=inputs, outputs=outputs)
    button_reset.click(reset, inputs=[state], outputs=[state, chat, output])

    examples = gr.Examples(
        inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
        examples=[
            [
                "Human talks to a powerful AI that follows the Human's instructions. "
                "AI is a smart, talkative, friendly, honest, helpful, harmless assistant to Human. "
                "AI has instant access to an online encyclopedia containing all the facts about the world "
                "and answers any question in detail. AI never says common misconceptions, "
                "outdated information, lies, fiction, myths, jokes, or memes.</s>\n"
                "AI: Hi! How can I help you?</s>\n",
                "Could you remind me please who was Neil Armstrong?",
                "stabilityai/StableBeluga2",
                True,
                0,
                0.9,
                0.75,
            ],
            [
                "Human mluví s mocnou, inteligentní a vševědoucí AI, která plní instrukce od Human. "
                "AI je výřečná, přátelská, pozitivní a poskytuje detailní odpovědi na jakoukoliv otázku.</s>\n"
                "Human: Ahoj!</s>\n"
                "AI: Ahoj! Jak ti mohu pomoci?",
                "Můžeš mi prosím připomenout, kdo byl Neil Armstrong?",
                "stabilityai/StableBeluga2",
                True,
                0,
                0.9,
                0.75,
            ],
         ],
    )