Spaces:
Build error
Build error
updates
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
import threading
|
3 |
import codecs
|
4 |
from datetime import datetime
|
|
|
5 |
from transformers import BloomTokenizerFast
|
6 |
from petals.client import DistributedBloomForCausalLM
|
7 |
import torch
|
@@ -11,63 +11,18 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
11 |
TORCH_DTYPE = torch.bfloat16
|
12 |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
|
13 |
|
14 |
-
models = {
|
15 |
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
|
16 |
-
kill = threading.Event()
|
17 |
|
18 |
-
def stop_threads():
|
19 |
-
global kill
|
20 |
-
print("Force stopping threads")
|
21 |
-
kill.set()
|
22 |
|
23 |
-
|
24 |
-
global output
|
25 |
-
|
26 |
-
if kill.is_set():
|
27 |
-
return
|
28 |
-
|
29 |
-
flag = False
|
30 |
-
token_cnt = 0
|
31 |
-
with models[model_name][1].inference_session(max_length=512) as sess:
|
32 |
-
print(f"Thread Start -> {threading.get_ident()}")
|
33 |
-
output[model_name] = ""
|
34 |
-
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
|
35 |
-
n_input_tokens = inputs.shape[1]
|
36 |
-
done = False
|
37 |
-
while not done and not kill.is_set():
|
38 |
-
outputs = models[model_name][1].generate(
|
39 |
-
inputs,
|
40 |
-
max_new_tokens=1,
|
41 |
-
do_sample=True,
|
42 |
-
top_p=top_p,
|
43 |
-
temperature=temperature,
|
44 |
-
repetition_penalty=repetition_penalty,
|
45 |
-
session=sess
|
46 |
-
)
|
47 |
-
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
|
48 |
-
token_cnt += 1
|
49 |
-
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
|
50 |
-
|
51 |
-
for stop_word in stop:
|
52 |
-
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
53 |
-
if stop_word != '' and stop_word in output[model_name]:
|
54 |
-
print(f"\nDONE (stop) -> {threading.get_ident()}")
|
55 |
-
done = True
|
56 |
-
if flag or (token_cnt >= max_tokens):
|
57 |
-
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
|
58 |
-
done = True
|
59 |
-
inputs = None # Prefix is passed only for the 1st token of the bot's response
|
60 |
-
n_input_tokens = 0
|
61 |
-
print(f"\nThread End -> {threading.get_ident()}")
|
62 |
|
63 |
def to_md(text):
|
64 |
return text.replace("\n", "<br />")
|
65 |
|
66 |
-
threads = list()
|
67 |
-
|
68 |
def infer(
|
69 |
prompt,
|
70 |
-
model_idx =
|
71 |
max_new_tokens=10,
|
72 |
temperature=0.1,
|
73 |
top_p=1.0,
|
@@ -76,24 +31,22 @@ def infer(
|
|
76 |
num_completions=1,
|
77 |
seed=42,
|
78 |
):
|
79 |
-
global threads
|
80 |
global output
|
81 |
global models
|
82 |
|
83 |
-
if len(model_idx) == 0:
|
84 |
-
return
|
85 |
-
|
86 |
-
kill.clear()
|
87 |
print("Loading Models\n")
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
97 |
|
98 |
max_new_tokens = int(max_new_tokens)
|
99 |
temperature = float(temperature)
|
@@ -115,71 +68,96 @@ def infer(
|
|
115 |
|
116 |
print(f"START -> ({datetime.now()})\n")
|
117 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
|
118 |
-
for idx in model_idx:
|
119 |
-
model_name = MODEL_NAMES[idx]
|
120 |
-
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
|
121 |
-
threads.append(x)
|
122 |
-
x.start()
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
examples = [
|
132 |
[
|
133 |
# Question Answering
|
134 |
'''Please answer the following question:
|
135 |
Question: What is the capital of Germany?
|
136 |
-
Answer:''',
|
137 |
[
|
138 |
-
#
|
139 |
-
'''
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
]
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
with gr.Row():
|
176 |
-
btn_clear = gr.Button("Clear", variant="secondary")
|
177 |
-
btn_run = gr.Button("Run", variant="primary")
|
178 |
-
btn_stop = gr.Button("Stop", variant="stop")
|
179 |
-
click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out])
|
180 |
-
btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out])
|
181 |
-
btn_stop.click(stop_threads,cancels=click_run)
|
182 |
-
gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop])
|
183 |
-
|
184 |
-
demo.queue(concurrency_count=1)
|
185 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import codecs
|
3 |
from datetime import datetime
|
4 |
+
import gc
|
5 |
from transformers import BloomTokenizerFast
|
6 |
from petals.client import DistributedBloomForCausalLM
|
7 |
import torch
|
|
|
11 |
TORCH_DTYPE = torch.bfloat16
|
12 |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
|
13 |
|
14 |
+
models = {"model":None,"model_name":None}
|
15 |
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
|
|
|
16 |
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
print (DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def to_md(text):
|
21 |
return text.replace("\n", "<br />")
|
22 |
|
|
|
|
|
23 |
def infer(
|
24 |
prompt,
|
25 |
+
model_idx = 0,
|
26 |
max_new_tokens=10,
|
27 |
temperature=0.1,
|
28 |
top_p=1.0,
|
|
|
31 |
num_completions=1,
|
32 |
seed=42,
|
33 |
):
|
|
|
34 |
global output
|
35 |
global models
|
36 |
|
|
|
|
|
|
|
|
|
37 |
print("Loading Models\n")
|
38 |
+
model_name = MODEL_NAMES[model_idx]
|
39 |
+
if (models["model_name"] == None or models["model_name"] != model_name):
|
40 |
+
models = {"model":None,"model_name":None}
|
41 |
+
gc.collect()
|
42 |
+
if (DEVICE == "cuda"):
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
|
45 |
+
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
|
46 |
+
model = model.to(DEVICE)
|
47 |
+
models["model"] = tokenizer, model
|
48 |
+
models["model_name"] = model_name
|
49 |
+
output[model_name] = ""
|
50 |
|
51 |
max_new_tokens = int(max_new_tokens)
|
52 |
temperature = float(temperature)
|
|
|
68 |
|
69 |
print(f"START -> ({datetime.now()})\n")
|
70 |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
flag = False
|
73 |
+
token_cnt = 0
|
74 |
+
with models["model"][1].inference_session(max_length=512) as sess:
|
75 |
+
print(f"Encode Input Prompt")
|
76 |
+
output[model_name] = ""
|
77 |
+
inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
|
78 |
+
n_input_tokens = inputs.shape[1]
|
79 |
+
done = False
|
80 |
+
print(f"Start Inference ({sess})")
|
81 |
+
while not done:
|
82 |
+
outputs = models["model"][1].generate(
|
83 |
+
inputs,
|
84 |
+
max_new_tokens=1,
|
85 |
+
do_sample=True,
|
86 |
+
top_p=top_p,
|
87 |
+
temperature=temperature,
|
88 |
+
repetition_penalty=repetition_penalty,
|
89 |
+
session=sess
|
90 |
+
)
|
91 |
+
output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
|
92 |
+
token_cnt += 1
|
93 |
+
print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
|
94 |
+
yield output[model_name]
|
95 |
+
for stop_word in stop:
|
96 |
+
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
|
97 |
+
if stop_word != '' and stop_word in output[model_name]:
|
98 |
+
print(f"\nDONE (stop)")
|
99 |
+
done = True
|
100 |
+
if flag or (token_cnt >= max_new_tokens):
|
101 |
+
print(f"\nDONE (max tokens)")
|
102 |
+
done = True
|
103 |
+
inputs = None # Prefix is passed only for the 1st token of the bot's response
|
104 |
+
n_input_tokens = 0
|
105 |
+
print(f"\nEnd")
|
106 |
+
yield output[model_name]
|
107 |
|
108 |
examples = [
|
109 |
[
|
110 |
# Question Answering
|
111 |
'''Please answer the following question:
|
112 |
Question: What is the capital of Germany?
|
113 |
+
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
|
114 |
[
|
115 |
+
# Chatbot 1
|
116 |
+
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
|
117 |
+
Alex: Good morning, Fritz!
|
118 |
+
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
|
119 |
+
[
|
120 |
+
# Chatbot 1
|
121 |
+
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
|
122 |
+
Alex: Good morning, Fritz!
|
123 |
+
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
|
124 |
+
[
|
125 |
+
# Expert Answers
|
126 |
+
'''Expert Questions & Helpful Answers
|
127 |
+
Ask Research Experts
|
128 |
+
Question:
|
129 |
+
Are humans good or bad?
|
130 |
+
|
131 |
+
Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
|
132 |
+
[
|
133 |
+
# G
|
134 |
+
'''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
|
135 |
+
]
|
136 |
]
|
137 |
|
138 |
+
|
139 |
+
|
140 |
+
iface = gr.Interface(
|
141 |
+
fn=infer,
|
142 |
+
allow_flagging="never",
|
143 |
+
inputs=[
|
144 |
+
gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt
|
145 |
+
gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
|
146 |
+
gr.Slider(1, 256, value=15), # max_tokens
|
147 |
+
gr.Slider(0.0, 1.0, value=0.2), # temperature
|
148 |
+
gr.Slider(0.0, 1.0, value=0.9), # top_p
|
149 |
+
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
|
150 |
+
gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
|
151 |
+
],
|
152 |
+
outputs=gr.Textbox(lines=20, label="Generated Output:"),
|
153 |
+
|
154 |
+
examples=examples,
|
155 |
+
#cache_examples=True,
|
156 |
+
title="BLOOM vs BLOOMZ",
|
157 |
+
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
|
158 |
+
<p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
|
159 |
+
<p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
|
160 |
+
)
|
161 |
+
|
162 |
+
iface.queue(concurrency_count=2)
|
163 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|