SeedEdit-APP / app_future.py
Peng-Wang's picture
Start fresh from current state
959541f
# Copyright (2024) Bytedance Ltd. and/or its affiliates
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import uuid
from loguru import logger
import hashlib
import gradio as gr
import io
import base64
from caller import (
SeedT2ICaller,
SeedEditCaller
)
from PIL import Image
API_KEY = ""
help_text = """
## How to use this Demo
1. Type-in the caption/instruction text box, and click "Generate" to generate an initial image using Seed_T2I_V14 (CFG and steps are not used here)
2. Type-in the caption/instruction text box, and click "Edit" to edit the current image.
3. Click Undo if you are not satisfied with the current results, and re-edit. Otherwise, edit will apply to current results.
4. Currently, we do not support too many rounds of editing [as shown in our video] since the current API hasn't been updated to the new model yet.
This is a demo with limited QPS and a simple interface.
For a better experience, please use Doubao/Dreamina APP.
<font size=2>Note: This demo is governed by the license of CC BY-NC \
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
including hate speech, violence, pornography, deception, etc. \
(注:本演示受CC BY-NC的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
"""
def image2str(image):
buf = io.BytesIO()
image.save(buf, format="PNG")
i_str = base64.b64encode(buf.getvalue()).decode()
return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>'
def main():
resolution = 1024
max_edit_iter = 3
cfg_t2i = {
"resolution": resolution
}
model_t2i = SeedT2ICaller(cfg_t2i)
cfg_edit = cfg_t2i
model_edit = SeedEditCaller(cfg_edit)
logger.info("All models loaded")
def generate_t2i(instruction: str, state):
logger.info("Generate images ...")
# 调用模型生成图像并捕获返回结果
gen_image, success = model_t2i.generate(instruction, batch_size=1)
# 检查生成是否成功以及生成的图像是否有效
if not success or gen_image is None:
logger.error("Image generation failed or returned None.")
raise ValueError("Image generation was unsuccessful.")
# Write cache
if state is None:
state = {}
output_md5 = hashlib.md5(gen_image.tobytes()).hexdigest()
logger.info(output_md5)
state[output_md5] = gen_image
return instruction, gen_image, state
def generate(prev_image, cur_image, cfg_scale, instruction, state):
if len(state.keys()) >= max_edit_iter:
return prev_image, cur_image, instruction, state
try:
if cur_image is None:
cur_image = prev_image
logger.info("Generating edited images ...")
if not instruction:
return prev_image, cur_image, instruction, state
logger.info("Running diffusion models ...")
image_out = f"./cache/{'-'.join(instruction.split()[:10])[:50]}_{uuid.uuid4()}.jpg"
logger.info(f"Input size {cur_image.size}")
edited_image, success = model_edit.edit(cur_image, instruction, batch_size=1, cfg_scale=cfg_scale, filename=image_out)
if not success or edited_image is None:
logger.error("Image generation failed or returned None.")
raise ValueError("Image generation was unsuccessful.")
output_md5 = hashlib.md5(edited_image.tobytes()).hexdigest()
logger.info(f"EDIT adding {output_md5}")
state[output_md5] = edited_image
return cur_image, edited_image, instruction, state
except Exception as e:
logger.error(e)
return prev_image, cur_image, instruction, state
def reset():
return 0.5, None, None, "", {}
def undo(prev_image, cur_image, instruction, state):
if cur_image is not None:
cur_md5 = hashlib.md5(cur_image.tobytes()).hexdigest()
if cur_md5 in state:
logger.info(f"UNDO removing {cur_md5}")
state.pop(cur_md5, None)
return prev_image, prev_image, instruction, state
def show_state(state):
num_cache = len(state.keys())
return f"Num Cache: {num_cache}" if num_cache < max_edit_iter else "Max edit number reached. Please reset for testing."
with gr.Blocks(css="footer {visibility: hidden}") as demo:
state = gr.State({})
with gr.Row():
with gr.Column(scale=2):
prev_image = gr.Image(label="Input Image", type="pil", interactive=True, visible=False, height=resolution, width=resolution)
cur_image = gr.Image(label="Edited Image", type="pil", interactive=True, height=resolution, width=resolution)
with gr.Column(scale=1):
with gr.Row():
generate_t2i_button = gr.Button("Generate")
generate_button = gr.Button("Edit")
reset_button = gr.Button("Reset")
undo_button = gr.Button("Undo")
with gr.Row():
instruction = gr.Textbox(lines=1, label="Caption (Generate) / Instruction (Edit)", interactive=True)
with gr.Row():
cfg_scale = gr.Slider(value=0.5, minimum=0.0, maximum=1.0, step=0.1, label="Edit Strength (CFG)", interactive=True)
with gr.Row():
output_label = gr.Label()
gr.Markdown(help_text)
# Function bindings
generate_t2i_button.click(generate_t2i, [instruction, state], [instruction, cur_image, state])
generate_button.click(generate, [prev_image, cur_image, cfg_scale, instruction, state], [prev_image, cur_image, instruction, state])
reset_button.click(reset, [], [cfg_scale, prev_image, cur_image, instruction, state])
undo_button.click(undo, [prev_image, cur_image, instruction, state], [prev_image, cur_image, instruction, state])
# Update state display
generate_t2i_button.click(show_state, [state], output_label)
generate_button.click(show_state, [state], output_label)
reset_button.click(show_state, [state], output_label)
undo_button.click(show_state, [state], output_label)
demo.launch(server_name="0.0.0.0", server_port=8024)
if __name__ == "__main__":
main()