Emu2 / demo /generation_frontend.py
ryanzhangfan's picture
initial commit
9aa6aea
raw
history blame
7.99 kB
# -*- coding: utf-8 -*-
# ===================================================
#
# Author : Fan Zhang
# Email : [email protected]
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
# Create On : 2023-12-11 15:35
# Last Modified : 2023-12-19 15:02
# File Name : generation_frontend.py
# Description :
#
# ===================================================
import base64
import json
import io
import time
from PIL import Image
import requests
import gradio as gr
from emu.constants import EVA_IMAGE_SIZE
from .meta import ConvMeta, Role, DataMeta
from .utils import frontend_logger as logging
CONTROLLER_URL = ""
def submit(
meta,
enable_grd,
left,
top,
right,
bottom,
image,
text,
):
if meta is None:
meta = ConvMeta()
meta.pop_error()
if meta.has_gen:
meta.clear()
if enable_grd:
if text == "" and image is None:
logging.info(f"{meta.log_id}: invalid input: no valid data for grounding input")
gr.Error("text or image must be given if enable grounding generation")
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
meta.append(Role.USER, DataMeta.build(text=text, image=image, coordinate=[left, top, right, bottom]))
elif image is not None and text != "":
logging.info(f"{meta.log_id}: invalid input: give text and image simultaneously for single modality input")
gr.Error("Do not submit text and image data at the same time!!!")
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
elif image is not None:
meta.append(Role.USER, DataMeta.build(image=image))
elif text != "":
meta.append(Role.USER, DataMeta.build(text=text))
return meta.format_chatbot(), meta, False, 0, 0, EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, None, ""
def clear_history(meta):
if meta is None:
meta = ConvMeta()
meta.clear()
return meta.format_chatbot(), meta
def generate(meta, classifier_free_guidance, steps):
if meta is None:
meta = ConvMeta()
meta.pop_error()
meta.pop()
prompt = meta.format_prompt()
prompt_list, image_list = [], {}
for idx, p in enumerate(prompt):
if isinstance(p, Image.Image):
key = f"[<IMAGE{idx}>]"
prompt_list.append(["IMAGE", key])
buf = io.BytesIO()
p.save(buf, format="PNG")
image_list[key] = (key, io.BytesIO(buf.getvalue()), "image/png")
else:
prompt_list.append(["TEXT", p])
if len(image_list) == 0:
image_list = None
logging.info(f"{meta.log_id}: construct generation reqeust with prompt {prompt_list}")
t0 = time.time()
try:
rsp = requests.post(
CONTROLLER_URL + "/v1/mmg",
files=image_list,
data={
"log_id": meta.log_id,
"prompt": json.dumps(prompt_list),
"classifier_free_guidance": classifier_free_guidance,
"steps": steps,
},
)
except:
rsp = requests.Response()
rsp.status_code = 1099
t1 = time.time()
logging.info(f"{meta.log_id}: get response with status code: {rsp.status_code}, time: {(t1-t0)*1000:.3f}ms")
if rsp.ok:
content = json.loads(rsp.text)
if content["code"] == 0:
image = Image.open(io.BytesIO(base64.b64decode(content["data"])))
meta.append(Role.ASSISTANT, DataMeta.build(image=image, resize=False))
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: {content['data']}"))
else:
meta.append(Role.ASSISTANT, DataMeta.build(text=f"GENERATE FAILED: http failed with code {rsp.status_code}"))
return meta.format_chatbot(), meta
def build_generation(args):
global CONTROLLER_URL
CONTROLLER_URL = args.controller_url
with gr.Blocks(title="Emu", theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
imagebox = gr.Image(type="pil")
with gr.Row():
with gr.Accordion("Grounding Parameters", open=True, visible=True) as grounding_row:
enable_grd = gr.Checkbox(label="Enable")
left = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="left")
top = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=0, step=1, interactive=True, label="top")
right = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="right")
bottom = gr.Slider(minimum=0, maximum=EVA_IMAGE_SIZE, value=EVA_IMAGE_SIZE, step=1, interactive=True, label="bottom")
with gr.Row():
with gr.Accordion("Diffusion Parameters", open=True, visible=True) as parameters_row:
cfg = gr.Slider(minimum=1, maximum=30, value=3, step=0.5, interactive=True, label="classifier free guidance")
steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, interactive=True, label="steps")
with gr.Column(scale=6):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Emu Chatbot",
visible=True,
height=720,
)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and add to prompt",
visible=True,
container=False,
)
with gr.Column(scale=1, min_width=60):
add_btn = gr.Button(value="Add")
with gr.Row(visible=True) as button_row:
# upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
# downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
# regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear History")
generate_btn = gr.Button(value="Generate")
clear_btn.click(clear_history, inputs=state, outputs=[chatbot, state])
textbox.submit(
submit,
inputs=[
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
outputs=[
chatbot,
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
)
add_btn.click(
submit,
inputs=[
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
outputs=[
chatbot,
state,
enable_grd,
left,
top,
right,
bottom,
imagebox,
textbox,
],
)
generate_btn.click(
generate,
inputs=[
state,
cfg,
steps,
],
outputs=[
chatbot,
state,
]
)
return demo