Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# or gradio app.py | |
import traceback | |
import gradio as gr | |
import chat_client | |
CHAT_URL='ws://chat.petals.ml/api/v2/generate' | |
#CHAT_URL='ws://localhost:8000/api/v2/generate' | |
EMPTY_STATE = { | |
'generate': False, | |
'model': None, | |
'client': None, | |
'history': [], | |
} | |
def generate(state, prompt, model, context, output, *args): | |
# Save that we're in generating loop | |
state['generate'] = True | |
try: | |
for x in _generate(state, prompt, model, context, output, *args): | |
yield x | |
except BrokenPipeError: | |
# Broken session, try to renew | |
# TODO This is a bit fragile because of recursive call... | |
print("Retrying session...") | |
context = output | |
output = '' | |
yield from generate(state, prompt, model, context, output, *args) | |
finally: | |
state['generate'] = False | |
def _generate(state, prompt, model, context, output, endseq, max_length, | |
do_sample, top_k, top_p, temperature): | |
print('prompt', prompt) | |
eos = "</s>\n" if "bloomz" in model else "\n\n" | |
if state['model'] != model or \ | |
state['client'] == None or state['client'].is_session() == False: | |
try: | |
state['client'] = chat_client.ModelClient(CHAT_URL) | |
state['client'].open_session(f"bigscience/{model}-petals", max_length) | |
state['model'] = model | |
except Exception: | |
print(traceback.format_exc()) | |
yield state, state['history'], prompt, output, \ | |
gr.update(visible=True, value=traceback.format_exc()) | |
return | |
else: | |
context = '' | |
client = state['client'] | |
context += eos | |
#for question, answer in state['history']: | |
# context += f"Human: {question}{eos}AI: {answer}{eos}" | |
# Fix eventual eos token mismatch and add eos token to context and prompt | |
if "bloomz" in model: | |
context = context.replace("\n\n", eos) | |
prompt2 = prompt.replace("\n\n", eos) + "</s>\n" | |
else: | |
context = context.replace("</s>", eos) | |
prompt2 = prompt.replace("</s>", eos) + "\n\n" | |
prompt2 = f"{context}Human: {prompt2}AI:" | |
# Translate checkbox items to actual sequences | |
seq = [] | |
for s in endseq: | |
if s == "Human:": | |
seq.append("Human:") | |
if s == "AI:": | |
seq.append("AI:") | |
if s == "\\n": | |
seq.append("\n") | |
elif s == "</s>": | |
seq.append("</s>") | |
elif s == "? (question mark)": | |
seq.append("?") | |
elif s == ". (dot)": | |
seq.append(".") | |
# only top_k or top_p can be set | |
if top_k == 0: | |
top_k = None | |
if top_p == 0: | |
top_p = None | |
if top_p and top_k: | |
top_k = None | |
if temperature == 0: | |
temperature = 1.0 | |
output += prompt2 | |
# Update widgets even before we get the first response | |
yield state, state['history'] + [[prompt, '']], None, output, gr.update(visible=False) | |
orig_history = state['history'] | |
new_line = '' | |
try: | |
for out in client.generate(prompt2, | |
max_new_tokens=1, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
extra_stop_sequences=seq | |
): | |
if not state['generate']: | |
client.close_session() | |
yield state, [], None, '', '' | |
# Stopping generation | |
return | |
new_line += out | |
# Detect end sequences and finish the generation | |
# prematurely if found. | |
for s in seq: | |
spl = new_line.split(s) | |
new_line = spl[0] | |
if len(spl) > 1: | |
state['history'] = orig_history + [[prompt, new_line]] | |
output += new_line | |
yield state, state['history'], None, output, '' | |
# Stopping generation | |
return | |
# Keep original history untouched as we're adding just | |
# a chunks at one moment. | |
state['history'] = orig_history + [[prompt, new_line]] | |
yield state, state['history'], None, output, '' | |
except BrokenPipeError: | |
# Session was interrupted | |
# Handled in upstream func | |
client.close_session() | |
state['client'] = None | |
state['model'] = None | |
print("Broken session!") | |
raise | |
except Exception: | |
client.close_session() | |
state['client'] = None | |
state['model'] = None | |
print(traceback.format_exc()) | |
# TODO Store errors outside output log | |
yield state, state['history'], prompt, output, \ | |
gr.update(visible=True, value=traceback.format_exc()) | |
return | |
def reset(state): | |
"""Resets the session and clears the chat window.""" | |
state.update(EMPTY_STATE) | |
return state, [], '', gr.update(visible=False, value='') | |
with gr.Blocks() as iface_chat: | |
gr.Markdown("""**Let's talk to Bloom in a chat!**""") | |
with gr.Row(): | |
model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloomz', label="Use model") | |
# Additional ending sequence, at which generation shoud stop | |
endseq = gr.CheckboxGroup(["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"], | |
value=["Human:", "AI:", "\\n", "</s>"], label='Extra end sequences') | |
# Maximum length of inference session | |
max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length") | |
with gr.Row(): | |
with gr.Column(): | |
# Switch between sampling and greedy generation | |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") | |
context = gr.Textbox(lines=3, label="Initial context:", interactive=True, | |
value="A human talks to a powerful AI that follows the human's instructions.\n" | |
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n" | |
"Human: Hi!</s>\n" | |
"AI: How can I help you?") | |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work. | |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") | |
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") | |
# TODO num_beams | |
# Generation temperature | |
temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature") | |
chat = gr.Chatbot(label='Chat window') | |
prompt = gr.Textbox(show_label=False, label='Prompt', | |
placeholder="Prompt Here and press Enter...").style(container=False) | |
error = gr.Textbox(label="Error log", visible=False, elem_id="error") | |
with gr.Row(): | |
button_generate = gr.Button("Generate") | |
button_reset = gr.Button("Reset session") | |
with gr.Accordion("Raw prompt log", open=False): | |
output = gr.Textbox(lines=3, show_label=False).style(container=False) | |
# Chat history | |
state = gr.State(EMPTY_STATE) | |
inputs = [state, prompt, model, context, output, endseq, | |
max_length, do_sample, top_k, top_p, temperature] | |
outputs=[state, chat, prompt, output, error] | |
prompt.submit(generate, inputs=inputs, outputs=outputs) | |
button_generate.click(generate, inputs=inputs, outputs=outputs) | |
button_reset.click(reset, inputs=[state], outputs=[state, chat, output, error]) | |
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature], | |
examples=[ | |
["A Human talks to a powerful AI that follows the Human's instructions. " | |
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n" | |
"Human: Hi!</s>\n" | |
"AI: Hi! How can I help you?", | |
"Could you remind me please who was Neil Armstrong?", | |
"bloomz", True, 0, 0.9, 0.75], | |
]) | |