File size: 5,270 Bytes
3d3362f
 
 
820d5f8
3d3362f
 
 
 
 
 
dac87a3
 
 
 
 
 
 
 
 
 
 
3d3362f
 
 
820d5f8
 
 
 
 
dac87a3
820d5f8
3d3362f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104f494
 
 
3d3362f
 
 
 
 
dac87a3
3d3362f
820d5f8
 
3d3362f
 
 
 
 
 
820d5f8
3d3362f
dac87a3
 
 
 
820d5f8
 
 
3d3362f
dac87a3
820d5f8
 
dac87a3
820d5f8
3d3362f
dac87a3
 
 
 
 
d17e7da
 
3d3362f
 
 
 
 
 
 
 
 
dac87a3
3d3362f
 
 
 
 
 
 
512c082
3d3362f
 
 
 
58ea66a
3d3362f
 
 
 
58ea66a
dac87a3
3d3362f
 
 
dac87a3
3d3362f
 
 
 
 
 
dac87a3
104f494
dac87a3
104f494
dac87a3
3d3362f
 
 
58ea66a
 
dac87a3
 
3d3362f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python
# or gradio app.py

import traceback
import gradio as gr
import chat_client

CHAT_URL='ws://chat.petals.ml/api/v2/generate'
#CHAT_URL='ws://localhost:8000/api/v2/generate'

def generate(state, *args):
    # Save that we're in generating loop
    state['generate'] = True

    try:
        for x in _generate(state, *args):
            yield x
    finally:
        state['generate'] = False

def _generate(state, prompt, model, endseq, max_length,
        do_sample, top_k, top_p, temperature,
        add_stoptoken, copy_output):

    try:
        client = chat_client.ModelClient(CHAT_URL)
        client.open_session(f"bigscience/{model}-petals", max_length)
    except Exception:
        print(traceback.format_exc())
        yield state, prompt, "Error: " + traceback.format_exc()
        return

    if add_stoptoken:
        prompt += "</s>" if "bloomz" in model else "\n\n"

    # Translate checkbox items to actual sequences
    seq = []
    for s in endseq:
        if s == "\\n":
            seq.append("\n")
        elif s == "</s>":
            seq.append("</s>")
        elif s == "? (question mark)":
            seq.append("?")
        elif s == ". (dot)":
            seq.append(".")

    # only top_k or top_p can be set
    if top_k == 0:
        top_k = None
    if top_p == 0:
        top_p = None
    if top_p and top_k:
        top_k = None

    if not temperature:
        temperature = 1.0

    prompt2 = prompt
    output = ''

    # This render prompt dialog immediately and
    # don't wait to generator to return first result
    yield [state, prompt2, output]

    try:
        for out in client.generate(prompt,
                    max_new_tokens=1,
                    do_sample=do_sample,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    extra_stop_sequences=seq
            ):

            if not state['generate']:
                client.close_session()
                return

            output += out
            if copy_output:
                prompt2 += out

            yield state, prompt2, output
    except Exception:
        print(traceback.format_exc())
        yield state, prompt, output + "\nError: " + traceback.format_exc()
        return

def stop(state):
    """Stops generating."""
    state.update({"generate": False})
    return state

with gr.Blocks() as iface_prompt:
    gr.Markdown("""**Useful for testing raw prompts with zero, one or few-shot prompting.**""")

    with gr.Row():
        model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloom', label="Use model")

        # Additional ending sequence, at which generation shoud stop
        endseq = gr.CheckboxGroup(["\\n", "</s>", "? (question mark)", ". (dot)"],
            value=["\\n", "</s>"], label='Extra end sequences')

        # Maximum length of inference session
        max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=512, interactive=True, label="Max length")

    with gr.Row():
        with gr.Column():
            # Switch between sampling and greedy generation
            do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")

            # Should the app append stop sequence at the end of prompt or should it leave the prompt open?
            add_stoptoken = gr.Checkbox(value=True, interactive=True, label="Automatically add eos token to the prompt.")

        # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
        top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
        top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
        # TODO num_beams

        # Generation temperature
        temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")

    prompt = gr.Textbox(lines=3, label='Prompt', placeholder="Prompt Here...")
    state = gr.State({'generate': False})

    with gr.Row():
        button_generate = gr.Button("Generate")
        button_stop = gr.Button("Stop")

        # Automatically copy the output at the end of prompt
        copy_output = gr.Checkbox(label="Output -> Prompt")

    output = gr.Textbox(lines=3, label='Output')

    inputs = [state, prompt, model, endseq, max_length, do_sample,
            top_k, top_p, temperature, add_stoptoken, copy_output]
    outputs = [state, prompt, output]
    button_generate.click(generate, inputs=inputs, outputs=outputs)
    button_stop.click(stop, inputs=[state], outputs=[state])

    examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
        examples=[
        ["The SQL command to extract all the users whose name starts with A is: ", "bloom-7b1", False, 0, 0, 1, False],
        ["The Spanish translation of thank you for your help is: ", "bloom-7b1", False, 0, 0, 1, False],
        ["A human talks to a powerful AI that follows the Human's instructions.\n"
         "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
         "Human: Hi!</s>\n"
         "AI: Hi! How can I help you?</s>\n"
         "Human: What's the capital of Portugal?</s>\n"
         "AI: ", "bloomz", True, 0, 0.9, 0.75, False]
        ])