File size: 6,061 Bytes
e956bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
import threading
import codecs
#from ast import literal_eval
from datetime import datetime

import os
os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"

from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]

models = {}
output = {}


def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty):
    global output
    n_input_tokens = inputs.shape[1]
    outputs = models[model_name][1].generate(inputs,
                    max_new_tokens=max_new_tokens, 
                    min_length=min_length,
                    do_sample=True, 
                    temperature=temperature,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty
                    )
    output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])

def to_md(text):
    # return text.replace("\n", "<br />")
    return text.replace("\n", "<br />")

def infer(
        prompt,
        min_length=2,
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        repetition_penalty = 1.0,
        stop="\n",
        num_completions=1,
        seed=42,
):

    #gc.collect()
    #torch.cuda.empty_cache()

    if not models:
        for model_name in MODEL_NAMES:
            tokenizer = BloomTokenizerFast.from_pretrained(model_name)
            model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
            model = model.to(DEVICE)
            models[model_name] = tokenizer, model

    max_new_tokens = int(max_new_tokens)
    num_completions = int(num_completions)
    temperature = float(temperature)
    top_p = float(top_p)
    stop = stop.split(";")
    repetition_penalty = float(repetition_penalty)
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 0 <= min_length <= max_new_tokens
    assert 1 <= num_completions <= 5
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0
    assert 0.9 <= repetition_penalty <= 3.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "
    
    threads = list()
    print(f"START -> ({datetime.now()})\n")
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
    for model_name in MODEL_NAMES:
        inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
        x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty))
        threads.append(x)
        x.start()
        #n_input_tokens = inputs.shape[1]
        # outputs = models[model_name][1].generate(inputs,
        #             max_new_tokens=max_new_tokens, 
        #             min_length=min_length,
        #             do_sample=True, 
        #             temperature=temperature,
        #             top_p=top_p,
        #             repetition_penalty=repetition_penalty
        #             )
        #output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])

        #output[model_name] = outputs[len(prompt):]
    
    # Join Threads
    for model_name, thread in enumerate(threads):
        print(f"waiting on: {model_name}\n")
        thread.join()
        print(f"{model_name} thread done\n")
    
    
    for model_name in MODEL_NAMES:
        stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
        stop =  [x.strip(' ') for x in stop.split(',')]
        for stop_word in stop:
            if stop_word != '' and stop_word in output[model_name]:
                output[model_name] = output[model_name][:output[model_name].find(stop_word)]

        print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
    
    print(f"DONE -> ({datetime.now()})\n")
    return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]


examples = [
    [
        # Question Answering
        '''Please answer the following question:
Question: What is the capital of Germany?
Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"],
    [
        # Natural Language Interface
        '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]


def main():
    iface = gr.Interface(
        fn=infer,
        allow_flagging="never",
        inputs=[
            gr.Textbox(lines=20),  # prompt
            gr.Slider(0, 256, value=1), #min_length
            gr.Slider(1, 384, value=20),  # max_tokens
            gr.Slider(0.0, 1.0, value=0.2),  # temperature
            gr.Slider(0.0, 1.0, value=0.9),  # top_p
            gr.Slider(0.9, 3.0, value=1.0),  # repetition penalty
            gr.Textbox(lines=1, value="\\n,</s>") # stop
        ],
        outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")],
        
        examples=examples,
        cache_examples=True,
        title="BLOOM vs BLOOMZ",
        description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. Please consider joining the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>'''
    )

    iface.launch(debug=True, share=False)


if __name__ == '__main__':
    main()