Spaces:
Build error
Build error
File size: 6,625 Bytes
e956bee 2812e92 e956bee 26fd787 e956bee 2812e92 26fd787 e956bee 2812e92 e956bee 2812e92 e956bee 26fd787 e956bee 26fd787 2812e92 e956bee 26fd787 e956bee 2812e92 e956bee 2812e92 e956bee 2812e92 e956bee 2812e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
import codecs
from datetime import datetime
import gc
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {"model":None,"model_name":None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}
print (DEVICE)
def to_md(text):
return text.replace("\n", "<br />")
def infer(
prompt,
model_idx = 0,
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
global output
global models
print("Loading Models\n")
model_name = MODEL_NAMES[model_idx]
if (models["model_name"] == None or models["model_name"] != model_name):
models = {"model":None,"model_name":None}
gc.collect()
if (DEVICE == "cuda"):
torch.cuda.empty_cache()
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
model = model.to(DEVICE)
models["model"] = tokenizer, model
models["model_name"] = model_name
output[model_name] = ""
max_new_tokens = int(max_new_tokens)
temperature = float(temperature)
top_p = float(top_p)
stop = [x.strip(' ') for x in stop.split(',')]
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
flag = False
token_cnt = 0
with models["model"][1].inference_session(max_length=512) as sess:
print(f"Encode Input Prompt")
output[model_name] = ""
inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
n_input_tokens = inputs.shape[1]
done = False
print(f"Start Inference ({sess})")
while not done:
outputs = models["model"][1].generate(
inputs,
max_new_tokens=1,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
session=sess
)
output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
token_cnt += 1
print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
yield output[model_name]
for stop_word in stop:
stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
if stop_word != '' and stop_word in output[model_name]:
print(f"\nDONE (stop)")
done = True
if flag or (token_cnt >= max_new_tokens):
print(f"\nDONE (max tokens)")
done = True
inputs = None # Prefix is passed only for the 1st token of the bot's response
n_input_tokens = 0
print(f"\nEnd")
yield output[model_name]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Chatbot 1
'''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
[
# Expert Answers
'''Expert Questions & Helpful Answers
Ask Research Experts
Question:
Are humans good or bad?
Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
[
# G
'''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
]
]
iface = gr.Interface(
fn=infer,
allow_flagging="never",
inputs=[
gr.Textbox(lines=20,label="Input Prompt", max_lines=10), # prompt
gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
gr.Slider(1, 256, value=15), # max_tokens
gr.Slider(0.0, 1.0, value=0.2), # temperature
gr.Slider(0.0, 1.0, value=0.9), # top_p
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
],
outputs=gr.Textbox(lines=20, label="Generated Output:"),
examples=examples,
#cache_examples=True,
title="BLOOM vs BLOOMZ",
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
<p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
<p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
)
iface.queue(concurrency_count=2)
iface.launch() |