Emu2 / demo /chat_frontend.py
ryanzhangfan's picture
initial commit
9aa6aea
raw
history blame
8.03 kB
# -*- coding: utf-8 -*-
# ===================================================
#
# Author : Fan Zhang
# Email : [email protected]
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
# Create On : 2023-12-12 18:05
# Last Modified : 2023-12-19 15:00
# File Name : chat_frontend.py
# Description :
#
# ===================================================
import json
import io
import time
from PIL import Image
import requests
import gradio as gr
from .meta import ConvMeta, Role, DataMeta
from .utils import extract_frames
from .utils import frontend_logger as logging
CONTROLLER_URL = ""
def submit(
meta,
image,
video,
text,
num_frames,
):
if meta is None:
meta = ConvMeta()
meta.pop_error()
check_text = (text != "" and text is not None)
check_image = image is not None
check_video = video is not None
if check_text + check_image + check_video != 1:
logging.info(f"{meta.log_id}: invalid input: give multi madality simultaneously for single modality input")
gr.Error("Invalid input number, must give exactly one modality input at a time")
return meta.format_chatbot(), meta, None, None, ""
if check_text:
meta.append(Role.USER, DataMeta.build(text=text))
elif check_image:
meta.append(Role.USER, DataMeta.build(image=image))
elif check_video:
frames = extract_frames(video, num_frames)
meta.append(Role.USER, DataMeta.build(frames=frames))
return meta.format_chatbot(), meta, None, None, ""
def clear_history(meta):
if meta is None:
meta = ConvMeta()
meta.clear()
return meta.format_chatbot(), meta
def generate(
meta,
do_sample,
max_new_tokens,
temperature,
top_k,
top_p,
length_penalty,
num_beams,
repetition_penalty,
):
if meta is None:
meta = ConvMeta()
meta.pop_error()
meta.pop()
prompt = meta.format_chat()
prompt_list, image_list = [], {}
for idx, p in enumerate(prompt):
if isinstance(p, Image.Image):
key = f"[<IMAGE{idx}>]"
prompt_list.append(["IMAGE", key])
buf = io.BytesIO()
p.save(buf, format="PNG")
image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
else:
prompt_list.append(["TEXT", p])
if len(image_list) == 0:
image_list = None
logging.info(f"{meta.log_id}: construct chat reqeust with prompt {prompt_list}")
t0 = time.time()
try:
print(do_sample)
rsp = requests.post(
CONTROLLER_URL + "/v1/mmc",
files=image_list,
data={
"log_id": meta.log_id,
"prompt": json.dumps(prompt_list),
"do_sample": do_sample,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"length_penalty": length_penalty,
"num_beams": num_beams,
"repetition_penalty": repetition_penalty,
},
)
except:
rsp = requests.Response()
rsp.status_code = 1099
t1 = time.time()
logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
if rsp.ok:
content = json.loads(rsp.text)
if content["code"] == 0:
meta.append(Role.ASSISTANT, DataMeta.build(text=content["data"]))
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"), is_error=True)
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"), is_error=True)
return meta.format_chatbot(), meta
def build_chat(args):
global CONTROLLER_URL
CONTROLLER_URL = args.controller_url
with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
imagebox = gr.Image(type="pil")
with gr.Row():
videobox = gr.Video()
with gr.Accordion("Parameters", open=True, visible=True) as parameter_row:
do_sample = gr.Checkbox(value=False, label="Do Sample", interactive=True)
max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=512, step=1, interactive=True, label="Max Output Tokens")
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.05, interactive=True, label="Temperature")
top_k = gr.Slider(minimum=1, maximum=5, value=3, step=1, interactive=True, label="Top K")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, interactive=True, label="Top P")
length_penalty = gr.Slider(minimum=0, maximum=5, value=3, step=0.1, interactive=True, label="Length Penalty")
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size")
repetition_penalty = gr.Slider(minimum=1.0, maximum=10.0, value=1.0, step=0.5, interactive=True, label="Repetition Penalty")
num_frames = gr.Number(interactive=True, value=8, maximum=12, label="Num Video Frames")
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Emu Chatbot",
visible=True,
height=1070,
)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and add to prompt",
visible=True,
container=False,
)
with gr.Column(scale=1, min_width=60):
add_btn = gr.Button(value="Add")
with gr.Row(visible=True) as button_row:
# upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
# downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
# flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
# regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear History")
generate_btn = gr.Button(value="Generate")
clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
textbox.submit(
submit,
inputs=[
state,
imagebox,
videobox,
textbox,
num_frames,
],
outputs=[
chatbot,
state,
imagebox,
videobox,
textbox,
],
)
add_btn.click(
submit,
inputs=[
state,
imagebox,
videobox,
textbox,
num_frames,
],
outputs=[
chatbot,
state,
imagebox,
videobox,
textbox,
],
)
generate_btn.click(
generate,
inputs=[
state,
do_sample,
max_new_tokens,
temperature,
top_k,
top_p,
length_penalty,
num_beams,
repetition_penalty,
],
outputs=[
chatbot,
state,
],
)
return demo