BLOOMZ_Compare / app.py
gururise's picture
updates
8541fdf
import gradio as gr
import codecs
from datetime import datetime
import gc
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {"model":None,"model_name":None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
print (DEVICE)
def to_md(text):
return text.replace("\n", "<br />")
def infer(
prompt,
model_idx = 0,
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global output
global models
print("Loading Models\n")
model_name = MODEL_NAMES[model_idx]
if (models["model_name"] == None or models["model_name"] != model_name):
models = {"model":None,"model_name":None}
gc.collect()
if (DEVICE == "cuda"):
torch.cuda.empty_cache()
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
model = model.to(DEVICE)
models["model"] = tokenizer, model
models["model_name"] = model_name
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
flag = False
token_cnt = 0
with models["model"][1].inference_session(max_length=512) as sess:
print(f"Encode Input Prompt")
output[model_name] = ""
inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
print(f"Start Inference ({sess})")
while not done:
outputs = models["model"][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
yield output[model_name]
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop)")
done = True
if flag or (token_cnt >= max_new_tokens):
print(f"\nDONE (max tokens)")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nEnd")
yield output[model_name]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Expert Answers
'''Expert Questions & Helpful Answers
Ask Research Experts
Question:
Are humans good or bad?
Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
[
# G
'''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
]
]
iface = gr.Interface(
fn=infer,
allow_flagging="never",
inputs=[
gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt
gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
gr.Slider(1, 256, value=15), # max_tokens
gr.Slider(0.0, 1.0, value=0.2), # temperature
gr.Slider(0.0, 1.0, value=0.9), # top_p
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
],
outputs=gr.Textbox(lines=20, label="Generated Output:"),
examples=examples,
cache_examples=False,
title="BLOOM vs BLOOMZ",
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
<p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
<p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
)
iface.queue(concurrency_count=2)
iface.launch()