import gradio as gr
import threading
import codecs
from datetime import datetime
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
kill = threading.Event()
def stop_threads():
global kill
print("Force stopping threads")
kill.set()
def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop):
global output
if kill.is_set():
return
flag = False
token_cnt = 0
with models[model_name][1].inference_session(max_length=512) as sess:
print(f"Thread Start -> {threading.get_ident()}")
output[model_name] = ""
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
while not done and not kill.is_set():
outputs = models[model_name][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True)
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop) -> {threading.get_ident()}")
done = True
if flag or (token_cnt >= max_tokens):
print(f"\nDONE (max tokens) -> {threading.get_ident()}")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nThread End -> {threading.get_ident()}")
def to_md(text):
return text.replace("\n", "
")
threads = list()
def infer(
prompt,
model_idx = ["BLOOM","BLOOMZ"],
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global threads
global output
global models
if len(model_idx) == 0:
return
kill.clear()
print("Loading Models\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
if models[model_name] == None:
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
model = model.to(DEVICE)
models[model_name] = tokenizer, model
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
for idx in model_idx:
model_name = MODEL_NAMES[idx]
x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop))
threads.append(x)
x.start()
# Join Threads
for model_name, thread in enumerate(threads):
while thread.is_alive():
thread.join(timeout=0.2)
yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,", ["BLOOM","BLOOMZ"]],
[
# Natural Language Interface
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,"]
]
def clear_prompt():
return "","",""
with gr.Blocks() as demo:
gr.Markdown("#
BLOOM vs BLOOMZ Comparison
") gr.Markdown("") gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \ Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference.