#!/usr/bin/env python # or gradio app.py import traceback import gradio as gr import chat_client CHAT_URL='ws://chat.petals.ml/api/v2/generate' #CHAT_URL='ws://localhost:8000/api/v2/generate' def generate(prompt, model, endseq, max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output): try: client = chat_client.ModelClient(CHAT_URL) client.open_session(f"bigscience/{model}-petals", max_length) except Exception: print(traceback.format_exc()) yield [prompt, "Error: " + traceback.format_exc()] return if add_stoptoken: prompt += "" if "bloomz" in model else "\n\n" # Translate checkbox items to actual sequences seq = [] for s in endseq: if s == "\\n": seq.append("\n") elif s == "": seq.append("") elif s == "? (question mark)": seq.append("?") elif s == ". (dot)": seq.append(".") # only top_k or top_p can be set if top_k == 0: top_k = None if top_p == 0: top_p = None if top_p and top_k: top_k = None prompt2 = prompt output = '' # This render prompt dialog immediately and # don't wait to generator to return first result yield [prompt2, output] try: for out in client.generate(prompt, max_new_tokens=1, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, extra_stop_sequences=seq ): output += out if copy_output: prompt2 += out yield [prompt2, output] except Exception: print(traceback.format_exc()) yield [prompt, "Error: " + traceback.format_exc()] return with gr.Blocks() as iface: gr.Markdown("""# Petals playground **Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models!** This space uses websocket API of [chat.petals.ml](http://chat.petals.ml). Health status of Petals network [lives here](http://health.petals.ml). Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model. For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate. BLOOMZ performs better in chat mode and understands the instructions better.""") with gr.Row(): model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloom', label="Use model") # Additional ending sequence, at which generation shoud stop endseq = gr.CheckboxGroup(["\\n", "", "? (question mark)", ". (dot)"], value=["\\n", ""], label='Extra end sequences') # Maximum length of inference session max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=256, interactive=True, label="Max length") with gr.Row(): with gr.Column(): # Switch between sampling and greedy generation do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") # Should the app append stop sequence at the end of prompt or should it leave the prompt open? add_stoptoken = gr.Checkbox(value=True, interactive=True, label="Automatically add eos token to the prompt.") # Only one of top_k and top_p can be set. Requires "do_sample=True" to work. top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") # Generation temperature temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature") prompt = gr.Textbox(lines=2, label='Prompt', placeholder="Prompt Here...") with gr.Row(): button_generate = gr.Button("Generate") # button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet. # Automatically copy the output at the end of prompt copy_output = gr.Checkbox(label="Output -> Prompt") output = gr.Textbox(lines=3, label='Output') button_generate.click(generate, inputs=[prompt, model, endseq, max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output], outputs=[prompt, output]) examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken], examples=[ ["The SQL command to extract all the users whose name starts with A is: ", "bloom", False, 0, 0, 1, False], ["The Spanish translation of thank you for your help is: ", "bloom", False, 0, 0, 1, False], ["A human talks to a powerful AI that follows the human's instructions.\n" "Human: Hi!\n" "AI: Hi! How can I help you?\n" "Human: What's the capital of Portugal?\n" "AI: ", "bloomz", True, 0, 0.9, 0.75, False] ]) iface.queue() iface.launch()