|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import requests |
|
import random |
|
import os |
|
import sys |
|
import pickle |
|
from PIL import Image |
|
|
|
from tqdm.auto import tqdm |
|
from datetime import datetime |
|
|
|
import diffusers |
|
from diffusers import DDIMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import torch.nn.functional as F |
|
|
|
from utils import preprocess_mask, process_sketch, process_prompts, process_example |
|
|
|
|
|
|
|
|
|
canvas_html = "<div id='canvas-root' style='max-width:400px; margin: 0 auto'></div>" |
|
load_js = """ |
|
async () => { |
|
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js" |
|
fetch(url) |
|
.then(res => res.text()) |
|
.then(text => { |
|
const script = document.createElement('script'); |
|
script.type = "module" |
|
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); |
|
document.head.appendChild(script); |
|
}); |
|
} |
|
""" |
|
|
|
get_js_colors = """ |
|
async (canvasData) => { |
|
const canvasEl = document.getElementById("canvas-root"); |
|
return [canvasEl._data] |
|
} |
|
""" |
|
|
|
css = ''' |
|
#color-bg{display:flex;justify-content: center;align-items: center;} |
|
.color-bg-item{width: 100%; height: 32px} |
|
#main_button{width:100%} |
|
<style> |
|
''' |
|
|
|
|
|
|
|
|
|
global sreg, creg, sizereg, COUNT, creg_maps, sreg_maps, pipe, text_cond |
|
|
|
sreg = 0 |
|
creg = 0 |
|
sizereg = 0 |
|
COUNT = 0 |
|
reg_sizes = {} |
|
creg_maps = {} |
|
sreg_maps = {} |
|
text_cond = 0 |
|
device="cuda" |
|
MAX_COLORS = 12 |
|
|
|
pipe = diffusers.StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
variant="fp16").to(device) |
|
|
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
pipe.scheduler.set_timesteps(50) |
|
timesteps = pipe.scheduler.timesteps |
|
sp_sz = pipe.unet.sample_size |
|
|
|
with open('./valset.pkl', 'rb') as f: |
|
val_prompt = pickle.load(f) |
|
|
|
|
|
|
|
|
|
def mod_forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): |
|
|
|
residual = hidden_states |
|
|
|
if self.spatial_norm is not None: |
|
hidden_states = self.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape) |
|
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if self.group_norm is not None: |
|
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = self.to_q(hidden_states) |
|
|
|
global sreg, creg, COUNT, creg_maps, sreg_maps, reg_sizes, text_cond |
|
|
|
sa_ = True if encoder_hidden_states is None else False |
|
encoder_hidden_states = text_cond if encoder_hidden_states is not None else hidden_states |
|
|
|
if self.norm_cross: |
|
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = self.to_k(encoder_hidden_states) |
|
value = self.to_v(encoder_hidden_states) |
|
|
|
query = self.head_to_batch_dim(query) |
|
key = self.head_to_batch_dim(key) |
|
value = self.head_to_batch_dim(value) |
|
|
|
if COUNT/32 < 50*0.3: |
|
|
|
dtype = query.dtype |
|
if self.upcast_attention: |
|
query = query.float() |
|
key = key.float() |
|
|
|
sim = torch.baddbmm(torch.empty(query.shape[0], query.shape[1], key.shape[1], |
|
dtype=query.dtype, device=query.device), |
|
query, key.transpose(-1, -2), beta=0, alpha=self.scale) |
|
|
|
treg = torch.pow(timesteps[COUNT//32]/1000, 5) |
|
|
|
|
|
if sa_: |
|
min_value = sim[int(sim.size(0)/2):].min(-1)[0].unsqueeze(-1) |
|
max_value = sim[int(sim.size(0)/2):].max(-1)[0].unsqueeze(-1) |
|
mask = sreg_maps[sim.size(1)].repeat(self.heads,1,1) |
|
size_reg = reg_sizes[sim.size(1)].repeat(self.heads,1,1) |
|
|
|
sim[int(sim.size(0)/2):] += (mask>0)*size_reg*sreg*treg*(max_value-sim[int(sim.size(0)/2):]) |
|
sim[int(sim.size(0)/2):] -= ~(mask>0)*size_reg*sreg*treg*(sim[int(sim.size(0)/2):]-min_value) |
|
|
|
|
|
else: |
|
min_value = sim[int(sim.size(0)/2):].min(-1)[0].unsqueeze(-1) |
|
max_value = sim[int(sim.size(0)/2):].max(-1)[0].unsqueeze(-1) |
|
mask = creg_maps[sim.size(1)].repeat(self.heads,1,1) |
|
size_reg = reg_sizes[sim.size(1)].repeat(self.heads,1,1) |
|
|
|
sim[int(sim.size(0)/2):] += (mask>0)*size_reg*creg*treg*(max_value-sim[int(sim.size(0)/2):]) |
|
sim[int(sim.size(0)/2):] -= ~(mask>0)*size_reg*creg*treg*(sim[int(sim.size(0)/2):]-min_value) |
|
|
|
attention_probs = sim.softmax(dim=-1) |
|
attention_probs = attention_probs.to(dtype) |
|
|
|
else: |
|
attention_probs = self.get_attention_scores(query, key, attention_mask) |
|
|
|
COUNT += 1 |
|
|
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = self.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states) |
|
|
|
hidden_states = self.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if self.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / self.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
for _module in pipe.unet.modules(): |
|
if _module.__class__.__name__ == "Attention": |
|
_module.__class__.__call__ = mod_forward |
|
|
|
|
|
|
|
|
|
def process_generation(binary_matrixes, seed, creg_, sreg_, sizereg_, bsz, master_prompt, *prompts): |
|
|
|
global creg, sreg, sizereg |
|
creg, sreg, sizereg = creg_, sreg_, sizereg_ |
|
|
|
clipped_prompts = prompts[:len(binary_matrixes)] |
|
prompts = [master_prompt] + list(clipped_prompts) |
|
layouts = torch.cat([preprocess_mask(mask_, sp_sz, sp_sz, device) for mask_ in binary_matrixes]) |
|
|
|
text_input = pipe.tokenizer(prompts, padding="max_length", return_length=True, return_overflowing_tokens=False, |
|
max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
cond_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
uncond_input = pipe.tokenizer([""]*bsz, padding="max_length", max_length=pipe.tokenizer.model_max_length, |
|
truncation=True, return_tensors="pt") |
|
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
|
|
|
|
|
|
|
|
|
global sreg_maps, reg_sizes |
|
sreg_maps = {} |
|
reg_sizes = {} |
|
|
|
for r in range(4): |
|
res = int(sp_sz/np.power(2,r)) |
|
layouts_s = F.interpolate(layouts,(res, res),mode='nearest') |
|
layouts_s = (layouts_s.view(layouts_s.size(0),1,-1)*layouts_s.view(layouts_s.size(0),-1,1)).sum(0).unsqueeze(0).repeat(bsz,1,1) |
|
reg_sizes[np.power(res, 2)] = 1-sizereg*layouts_s.sum(-1, keepdim=True)/(np.power(res, 2)) |
|
sreg_maps[np.power(res, 2)] = layouts_s |
|
|
|
|
|
|
|
|
|
|
|
pww_maps = torch.zeros(1,77,sp_sz,sp_sz).to(device) |
|
for i in range(1,len(prompts)): |
|
wlen = text_input['length'][i] - 2 |
|
widx = text_input['input_ids'][i][1:1+wlen] |
|
for j in range(77): |
|
try: |
|
if (text_input['input_ids'][0][j:j+wlen] == widx).sum() == wlen: |
|
pww_maps[:,j:j+wlen,:,:] = layouts[i-1:i] |
|
cond_embeddings[0][j:j+wlen] = cond_embeddings[i][1:1+wlen] |
|
break |
|
except: |
|
raise gr.Error("Please check whether every segment prompt is included in the full text !") |
|
return |
|
|
|
global creg_maps |
|
creg_maps = {} |
|
for r in range(4): |
|
res = int(sp_sz/np.power(2,r)) |
|
layout_c = F.interpolate(pww_maps,(res,res),mode='nearest').view(1,77,-1).permute(0,2,1).repeat(bsz,1,1) |
|
creg_maps[np.power(res, 2)] = layout_c |
|
|
|
|
|
|
|
|
|
|
|
global text_cond |
|
text_cond = torch.cat([uncond_embeddings, cond_embeddings[:1].repeat(bsz,1,1)]) |
|
|
|
global COUNT |
|
COUNT = 0 |
|
|
|
if seed == -1: |
|
latents = torch.randn(bsz,4,sp_sz,sp_sz).to(device) |
|
else: |
|
latents = torch.randn(bsz,4,sp_sz,sp_sz, generator=torch.Generator().manual_seed(seed)).to(device) |
|
|
|
image = pipe(prompts[:1]*bsz, latents=latents).images |
|
|
|
return(image) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
binary_matrixes = gr.State([]) |
|
color_layout = gr.State([]) |
|
gr.Markdown('''## DenseDiffusion: Dense Text-to-Image Generation with Attention Modulation''') |
|
gr.Markdown(''' |
|
#### ๐บ Instruction to generate images ๐บ <br> |
|
(1) Create the image layout. <br> |
|
(2) Label each segment with a text prompt. <br> |
|
(3) Adjust the full text. The default full text is automatically concatenated from each segment's text. The default one works well, but refineing the full text will further improve the result. <br> |
|
(4) Check the generated images, and tune the hyperparameters if needed. <br> |
|
- w<sup>c</sup> : The degree of attention modulation at cross-attention layers. <br> |
|
- w<sup>s</sup> : The degree of attention modulation at self-attention layers. <br> |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Box(elem_id="main-image"): |
|
canvas_data = gr.JSON(value={}, visible=False) |
|
canvas = gr.HTML(canvas_html) |
|
button_run = gr.Button("(1) I've finished my sketch ! ๐บ", elem_id="main_button", interactive=True) |
|
|
|
prompts = [] |
|
colors = [] |
|
color_row = [None] * MAX_COLORS |
|
with gr.Column(visible=False) as post_sketch: |
|
for n in range(MAX_COLORS): |
|
if n == 0 : |
|
with gr.Row(visible=False) as color_row[n]: |
|
colors.append(gr.Image(shape=(100, 100), label="background", type="pil", image_mode="RGB", width=100, height=100)) |
|
prompts.append(gr.Textbox(label="Prompt for the background (white region)", value="")) |
|
else: |
|
with gr.Row(visible=False) as color_row[n]: |
|
colors.append(gr.Image(shape=(100, 100), label="segment "+str(n), type="pil", image_mode="RGB", width=100, height=100)) |
|
prompts.append(gr.Textbox(label="Prompt for the segment "+str(n))) |
|
|
|
get_genprompt_run = gr.Button("(2) I've finished segment labeling ! ๐บ", elem_id="prompt_button", interactive=True) |
|
|
|
with gr.Column(visible=False) as gen_prompt_vis: |
|
general_prompt = gr.Textbox(value='', label="(3) Textual Description for the entire image", interactive=True) |
|
with gr.Accordion("(4) Tune the hyperparameters", open=False): |
|
creg_ = gr.Slider(label=" w\u1D9C (The degree of attention modulation at cross-attention layers) ", minimum=0, maximum=2., value=1.0, step=0.1) |
|
sreg_ = gr.Slider(label=" w \u02E2 (The degree of attention modulation at self-attention layers) ", minimum=0, maximum=2., value=0.3, step=0.1) |
|
sizereg_ = gr.Slider(label="The degree of mask-area adaptive adjustment", minimum=0, maximum=1., value=1., step=0.1) |
|
bsz_ = gr.Slider(label="Number of Samples to generate", minimum=1, maximum=4, value=1, step=1) |
|
seed_ = gr.Slider(label="Seed", minimum=-1, maximum=999999999, value=-1, step=1) |
|
|
|
final_run_btn = gr.Button("Generate ! ๐บ") |
|
|
|
layout_path = gr.Textbox(label="layout_path", visible=False) |
|
all_prompts = gr.Textbox(label="all_prompts", visible=False) |
|
|
|
with gr.Column(): |
|
out_image = gr.Gallery(label="Result", columns=2, height='auto') |
|
|
|
button_run.click(process_sketch, inputs=[canvas_data], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors, queue=False) |
|
|
|
get_genprompt_run.click(process_prompts, inputs=[binary_matrixes, *prompts], outputs=[gen_prompt_vis, general_prompt], queue=False) |
|
|
|
final_run_btn.click(process_generation, inputs=[binary_matrixes, seed_, creg_, sreg_, sizereg_, bsz_, general_prompt, *prompts], outputs=out_image) |
|
|
|
gr.Examples( |
|
examples=[['0.png', '***'.join([val_prompt[0]['textual_condition']] + val_prompt[0]['segment_descriptions']), 381940206], |
|
['1.png', '***'.join([val_prompt[1]['textual_condition']] + val_prompt[1]['segment_descriptions']), 307504592], |
|
['5.png', '***'.join([val_prompt[5]['textual_condition']] + val_prompt[5]['segment_descriptions']), 114972190]], |
|
inputs=[layout_path, all_prompts, seed_], |
|
outputs=[post_sketch, binary_matrixes, *color_row, *colors, *prompts, gen_prompt_vis, general_prompt, seed_], |
|
fn=process_example, |
|
run_on_click=True, |
|
label='๐บ Examples ๐บ', |
|
) |
|
|
|
demo.load(None, None, None, _js=load_js) |
|
|
|
demo.launch(debug=True) |