Spaces:
Build error
Build error
update gradio interface for iterative outputs
Browse files
app.py
CHANGED
@@ -1,46 +1,73 @@
|
|
1 |
import gradio as gr
|
2 |
import threading
|
3 |
import codecs
|
4 |
-
#from ast import literal_eval
|
5 |
from datetime import datetime
|
6 |
-
|
7 |
-
import os
|
8 |
-
#os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
|
9 |
-
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"
|
10 |
-
|
11 |
from transformers import BloomTokenizerFast
|
12 |
from petals.client import DistributedBloomForCausalLM
|
13 |
import torch
|
14 |
-
import
|
15 |
|
16 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
TORCH_DTYPE = torch.bfloat16
|
18 |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
|
19 |
|
20 |
-
models = {}
|
21 |
-
output = {}
|
|
|
22 |
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
def gen_thread(model_name,
|
25 |
global output
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def to_md(text):
|
38 |
-
# return text.replace("\n", "<br />")
|
39 |
return text.replace("\n", "<br />")
|
40 |
|
|
|
|
|
41 |
def infer(
|
42 |
prompt,
|
43 |
-
|
44 |
max_new_tokens=10,
|
45 |
temperature=0.1,
|
46 |
top_p=1.0,
|
@@ -49,27 +76,32 @@ def infer(
|
|
49 |
num_completions=1,
|
50 |
seed=42,
|
51 |
):
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
58 |
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
|
59 |
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
|
60 |
model = model.to(DEVICE)
|
61 |
models[model_name] = tokenizer, model
|
|
|
62 |
|
63 |
max_new_tokens = int(max_new_tokens)
|
64 |
-
num_completions = int(num_completions)
|
65 |
temperature = float(temperature)
|
66 |
top_p = float(top_p)
|
67 |
-
stop = stop.split(
|
68 |
repetition_penalty = float(repetition_penalty)
|
69 |
seed = seed
|
70 |
|
71 |
assert 1 <= max_new_tokens <= 384
|
72 |
-
assert 0 <= min_length <= max_new_tokens
|
73 |
assert 1 <= num_completions <= 5
|
74 |
assert 0.0 <= temperature <= 1.0
|
75 |
assert 0.0 <= top_p <= 1.0
|
@@ -80,45 +112,19 @@ def infer(
|
|
80 |
if prompt == "":
|
81 |
prompt = " "
|
82 |
|
83 |
-
threads = list()
|
84 |
print(f"START -> ({datetime.now()})\n")
|
85 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
|
86 |
-
for
|
87 |
-
|
88 |
-
x = threading.Thread(target=gen_thread, args=(model_name,
|
89 |
threads.append(x)
|
90 |
x.start()
|
91 |
-
#n_input_tokens = inputs.shape[1]
|
92 |
-
# outputs = models[model_name][1].generate(inputs,
|
93 |
-
# max_new_tokens=max_new_tokens,
|
94 |
-
# min_length=min_length,
|
95 |
-
# do_sample=True,
|
96 |
-
# temperature=temperature,
|
97 |
-
# top_p=top_p,
|
98 |
-
# repetition_penalty=repetition_penalty
|
99 |
-
# )
|
100 |
-
#output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
|
101 |
-
|
102 |
-
#output[model_name] = outputs[len(prompt):]
|
103 |
|
104 |
# Join Threads
|
105 |
for model_name, thread in enumerate(threads):
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
for model_name in MODEL_NAMES:
|
112 |
-
stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
|
113 |
-
stop = [x.strip(' ') for x in stop.split(',')]
|
114 |
-
for stop_word in stop:
|
115 |
-
if stop_word != '' and stop_word in output[model_name]:
|
116 |
-
output[model_name] = output[model_name][:output[model_name].find(stop_word)]
|
117 |
-
|
118 |
-
print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
|
119 |
-
|
120 |
-
print(f"DONE -> ({datetime.now()})\n")
|
121 |
-
return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
|
122 |
|
123 |
|
124 |
examples = [
|
@@ -126,7 +132,7 @@ examples = [
|
|
126 |
# Question Answering
|
127 |
'''Please answer the following question:
|
128 |
Question: What is the capital of Germany?
|
129 |
-
Answer:''',
|
130 |
[
|
131 |
# Natural Language Interface
|
132 |
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
|
@@ -136,28 +142,36 @@ Label: entailment
|
|
136 |
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
|
137 |
Label: contradiction
|
138 |
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
|
139 |
-
Label:''',
|
140 |
]
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
gr.
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import threading
|
3 |
import codecs
|
|
|
4 |
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
5 |
from transformers import BloomTokenizerFast
|
6 |
from petals.client import DistributedBloomForCausalLM
|
7 |
import torch
|
8 |
+
import time
|
9 |
|
10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
TORCH_DTYPE = torch.bfloat16
|
12 |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
|
13 |
|
14 |
+
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
|
15 |
+
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
|
16 |
+
kill = threading.Event()
|
17 |
|
18 |
+
def stop_threads():
|
19 |
+
global kill
|
20 |
+
print("Force stopping threads")
|
21 |
+
kill.set()
|
22 |
|
23 |
+
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
|
24 |
global output
|
25 |
+
|
26 |
+
if kill.is_set():
|
27 |
+
return
|
28 |
+
|
29 |
+
flag = False
|
30 |
+
token_cnt = 0
|
31 |
+
with models[model_name][1].inference_session(max_length=512) as sess:
|
32 |
+
print(f"Thread Start -> {threading.get_ident()}")
|
33 |
+
output[model_name] = ""
|
34 |
+
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
|
35 |
+
n_input_tokens = inputs.shape[1]
|
36 |
+
done = False
|
37 |
+
while not done and not kill.is_set():
|
38 |
+
outputs = models[model_name][1].generate(
|
39 |
+
inputs,
|
40 |
+
max_new_tokens=1,
|
41 |
+
do_sample=True,
|
42 |
+
top_p=top_p,
|
43 |
+
temperature=temperature,
|
44 |
+
repetition_penalty=repetition_penalty,
|
45 |
+
session=sess
|
46 |
+
)
|
47 |
+
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
|
48 |
+
token_cnt += 1
|
49 |
+
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
|
50 |
+
|
51 |
+
for stop_word in stop:
|
52 |
+
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
53 |
+
if stop_word != '' and stop_word in output[model_name]:
|
54 |
+
print(f"\nDONE (stop) -> {threading.get_ident()}")
|
55 |
+
done = True
|
56 |
+
if flag or (token_cnt >= max_tokens):
|
57 |
+
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
|
58 |
+
done = True
|
59 |
+
inputs = None # Prefix is passed only for the 1st token of the bot's response
|
60 |
+
n_input_tokens = 0
|
61 |
+
print(f"\nThread End -> {threading.get_ident()}")
|
62 |
|
63 |
def to_md(text):
|
|
|
64 |
return text.replace("\n", "<br />")
|
65 |
|
66 |
+
threads = list()
|
67 |
+
|
68 |
def infer(
|
69 |
prompt,
|
70 |
+
model_idx = ["BLOOM","BLOOMZ"],
|
71 |
max_new_tokens=10,
|
72 |
temperature=0.1,
|
73 |
top_p=1.0,
|
|
|
76 |
num_completions=1,
|
77 |
seed=42,
|
78 |
):
|
79 |
+
global threads
|
80 |
+
global output
|
81 |
+
global models
|
82 |
|
83 |
+
if len(model_idx) == 0:
|
84 |
+
return
|
85 |
|
86 |
+
kill.clear()
|
87 |
+
print("Loading Models\n")
|
88 |
+
for idx in model_idx:
|
89 |
+
model_name = MODEL_NAMES[idx]
|
90 |
+
if models[model_name] == None:
|
91 |
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
|
92 |
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
|
93 |
model = model.to(DEVICE)
|
94 |
models[model_name] = tokenizer, model
|
95 |
+
output[model_name] = ""
|
96 |
|
97 |
max_new_tokens = int(max_new_tokens)
|
|
|
98 |
temperature = float(temperature)
|
99 |
top_p = float(top_p)
|
100 |
+
stop = [x.strip(' ') for x in stop.split(',')]
|
101 |
repetition_penalty = float(repetition_penalty)
|
102 |
seed = seed
|
103 |
|
104 |
assert 1 <= max_new_tokens <= 384
|
|
|
105 |
assert 1 <= num_completions <= 5
|
106 |
assert 0.0 <= temperature <= 1.0
|
107 |
assert 0.0 <= top_p <= 1.0
|
|
|
112 |
if prompt == "":
|
113 |
prompt = " "
|
114 |
|
|
|
115 |
print(f"START -> ({datetime.now()})\n")
|
116 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
|
117 |
+
for idx in model_idx:
|
118 |
+
model_name = MODEL_NAMES[idx]
|
119 |
+
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
|
120 |
threads.append(x)
|
121 |
x.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Join Threads
|
124 |
for model_name, thread in enumerate(threads):
|
125 |
+
while thread.is_alive():
|
126 |
+
thread.join(timeout=0.2)
|
127 |
+
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
examples = [
|
|
|
132 |
# Question Answering
|
133 |
'''Please answer the following question:
|
134 |
Question: What is the capital of Germany?
|
135 |
+
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
|
136 |
[
|
137 |
# Natural Language Interface
|
138 |
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
|
|
|
142 |
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
|
143 |
Label: contradiction
|
144 |
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
|
145 |
+
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,</s>"]
|
146 |
]
|
147 |
|
148 |
+
def clear_prompt():
|
149 |
+
return "","",""
|
150 |
+
|
151 |
+
with gr.Blocks() as demo:
|
152 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
153 |
+
with gr.Row():
|
154 |
+
with gr.Column():
|
155 |
+
prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True)
|
156 |
+
with gr.Box():
|
157 |
+
chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model")
|
158 |
+
#min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length
|
159 |
+
max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens
|
160 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature
|
161 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p
|
162 |
+
rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty
|
163 |
+
stop = gr.Textbox(lines=1, value="\\n,</s>", label="Stop Token") # stop
|
164 |
+
with gr.Column():
|
165 |
+
bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:")
|
166 |
+
bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")
|
167 |
+
with gr.Row():
|
168 |
+
btn_clear = gr.Button("Clear", variant="secondary")
|
169 |
+
btn_run = gr.Button("Run", variant="primary")
|
170 |
+
btn_stop = gr.Button("Stop", variant="stop")
|
171 |
+
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
|
172 |
+
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
|
173 |
+
btn_stop.click(stop_threads,cancels=click_run)
|
174 |
+
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
|
175 |
+
|
176 |
+
demo.queue(concurrency_count=3)
|
177 |
+
demo.launch()
|