slush0's picture
Added automatic resuming of sessions in Chat mode (resends the context to the API).
57a7522
raw
history blame
8.05 kB
#!/usr/bin/env python
# or gradio app.py
import traceback
import gradio as gr
import chat_client
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
#CHAT_URL='ws://localhost:8000/api/v2/generate'
EMPTY_STATE = {
'generate': False,
'model': None,
'client': None,
'history': [],
}
def generate(state, prompt, model, context, output, *args):
# Save that we're in generating loop
state['generate'] = True
try:
for x in _generate(state, prompt, model, context, output, *args):
yield x
except BrokenPipeError:
# Broken session, try to renew
# TODO This is a bit fragile because of recursive call...
print("Retrying session...")
context = output
output = ''
yield from generate(state, prompt, model, context, output, *args)
finally:
state['generate'] = False
def _generate(state, prompt, model, context, output, endseq, max_length,
do_sample, top_k, top_p, temperature):
print('prompt', prompt)
eos = "</s>\n" if "bloomz" in model else "\n\n"
if state['model'] != model or \
state['client'] == None or state['client'].is_session() == False:
try:
state['client'] = chat_client.ModelClient(CHAT_URL)
state['client'].open_session(f"bigscience/{model}-petals", max_length)
state['model'] = model
except Exception:
print(traceback.format_exc())
yield state, state['history'], prompt, output, \
gr.update(visible=True, value=traceback.format_exc())
return
else:
context = ''
client = state['client']
context += eos
#for question, answer in state['history']:
# context += f"Human: {question}{eos}AI: {answer}{eos}"
# Fix eventual eos token mismatch and add eos token to context and prompt
if "bloomz" in model:
context = context.replace("\n\n", eos)
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
else:
context = context.replace("</s>", eos)
prompt2 = prompt.replace("</s>", eos) + "\n\n"
prompt2 = f"{context}Human: {prompt2}AI:"
# Translate checkbox items to actual sequences
seq = []
for s in endseq:
if s == "Human:":
seq.append("Human:")
if s == "AI:":
seq.append("AI:")
if s == "\\n":
seq.append("\n")
elif s == "</s>":
seq.append("</s>")
elif s == "? (question mark)":
seq.append("?")
elif s == ". (dot)":
seq.append(".")
# only top_k or top_p can be set
if top_k == 0:
top_k = None
if top_p == 0:
top_p = None
if top_p and top_k:
top_k = None
if temperature == 0:
temperature = 1.0
output += prompt2
# Update widgets even before we get the first response
yield state, state['history'] + [[prompt, '']], None, output, gr.update(visible=False)
orig_history = state['history']
new_line = ''
try:
for out in client.generate(prompt2,
max_new_tokens=1,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
extra_stop_sequences=seq
):
if not state['generate']:
client.close_session()
yield state, [], None, '', ''
# Stopping generation
return
new_line += out
# Detect end sequences and finish the generation
# prematurely if found.
for s in seq:
spl = new_line.split(s)
new_line = spl[0]
if len(spl) > 1:
state['history'] = orig_history + [[prompt, new_line]]
output += new_line
yield state, state['history'], None, output, ''
# Stopping generation
return
# Keep original history untouched as we're adding just
# a chunks at one moment.
state['history'] = orig_history + [[prompt, new_line]]
yield state, state['history'], None, output, ''
except BrokenPipeError:
# Session was interrupted
# Handled in upstream func
client.close_session()
state['client'] = None
state['model'] = None
print("Broken session!")
raise
except Exception:
client.close_session()
state['client'] = None
state['model'] = None
print(traceback.format_exc())
# TODO Store errors outside output log
yield state, state['history'], prompt, output, \
gr.update(visible=True, value=traceback.format_exc())
return
def reset(state):
"""Resets the session and clears the chat window."""
state.update(EMPTY_STATE)
return state, [], '', gr.update(visible=False, value='')
with gr.Blocks() as iface_chat:
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
with gr.Row():
model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloomz', label="Use model")
# Additional ending sequence, at which generation shoud stop
endseq = gr.CheckboxGroup(["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
value=["Human:", "AI:", "\\n", "</s>"], label='Extra end sequences')
# Maximum length of inference session
max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length")
with gr.Row():
with gr.Column():
# Switch between sampling and greedy generation
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
value="A human talks to a powerful AI that follows the human's instructions.\n"
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: How can I help you?")
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
# TODO num_beams
# Generation temperature
temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
chat = gr.Chatbot(label='Chat window')
prompt = gr.Textbox(show_label=False, label='Prompt',
placeholder="Prompt Here and press Enter...").style(container=False)
error = gr.Textbox(label="Error log", visible=False, elem_id="error")
with gr.Row():
button_generate = gr.Button("Generate")
button_reset = gr.Button("Reset session")
with gr.Accordion("Raw prompt log", open=False):
output = gr.Textbox(lines=3, show_label=False).style(container=False)
# Chat history
state = gr.State(EMPTY_STATE)
inputs = [state, prompt, model, context, output, endseq,
max_length, do_sample, top_k, top_p, temperature]
outputs=[state, chat, prompt, output, error]
prompt.submit(generate, inputs=inputs, outputs=outputs)
button_generate.click(generate, inputs=inputs, outputs=outputs)
button_reset.click(reset, inputs=[state], outputs=[state, chat, output, error])
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
examples=[
["A Human talks to a powerful AI that follows the Human's instructions. "
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: Hi! How can I help you?",
"Could you remind me please who was Neil Armstrong?",
"bloomz", True, 0, 0.9, 0.75],
])