# demo inspired by https://huggingface.co/spaces/lambdalabs/image-mixer-demo import argparse import copy import gradio as gr import torch from functools import partial from itertools import chain from torch import autocast from pytorch_lightning import seed_everything from basicsr.utils import tensor2img from ldm.inference_base import DEFAULT_NEGATIVE_PROMPT, diffusion_inference, get_adapters, get_sd_models from ldm.modules.extra_condition import api from ldm.modules.extra_condition.api import ExtraCondition, get_cond_model from ldm.modules.encoders.adapter import CoAdapterFuser import os from huggingface_hub import hf_hub_url import subprocess import shlex torch.set_grad_enabled(False) urls = { 'TencentARC/T2I-Adapter':[ 'third-party-models/body_pose_model.pth', 'third-party-models/table5_pidinet.pth', 'models/coadapter-canny-sd15v1.pth', 'models/coadapter-color-sd15v1.pth', 'models/coadapter-sketch-sd15v1.pth', 'models/coadapter-style-sd15v1.pth', 'models/coadapter-depth-sd15v1.pth', 'models/coadapter-fuser-sd15v1.pth', ], 'runwayml/stable-diffusion-v1-5': ['v1-5-pruned-emaonly.ckpt'], 'andite/anything-v4.0': ['anything-v4.5-pruned.ckpt', 'anything-v4.0.vae.pt'], } if os.path.exists('models') == False: os.mkdir('models') for repo in urls: files = urls[repo] for file in files: url = hf_hub_url(repo, file) name_ckp = url.split('/')[-1] save_path = os.path.join('models',name_ckp) if os.path.exists(save_path) == False: subprocess.run(shlex.split(f'wget {url} -O {save_path}')) supported_cond = ['style', 'color', 'sketch', 'depth', 'canny'] # config parser = argparse.ArgumentParser() parser.add_argument( '--sd_ckpt', type=str, default='models/v1-5-pruned-emaonly.ckpt', help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported', ) parser.add_argument( '--vae_ckpt', type=str, default=None, help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded', ) global_opt = parser.parse_args() global_opt.config = 'configs/stable-diffusion/sd-v1-inference.yaml' for cond_name in supported_cond: setattr(global_opt, f'{cond_name}_adapter_ckpt', f'models/coadapter-{cond_name}-sd15v1.pth') global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") global_opt.max_resolution = 512 * 512 global_opt.sampler = 'ddim' global_opt.cond_weight = 1.0 global_opt.C = 4 global_opt.f = 8 #TODO: expose style_cond_tau to users global_opt.style_cond_tau = 1.0 # stable-diffusion model sd_model, sampler = get_sd_models(global_opt) # adapters and models to processing condition inputs adapters = {} cond_models = {} torch.cuda.empty_cache() # fuser is indispensable coadapter_fuser = CoAdapterFuser(unet_channels=[320, 640, 1280, 1280], width=768, num_head=8, n_layes=3) coadapter_fuser.load_state_dict(torch.load(f'models/coadapter-fuser-sd15v1.pth')) coadapter_fuser = coadapter_fuser.to(global_opt.device) def run(*args): with torch.inference_mode(), \ sd_model.ema_scope(), \ autocast('cuda'): inps = [] for i in range(0, len(args) - 8, len(supported_cond)): inps.append(args[i:i + len(supported_cond)]) opt = copy.deepcopy(global_opt) opt.prompt, opt.neg_prompt, opt.scale, opt.n_samples, opt.seed, opt.steps, opt.resize_short_edge, opt.cond_tau \ = args[-8:] conds = [] activated_conds = [] for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)): cond_name = supported_cond[idx] if b == 'Nothing': if cond_name in adapters: adapters[cond_name]['model'] = adapters[cond_name]['model'].cpu() else: activated_conds.append(cond_name) if cond_name in adapters: adapters[cond_name]['model'] = adapters[cond_name]['model'].to(opt.device) else: adapters[cond_name] = get_adapters(opt, getattr(ExtraCondition, cond_name)) adapters[cond_name]['cond_weight'] = cond_weight process_cond_module = getattr(api, f'get_cond_{cond_name}') if b == 'Image': if cond_name not in cond_models: cond_models[cond_name] = get_cond_model(opt, getattr(ExtraCondition, cond_name)) conds.append(process_cond_module(opt, im1, 'image', cond_models[cond_name])) else: conds.append(process_cond_module(opt, im2, cond_name, None)) features = dict() for idx, cond_name in enumerate(activated_conds): cur_feats = adapters[cond_name]['model'](conds[idx]) if isinstance(cur_feats, list): for i in range(len(cur_feats)): cur_feats[i] *= adapters[cond_name]['cond_weight'] else: cur_feats *= adapters[cond_name]['cond_weight'] features[cond_name] = cur_feats adapter_features, append_to_context = coadapter_fuser(features) output_conds = [] for cond in conds: output_conds.append(tensor2img(cond, rgb2bgr=False)) ims = [] seed_everything(opt.seed) for _ in range(opt.n_samples): result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context) ims.append(tensor2img(result, rgb2bgr=False)) # Clear GPU memory cache so less likely to OOM torch.cuda.empty_cache() return ims, output_conds def change_visible(im1, im2, val): outputs = {} if val == "Image": outputs[im1] = gr.update(visible=True) outputs[im2] = gr.update(visible=False) elif val == "Nothing": outputs[im1] = gr.update(visible=False) outputs[im2] = gr.update(visible=False) else: outputs[im1] = gr.update(visible=False) outputs[im2] = gr.update(visible=True) return outputs DESCRIPTION = '''# CoAdapter [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter) This gradio demo is for a simple experience of CoAdapter: ''' with gr.Blocks(title="CoAdapter", css=".gr-box {border-color: #8136e2}") as demo: gr.Markdown(DESCRIPTION) btns = [] ims1 = [] ims2 = [] cond_weights = [] with gr.Row(): for cond_name in supported_cond: with gr.Box(): with gr.Column(): btn1 = gr.Radio( choices=["Image", cond_name, "Nothing"], label=f"Input type for {cond_name}", interactive=True, value="Nothing", ) im1 = gr.Image(source='upload', label="Image", interactive=True, visible=False, type="numpy") im2 = gr.Image(source='upload', label=cond_name, interactive=True, visible=False, type="numpy") cond_weight = gr.Slider( label="Condition weight", minimum=0, maximum=5, step=0.05, value=1, interactive=True) fn = partial(change_visible, im1, im2) btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False) btns.append(btn1) ims1.append(im1) ims2.append(im2) cond_weights.append(cond_weight) with gr.Column(): prompt = gr.Textbox(label="Prompt") neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT) scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", value=7.5, minimum=1, maximum=20, step=0.1) n_samples = gr.Slider(label="Num samples", value=1, minimum=1, maximum=8, step=1) seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1) steps = gr.Slider(label="Steps", value=50, minimum=10, maximum=100, step=1) resize_short_edge = gr.Slider(label="Image resolution", value=512, minimum=320, maximum=1024, step=1) cond_tau = gr.Slider( label="timestamp parameter that determines until which step the adapter is applied", value=1.0, minimum=0.1, maximum=1.0, step=0.05) with gr.Row(): submit = gr.Button("Generate") output = gr.Gallery().style(grid=2, height='auto') cond = gr.Gallery().style(grid=2, height='auto') inps = list(chain(btns, ims1, ims2, cond_weights)) inps.extend([prompt, neg_prompt, scale, n_samples, seed, steps, resize_short_edge, cond_tau]) submit.click(fn=run, inputs=inps, outputs=[output, cond]) # demo.launch() demo.launch(server_port=43343, server_name='0.0.0.0')