import gradio as gr import threading import codecs from datetime import datetime from transformers import BloomTokenizerFast from petals.client import DistributedBloomForCausalLM import torch import time DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.bfloat16 MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"] models = {MODEL_NAMES[0]:None,MODEL_NAMES[1]:None} output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""} kill = threading.Event() def stop_threads(): global kill print("Force stopping threads") kill.set() def gen_thread(model_name, prompt, max_tokens, temperature, top_p, repetition_penalty, stop): global output if kill.is_set(): return flag = False token_cnt = 0 with models[model_name][1].inference_session(max_length=512) as sess: print(f"Thread Start -> {threading.get_ident()}") output[model_name] = "" inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE) n_input_tokens = inputs.shape[1] done = False while not done and not kill.is_set(): outputs = models[model_name][1].generate( inputs, max_new_tokens=1, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, session=sess ) output[model_name] += models[model_name][0].decode(outputs[0, n_input_tokens:]) token_cnt += 1 print("\n["+ str(threading.get_ident()) + "]" + output[model_name], end="", flush=True) for stop_word in stop: stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0] if stop_word != '' and stop_word in output[model_name]: print(f"\nDONE (stop) -> {threading.get_ident()}") done = True if flag or (token_cnt >= max_tokens): print(f"\nDONE (max tokens) -> {threading.get_ident()}") done = True inputs = None # Prefix is passed only for the 1st token of the bot's response n_input_tokens = 0 print(f"\nThread End -> {threading.get_ident()}") def to_md(text): return text.replace("\n", "
") threads = list() def infer( prompt, model_idx = ["BLOOM","BLOOMZ"], max_new_tokens=10, temperature=0.1, top_p=1.0, repetition_penalty = 1.0, stop="\n", num_completions=1, seed=42, ): global threads global output global models if len(model_idx) == 0: return kill.clear() print("Loading Models\n") for idx in model_idx: model_name = MODEL_NAMES[idx] if models[model_name] == None: tokenizer = BloomTokenizerFast.from_pretrained(model_name) model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE) model = model.to(DEVICE) models[model_name] = tokenizer, model output[model_name] = "" max_new_tokens = int(max_new_tokens) temperature = float(temperature) top_p = float(top_p) stop = [x.strip(' ') for x in stop.split(',')] repetition_penalty = float(repetition_penalty) seed = seed assert 1 <= max_new_tokens <= 384 assert 1 <= num_completions <= 5 assert 0.0 <= temperature <= 1.0 assert 0.0 <= top_p <= 1.0 assert 0.9 <= repetition_penalty <= 3.0 if temperature == 0.0: temperature = 0.01 if prompt == "": prompt = " " print(f"START -> ({datetime.now()})\n") print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n") for idx in model_idx: model_name = MODEL_NAMES[idx] x = threading.Thread(target=gen_thread, args=(model_name, prompt, max_new_tokens, temperature, top_p, repetition_penalty, stop)) threads.append(x) x.start() # Join Threads for model_name, thread in enumerate(threads): while thread.is_alive(): thread.join(timeout=0.2) yield output[MODEL_NAMES[0]], output[MODEL_NAMES[1]] examples = [ [ # Question Answering '''Please answer the following question: Question: What is the capital of Germany? Answer:''',["BLOOM","BLOOMZ"] , 3, 0.2, 1.0, 1.0, "\\n,", ["BLOOM","BLOOMZ"]], [ # Natural Language Interface '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other. Possible labels: 1. entailment 2. contradiction Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes. Label: entailment Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater. Label: contradiction Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart. Label:''',["BLOOM","BLOOMZ"] , 2, 0.2, 1.0, 1.0, "\\n,"] ] def clear_prompt(): return "","","" with gr.Blocks() as demo: gr.Markdown("#

BLOOM vs BLOOMZ Comparison

") gr.Markdown("") gr.Markdown("Test Inference on the [BLOOM](https://huggingface.co/bigscience/bloom) and [BLOOMZ](https://huggingface.co/bigscience/bloomz) 176 Billion Parameter models using Petals. \ Please consider contributing your unused GPU cycles to the [Petals Swarm](https://github.com/bigscience-workshop/petals) to speed up inference.
\n \ Due to heavy resource requirements of these large models, token generation can take upwards of 3-5 seconds per token. Try to keep Max Tokens to a minimum.") gr.Markdown("") gr.Markdown("Special thanks to [RFT Capital](https://www.rftcapital.com/) for supporting our experiments with compute time dontations.") gr.Markdown("Type a Prompt and then click **Run** to see the output.") with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=17,label="Prompt",placeholder="Enter Prompt", interactive=True) with gr.Box(): chk_boxes = gr.CheckboxGroup(choices=["BLOOM","BLOOMZ"],value=["BLOOM","BLOOMZ"], type="index", label="Model") #min_length = gr.Slider(minimum=0, maximum=256, value=1, label="Minimum Length") #min_length max_tokens = gr.Slider(minimum=1, maximum=256, value=15, label="Max Tokens") # max_tokens temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.2, label="Temperature") # temperature top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.9, label="Top P") # top_p rep_penalty = gr.Slider(minimum=0.9, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty") # repetition penalty stop = gr.Textbox(lines=1, value="\\n,", label="Stop Token") # stop with gr.Column(): bloom_out = gr.Textbox(lines=7, label="BLOOM OUTPUT:") bloomz_out = gr.Textbox(lines=7,label="BLOOMZ OUTPUT:") with gr.Row(): btn_clear = gr.Button("Clear", variant="secondary") btn_run = gr.Button("Run", variant="primary") btn_stop = gr.Button("Stop", variant="stop") click_run = btn_run.click(infer, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop], outputs=[bloom_out,bloomz_out]) btn_clear.click(clear_prompt, outputs=[prompt, bloom_out, bloomz_out]) btn_stop.click(stop_threads,cancels=click_run) gr.Examples(examples, inputs=[prompt, chk_boxes, max_tokens, temperature, top_p, rep_penalty, stop]) demo.queue(concurrency_count=1) demo.launch()