Spaces:
Running
Running
# Copyright (2024) Bytedance Ltd. and/or its affiliates | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import annotations | |
import uuid | |
from loguru import logger | |
import hashlib | |
import gradio as gr | |
import io | |
import base64 | |
from caller import ( | |
SeedT2ICaller, | |
SeedEditCaller | |
) | |
from PIL import Image | |
API_KEY = "" | |
help_text = """ | |
## How to use this Demo | |
1. Type-in the caption/instruction text box, and click "Generate" to generate an initial image using Seed_T2I_V14 (CFG and steps are not used here) | |
2. Type-in the caption/instruction text box, and click "Edit" to edit the current image. | |
3. Click Undo if you are not satisfied with the current results, and re-edit. Otherwise, edit will apply to current results. | |
4. Currently, we do not support too many rounds of editing [as shown in our video] since the current API hasn't been updated to the new model yet. | |
This is a demo with limited QPS and a simple interface. | |
For a better experience, please use Doubao/Dreamina APP. | |
<font size=2>Note: This demo is governed by the license of CC BY-NC \ | |
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ | |
including hate speech, violence, pornography, deception, etc. \ | |
(注:本演示受CC BY-NC的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ | |
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) | |
""" | |
def image2str(image): | |
buf = io.BytesIO() | |
image.save(buf, format="PNG") | |
i_str = base64.b64encode(buf.getvalue()).decode() | |
return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>' | |
def main(): | |
resolution = 1024 | |
max_edit_iter = 3 | |
cfg_t2i = { | |
"resolution": resolution | |
} | |
model_t2i = SeedT2ICaller(cfg_t2i) | |
cfg_edit = cfg_t2i | |
model_edit = SeedEditCaller(cfg_edit) | |
logger.info("All models loaded") | |
def generate_t2i(instruction: str, state): | |
logger.info("Generate images ...") | |
# 调用模型生成图像并捕获返回结果 | |
gen_image, success = model_t2i.generate(instruction, batch_size=1) | |
# 检查生成是否成功以及生成的图像是否有效 | |
if not success or gen_image is None: | |
logger.error("Image generation failed or returned None.") | |
raise ValueError("Image generation was unsuccessful.") | |
# Write cache | |
if state is None: | |
state = {} | |
output_md5 = hashlib.md5(gen_image.tobytes()).hexdigest() | |
logger.info(output_md5) | |
state[output_md5] = gen_image | |
return instruction, gen_image, state | |
def generate(prev_image, cur_image, cfg_scale, instruction, state): | |
if len(state.keys()) >= max_edit_iter: | |
return prev_image, cur_image, instruction, state | |
try: | |
if cur_image is None: | |
cur_image = prev_image | |
logger.info("Generating edited images ...") | |
if not instruction: | |
return prev_image, cur_image, instruction, state | |
logger.info("Running diffusion models ...") | |
image_out = f"./cache/{'-'.join(instruction.split()[:10])[:50]}_{uuid.uuid4()}.jpg" | |
logger.info(f"Input size {cur_image.size}") | |
edited_image, success = model_edit.edit(cur_image, instruction, batch_size=1, cfg_scale=cfg_scale, filename=image_out) | |
if not success or edited_image is None: | |
logger.error("Image generation failed or returned None.") | |
raise ValueError("Image generation was unsuccessful.") | |
output_md5 = hashlib.md5(edited_image.tobytes()).hexdigest() | |
logger.info(f"EDIT adding {output_md5}") | |
state[output_md5] = edited_image | |
return cur_image, edited_image, instruction, state | |
except Exception as e: | |
logger.error(e) | |
return prev_image, cur_image, instruction, state | |
def reset(): | |
return 0.5, None, None, "", {} | |
def undo(prev_image, cur_image, instruction, state): | |
if cur_image is not None: | |
cur_md5 = hashlib.md5(cur_image.tobytes()).hexdigest() | |
if cur_md5 in state: | |
logger.info(f"UNDO removing {cur_md5}") | |
state.pop(cur_md5, None) | |
return prev_image, prev_image, instruction, state | |
def show_state(state): | |
num_cache = len(state.keys()) | |
return f"Num Cache: {num_cache}" if num_cache < max_edit_iter else "Max edit number reached. Please reset for testing." | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
state = gr.State({}) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
prev_image = gr.Image(label="Input Image", type="pil", interactive=True, visible=False, height=resolution, width=resolution) | |
cur_image = gr.Image(label="Edited Image", type="pil", interactive=True, height=resolution, width=resolution) | |
with gr.Column(scale=1): | |
with gr.Row(): | |
generate_t2i_button = gr.Button("Generate") | |
generate_button = gr.Button("Edit") | |
reset_button = gr.Button("Reset") | |
undo_button = gr.Button("Undo") | |
with gr.Row(): | |
instruction = gr.Textbox(lines=1, label="Caption (Generate) / Instruction (Edit)", interactive=True) | |
with gr.Row(): | |
cfg_scale = gr.Slider(value=0.5, minimum=0.0, maximum=1.0, step=0.1, label="Edit Strength (CFG)", interactive=True) | |
with gr.Row(): | |
output_label = gr.Label() | |
gr.Markdown(help_text) | |
# Function bindings | |
generate_t2i_button.click(generate_t2i, [instruction, state], [instruction, cur_image, state]) | |
generate_button.click(generate, [prev_image, cur_image, cfg_scale, instruction, state], [prev_image, cur_image, instruction, state]) | |
reset_button.click(reset, [], [cfg_scale, prev_image, cur_image, instruction, state]) | |
undo_button.click(undo, [prev_image, cur_image, instruction, state], [prev_image, cur_image, instruction, state]) | |
# Update state display | |
generate_t2i_button.click(show_state, [state], output_label) | |
generate_button.click(show_state, [state], output_label) | |
reset_button.click(show_state, [state], output_label) | |
undo_button.click(show_state, [state], output_label) | |
demo.launch(server_name="0.0.0.0", server_port=8024) | |
if __name__ == "__main__": | |
main() | |