Spaces:
Build error
Build error
import gradio as gr | |
import threading | |
import codecs | |
#from ast import literal_eval | |
from datetime import datetime | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/' | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516" | |
from transformers import BloomTokenizerFast | |
from petals.client import DistributedBloomForCausalLM | |
import torch | |
import gc | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.bfloat16 | |
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"] | |
models = {} | |
output = {} | |
def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty): | |
global output | |
n_input_tokens = inputs.shape[1] | |
outputs = models[model_name][1].generate(inputs, | |
max_new_tokens=max_new_tokens, | |
min_length=min_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty | |
) | |
output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:]) | |
def to_md(text): | |
# return text.replace("\n", "<br />") | |
return text.replace("\n", "<br />") | |
def infer( | |
prompt, | |
min_length=2, | |
max_new_tokens=10, | |
temperature=0.1, | |
top_p=1.0, | |
repetition_penalty = 1.0, | |
stop="\n", | |
num_completions=1, | |
seed=42, | |
): | |
#gc.collect() | |
#torch.cuda.empty_cache() | |
if not models: | |
for model_name in MODEL_NAMES: | |
tokenizer = BloomTokenizerFast.from_pretrained(model_name) | |
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE) | |
model = model.to(DEVICE) | |
models[model_name] = tokenizer, model | |
max_new_tokens = int(max_new_tokens) | |
num_completions = int(num_completions) | |
temperature = float(temperature) | |
top_p = float(top_p) | |
stop = stop.split(";") | |
repetition_penalty = float(repetition_penalty) | |
seed = seed | |
assert 1 <= max_new_tokens <= 384 | |
assert 0 <= min_length <= max_new_tokens | |
assert 1 <= num_completions <= 5 | |
assert 0.0 <= temperature <= 1.0 | |
assert 0.0 <= top_p <= 1.0 | |
assert 0.9 <= repetition_penalty <= 3.0 | |
if temperature == 0.0: | |
temperature = 0.01 | |
if prompt == "": | |
prompt = " " | |
threads = list() | |
print(f"START -> ({datetime.now()})\n") | |
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n") | |
for model_name in MODEL_NAMES: | |
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE) | |
x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty)) | |
threads.append(x) | |
x.start() | |
#n_input_tokens = inputs.shape[1] | |
# outputs = models[model_name][1].generate(inputs, | |
# max_new_tokens=max_new_tokens, | |
# min_length=min_length, | |
# do_sample=True, | |
# temperature=temperature, | |
# top_p=top_p, | |
# repetition_penalty=repetition_penalty | |
# ) | |
#output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:]) | |
#output[model_name] = outputs[len(prompt):] | |
# Join Threads | |
for model_name, thread in enumerate(threads): | |
print(f"waiting on: {model_name}\n") | |
thread.join() | |
print(f"{model_name} thread done\n") | |
for model_name in MODEL_NAMES: | |
stop = codecs.getdecoder("unicode_escape")(stop[0])[0] | |
stop = [x.strip(' ') for x in stop.split(',')] | |
for stop_word in stop: | |
if stop_word != '' and stop_word in output[model_name]: | |
output[model_name] = output[model_name][:output[model_name].find(stop_word)] | |
print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n") | |
print(f"DONE -> ({datetime.now()})\n") | |
return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]] | |
examples = [ | |
[ | |
# Question Answering | |
'''Please answer the following question: | |
Question: What is the capital of Germany? | |
Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"], | |
[ | |
# Natural Language Interface | |
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other. | |
Possible labels: 1. entailment 2. contradiction | |
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes. | |
Label: entailment | |
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater. | |
Label: contradiction | |
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart. | |
Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"] | |
] | |
def main(): | |
iface = gr.Interface( | |
fn=infer, | |
allow_flagging="never", | |
inputs=[ | |
gr.Textbox(lines=20), # prompt | |
gr.Slider(0, 256, value=1), #min_length | |
gr.Slider(1, 384, value=20), # max_tokens | |
gr.Slider(0.0, 1.0, value=0.2), # temperature | |
gr.Slider(0.0, 1.0, value=0.9), # top_p | |
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty | |
gr.Textbox(lines=1, value="\\n,</s>") # stop | |
], | |
outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")], | |
examples=examples, | |
cache_examples=True, | |
title="BLOOM vs BLOOMZ", | |
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. Please consider joining the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>''' | |
) | |
iface.launch(debug=True, share=False) | |
if __name__ == '__main__': | |
main() |