slush0's picture
Make exceptions shorter to not (completely) bloat the UI.
56e9926
raw
history blame
8.99 kB
import json
import re
import time
from datetime import datetime
import gradio as gr
import chat_client
CHAT_URL = "ws://chat.petals.ml/api/v2/generate"
# CHAT_URL='ws://localhost:8000/api/v2/generate'
EMPTY_STATE = {
"generate": False,
"model": None,
"client": None,
"history": [],
}
def generate(state, prompt, model, context, output, *args):
# Save that we're in generating loop
state["generate"] = True
try:
yield from _generate(state, prompt, model, context, output, *args)
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Broken session, try to renew
# TODO This is a bit fragile because of recursive call...
print("Retrying session...")
context = output
output = ""
yield from generate(state, prompt, model, context, output, *args)
finally:
state["generate"] = False
def _generate(
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
):
start = time.time()
cnt = 0 # Tokens generated
def stats():
# Produces inline stats for generation speed
if cnt == 0:
return "\u2026 | ? sec/t"
if cnt > time.time() - start:
items_per_sec = cnt / (time.time() - start)
return f" | {items_per_sec:.1f} t/sec"
sec_per_item = (time.time() - start) / cnt
return f" | {sec_per_item:.1f} sec/t"
eos = "</s>\n" if "bloomz" in model else "\n\n"
if state["model"] != model and output:
# If the connection is resumed, output is truncated in generate().
# So this executes when user change model.
context = output
output = ""
# Update widgets even before we get the first response
print("prompt", prompt)
yield state, state["history"] + [[prompt, stats()]], "", output
if (
state["model"] != model
or state["client"] == None
or state["client"].is_session() == False
):
try:
state["client"] = chat_client.ModelClient(CHAT_URL)
state["client"].open_session(f"bigscience/{model}-petals", max_length)
state["model"] = model
except Exception as e:
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
else:
context = ""
client = state["client"]
context += eos
# Fix eventual eos token mismatch and add eos token to context and prompt
if "bloomz" in model:
context = context.replace("\n\n", eos)
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
else:
context = context.replace("</s>", eos)
context = re.sub(r"\n\n+", "\n\n", context)
prompt2 = prompt.replace("</s>", eos) + "\n\n"
prompt2 = f"{context}Human: {prompt2}AI:"
# Translate checkbox items to actual sequences
seq = []
for s in endseq:
if s == "Human:":
seq.append("Human:")
if s == "AI:":
seq.append("AI:")
if s == "\\n":
seq.append("\n")
elif s == "</s>":
seq.append("</s>")
elif s == "? (question mark)":
seq.append("?")
elif s == ". (dot)":
seq.append(".")
# only top_k or top_p can be set
if top_k == 0:
top_k = None
if top_p == 0:
top_p = None
if top_p and top_k:
top_k = None
if temperature == 0:
temperature = 1.0
output += prompt2
orig_history = state["history"]
new_line = ""
try:
for out in client.generate(
prompt2,
max_new_tokens=1,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
extra_stop_sequences=seq,
):
if not state["generate"]:
client.close_session()
yield state, [], "", ""
# Stopping generation
return
cnt += 1
new_line += out
# Detect end sequences and finish the generation
# prematurely if found.
for s in seq:
spl = new_line.split(s)
new_line = spl[0]
if len(spl) > 1:
state["history"] = orig_history + [[prompt, new_line]]
output += new_line
yield state, state["history"], "", output
# Stopping generation
return
# Keep original history untouched as we're adding just
# a chunks at one moment.
state["history"] = orig_history + [[prompt, new_line + stats()]]
yield state, state["history"], "", output
# Final line w/o statistics
yield state, state["history"], "", output
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Session was interrupted
# Handled in upstream func
client.close_session()
state["client"] = None
state["model"] = None
print("Broken session!")
raise
except Exception as e:
client.close_session()
state["client"] = None
state["model"] = None
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
def reset(state):
"""Resets the session and clears the chat window."""
state.update(EMPTY_STATE)
return state, [], ""
# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_chat:
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
with gr.Row():
model = gr.Radio(
["bloom", "bloomz", "bloom-7b1"], value="bloomz", label="Use model"
)
# Additional ending sequence, at which generation shoud stop
endseq = gr.CheckboxGroup(
["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
value=["Human:", "AI:", "\\n", "</s>"],
label="Extra end sequences",
)
# Maximum length of inference session
max_length = gr.Radio(
[64, 128, 256, 512, 1024, 2048],
value=1024,
interactive=True,
label="Max length",
)
with gr.Row():
with gr.Column():
# Switch between sampling and greedy generation
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
context = gr.Textbox(
lines=3,
label="Initial context:",
interactive=True,
value="A Human talks to a powerful AI that follows "
"the Human's instructions.\n"
"AI is talkative, friendly, positive and provides "
"detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: How can I help you?",
)
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
# TODO num_beams
# Generation temperature
temperature = gr.Number(
value=0.75, precision=2, interactive=True, label="Temperature"
)
chat = gr.Chatbot(label="Chat window")
prompt = gr.Textbox(
show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
).style(container=False)
with gr.Row():
button_generate = gr.Button("Generate")
button_reset = gr.Button("Reset session")
with gr.Accordion("Raw prompt log", open=False):
output = gr.Textbox(lines=3, show_label=False).style(container=False)
# Chat history
state = gr.State(EMPTY_STATE)
# Define button actions
inputs = [
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
]
outputs = [state, chat, prompt, output]
prompt.submit(generate, inputs=inputs, outputs=outputs)
button_generate.click(generate, inputs=inputs, outputs=outputs)
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
examples = gr.Examples(
inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
examples=[
[
"A Human talks to a powerful AI that follows the Human's instructions. "
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: Hi! How can I help you?",
"Could you remind me please who was Neil Armstrong?",
"bloomz",
True,
0,
0.9,
0.75,
],
],
)