BLOOMZ_Compare / app.py
gururise's picture
add application file
e956bee
raw
history blame
6.06 kB
import gradio as gr
import threading
import codecs
#from ast import literal_eval
from datetime import datetime
import os
os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
models = {}
output = {}
def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty):
global output
n_input_tokens = inputs.shape[1]
outputs = models[model_name][1].generate(inputs,
max_new_tokens=max_new_tokens,
min_length=min_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty
)
output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
def to_md(text):
# return text.replace("\n", "<br />")
return text.replace("\n", "<br />")
def infer(
prompt,
min_length=2,
max_new_tokens=10,
temperature=0.1,
top_p=1.0,
repetition_penalty = 1.0,
stop="\n",
num_completions=1,
seed=42,
):
#gc.collect()
#torch.cuda.empty_cache()
if not models:
for model_name in MODEL_NAMES:
tokenizer = BloomTokenizerFast.from_pretrained(model_name)
model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
model = model.to(DEVICE)
models[model_name] = tokenizer, model
max_new_tokens = int(max_new_tokens)
num_completions = int(num_completions)
temperature = float(temperature)
top_p = float(top_p)
stop = stop.split(";")
repetition_penalty = float(repetition_penalty)
seed = seed
assert 1 <= max_new_tokens <= 384
assert 0 <= min_length <= max_new_tokens
assert 1 <= num_completions <= 5
assert 0.0 <= temperature <= 1.0
assert 0.0 <= top_p <= 1.0
assert 0.9 <= repetition_penalty <= 3.0
if temperature == 0.0:
temperature = 0.01
if prompt == "":
prompt = " "
threads = list()
print(f"START -> ({datetime.now()})\n")
print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
for model_name in MODEL_NAMES:
inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty))
threads.append(x)
x.start()
#n_input_tokens = inputs.shape[1]
# outputs = models[model_name][1].generate(inputs,
# max_new_tokens=max_new_tokens,
# min_length=min_length,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# repetition_penalty=repetition_penalty
# )
#output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
#output[model_name] = outputs[len(prompt):]
# Join Threads
for model_name, thread in enumerate(threads):
print(f"waiting on: {model_name}\n")
thread.join()
print(f"{model_name} thread done\n")
for model_name in MODEL_NAMES:
stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
stop = [x.strip(' ') for x in stop.split(',')]
for stop_word in stop:
if stop_word != '' and stop_word in output[model_name]:
output[model_name] = output[model_name][:output[model_name].find(stop_word)]
print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
print(f"DONE -> ({datetime.now()})\n")
return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
examples = [
[
# Question Answering
'''Please answer the following question:
Question: What is the capital of Germany?
Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"],
[
# Natural Language Interface
'''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
Possible labels: 1. entailment 2. contradiction
Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
Label: entailment
Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
Label: contradiction
Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"]
]
def main():
iface = gr.Interface(
fn=infer,
allow_flagging="never",
inputs=[
gr.Textbox(lines=20), # prompt
gr.Slider(0, 256, value=1), #min_length
gr.Slider(1, 384, value=20), # max_tokens
gr.Slider(0.0, 1.0, value=0.2), # temperature
gr.Slider(0.0, 1.0, value=0.9), # top_p
gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
gr.Textbox(lines=1, value="\\n,</s>") # stop
],
outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")],
examples=examples,
cache_examples=True,
title="BLOOM vs BLOOMZ",
description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. Please consider joining the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>'''
)
iface.launch(debug=True, share=False)
if __name__ == '__main__':
main()