Spaces:
Build error
Build error
import gradio as gr | |
import codecs | |
from datetime import datetime | |
import gc | |
from transformers import BloomTokenizerFast | |
from petals.client import DistributedBloomForCausalLM | |
import torch | |
import time | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.bfloat16 | |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"] | |
models = {"model":None,"model_name":None} | |
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""} | |
print (DEVICE) | |
def to_md(text): | |
return text.replace("\n", "<br />") | |
def infer( | |
prompt, | |
model_idx = 0, | |
max_new_tokens=10, | |
temperature=0.1, | |
top_p=1.0, | |
repetition_penalty = 1.0, | |
stop="\n", | |
num_completions=1, | |
seed=42, | |
): | |
global output | |
global models | |
print("Loading Models\n") | |
model_name = MODEL_NAMES[model_idx] | |
if (models["model_name"] == None or models["model_name"] != model_name): | |
models = {"model":None,"model_name":None} | |
gc.collect() | |
if (DEVICE == "cuda"): | |
torch.cuda.empty_cache() | |
tokenizer = BloomTokenizerFast.from_pretrained(model_name) | |
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300) | |
model = model.to(DEVICE) | |
models["model"] = tokenizer, model | |
models["model_name"] = model_name | |
output[model_name] = "" | |
max_new_tokens = int(max_new_tokens) | |
temperature = float(temperature) | |
top_p = float(top_p) | |
stop = [x.strip(' ') for x in stop.split(',')] | |
repetition_penalty = float(repetition_penalty) | |
seed = seed | |
assert 1 <= max_new_tokens <= 384 | |
assert 1 <= num_completions <= 5 | |
assert 0.0 <= temperature <= 1.0 | |
assert 0.0 <= top_p <= 1.0 | |
assert 0.9 <= repetition_penalty <= 3.0 | |
if temperature == 0.0: | |
temperature = 0.01 | |
if prompt == "": | |
prompt = " " | |
print(f"START -> ({datetime.now()})\n") | |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n") | |
flag = False | |
token_cnt = 0 | |
with models["model"][1].inference_session(max_length=512) as sess: | |
print(f"Encode Input Prompt") | |
output[model_name] = "" | |
inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE) | |
n_input_tokens = inputs.shape[1] | |
done = False | |
print(f"Start Inference ({sess})") | |
while not done: | |
outputs = models["model"][1].generate( | |
inputs, | |
max_new_tokens=1, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
session=sess | |
) | |
output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:]) | |
token_cnt += 1 | |
print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True) | |
yield output[model_name] | |
for stop_word in stop: | |
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0] | |
if stop_word != '' and stop_word in output[model_name]: | |
print(f"\nDONE (stop)") | |
done = True | |
if flag or (token_cnt >= max_new_tokens): | |
print(f"\nDONE (max tokens)") | |
done = True | |
inputs = None # Prefix is passed only for the 1st token of the bot's response | |
n_input_tokens = 0 | |
print(f"\nEnd") | |
yield output[model_name] | |
examples = [ | |
[ | |
# Question Answering | |
'''Please answer the following question: | |
Question: What is the capital of Germany? | |
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]], | |
[ | |
# Chatbot 1 | |
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI... | |
Alex: Good morning, Fritz! | |
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"], | |
[ | |
# Chatbot 1 | |
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI... | |
Alex: Good morning, Fritz! | |
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"], | |
[ | |
# Expert Answers | |
'''Expert Questions & Helpful Answers | |
Ask Research Experts | |
Question: | |
Are humans good or bad? | |
Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"], | |
[ | |
# G | |
'''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>" | |
] | |
] | |
iface = gr.Interface( | |
fn=infer, | |
allow_flagging="never", | |
inputs=[ | |
gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt | |
gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"), | |
gr.Slider(1, 256, value=15), # max_tokens | |
gr.Slider(0.0, 1.0, value=0.2), # temperature | |
gr.Slider(0.0, 1.0, value=0.9), # top_p | |
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty | |
gr.Textbox(lines=1, value="\\n\\n,</s>") # stop | |
], | |
outputs=gr.Textbox(lines=20, label="Generated Output:"), | |
examples=examples, | |
#cache_examples=True, | |
title="BLOOM vs BLOOMZ", | |
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p> | |
<p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p> | |
<p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>''' | |
) | |
iface.queue(concurrency_count=2) | |
iface.launch() |