# Copyright (2024) Bytedance Ltd. and/or its affiliates # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from __future__ import annotations import uuid from loguru import logger import hashlib import gradio as gr import io import base64 from caller import ( SeedT2ICaller, SeedEditCaller ) from PIL import Image API_KEY = "" help_text = """ ## How to use this Demo 1. Type-in the caption/instruction text box, and click "Generate" to generate an initial image using Seed_T2I_V14 (CFG and steps are not used here) 2. Type-in the caption/instruction text box, and click "Edit" to edit the current image. 3. Click Undo if you are not satisfied with the current results, and re-edit. Otherwise, edit will apply to current results. 4. Currently, we do not support too many rounds of editing [as shown in our video] since the current API hasn't been updated to the new model yet. This is a demo with limited QPS and a simple interface. For a better experience, please use Doubao/Dreamina APP. Note: This demo is governed by the license of CC BY-NC \ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ including hate speech, violence, pornography, deception, etc. \ (注:本演示受CC BY-NC的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) """ def image2str(image): buf = io.BytesIO() image.save(buf, format="PNG") i_str = base64.b64encode(buf.getvalue()).decode() return f'
' def main(): resolution = 1024 max_edit_iter = 3 cfg_t2i = { "resolution": resolution } model_t2i = SeedT2ICaller(cfg_t2i) cfg_edit = cfg_t2i model_edit = SeedEditCaller(cfg_edit) logger.info("All models loaded") def generate_t2i(instruction: str, state): logger.info("Generate images ...") # 调用模型生成图像并捕获返回结果 gen_image, success = model_t2i.generate(instruction, batch_size=1) # 检查生成是否成功以及生成的图像是否有效 if not success or gen_image is None: logger.error("Image generation failed or returned None.") raise ValueError("Image generation was unsuccessful.") # Write cache if state is None: state = {} output_md5 = hashlib.md5(gen_image.tobytes()).hexdigest() logger.info(output_md5) state[output_md5] = gen_image return instruction, gen_image, state def generate(prev_image, cur_image, cfg_scale, instruction, state): if len(state.keys()) >= max_edit_iter: return prev_image, cur_image, instruction, state try: if cur_image is None: cur_image = prev_image logger.info("Generating edited images ...") if not instruction: return prev_image, cur_image, instruction, state logger.info("Running diffusion models ...") image_out = f"./cache/{'-'.join(instruction.split()[:10])[:50]}_{uuid.uuid4()}.jpg" logger.info(f"Input size {cur_image.size}") edited_image, success = model_edit.edit(cur_image, instruction, batch_size=1, cfg_scale=cfg_scale, filename=image_out) if not success or edited_image is None: logger.error("Image generation failed or returned None.") raise ValueError("Image generation was unsuccessful.") output_md5 = hashlib.md5(edited_image.tobytes()).hexdigest() logger.info(f"EDIT adding {output_md5}") state[output_md5] = edited_image return cur_image, edited_image, instruction, state except Exception as e: logger.error(e) return prev_image, cur_image, instruction, state def reset(): return 0.5, None, None, "", {} def undo(prev_image, cur_image, instruction, state): if cur_image is not None: cur_md5 = hashlib.md5(cur_image.tobytes()).hexdigest() if cur_md5 in state: logger.info(f"UNDO removing {cur_md5}") state.pop(cur_md5, None) return prev_image, prev_image, instruction, state def show_state(state): num_cache = len(state.keys()) return f"Num Cache: {num_cache}" if num_cache < max_edit_iter else "Max edit number reached. Please reset for testing." with gr.Blocks(css="footer {visibility: hidden}") as demo: state = gr.State({}) with gr.Row(): with gr.Column(scale=2): prev_image = gr.Image(label="Input Image", type="pil", interactive=True, visible=False, height=resolution, width=resolution) cur_image = gr.Image(label="Edited Image", type="pil", interactive=True, height=resolution, width=resolution) with gr.Column(scale=1): with gr.Row(): generate_t2i_button = gr.Button("Generate") generate_button = gr.Button("Edit") reset_button = gr.Button("Reset") undo_button = gr.Button("Undo") with gr.Row(): instruction = gr.Textbox(lines=1, label="Caption (Generate) / Instruction (Edit)", interactive=True) with gr.Row(): cfg_scale = gr.Slider(value=0.5, minimum=0.0, maximum=1.0, step=0.1, label="Edit Strength (CFG)", interactive=True) with gr.Row(): output_label = gr.Label() gr.Markdown(help_text) # Function bindings generate_t2i_button.click(generate_t2i, [instruction, state], [instruction, cur_image, state]) generate_button.click(generate, [prev_image, cur_image, cfg_scale, instruction, state], [prev_image, cur_image, instruction, state]) reset_button.click(reset, [], [cfg_scale, prev_image, cur_image, instruction, state]) undo_button.click(undo, [prev_image, cur_image, instruction, state], [prev_image, cur_image, instruction, state]) # Update state display generate_t2i_button.click(show_state, [state], output_label) generate_button.click(show_state, [state], output_label) reset_button.click(show_state, [state], output_label) undo_button.click(show_state, [state], output_label) demo.launch(server_name="0.0.0.0", server_port=8024) if __name__ == "__main__": main()