basso4's picture
Upload 57 files
b6e2095 verified
raw
history blame
19.1 kB
import random
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from einops import rearrange
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from ldm.models.diffusion.ddpm import DDPM
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.util import instantiate_from_config, default
from ldm.models.diffusion.ddim import DDIMSampler
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LambdaLR
def disabled_train(self, mode=True):
return self
# =============================================================
# ๅฏ่ฎญ็ปƒ้ƒจๅˆ† ControlNet
# =============================================================
class ControlNet(nn.Module):
def __init__(
self,
in_channels, # 9
model_channels, # 320
hint_channels, # 20
attention_resolutions, # [4,2,1]
num_res_blocks, # 2
channel_mult=(1, 2, 4, 8), # [1,2,4,4]
num_head_channels=-1, # 64
transformer_depth=1, # 1
context_dim=None, # 768
use_checkpoint=False, # True
dropout=0,
conv_resample=True,
dims=2,
num_heads=-1,
use_scale_shift_norm=False):
super(ControlNet, self).__init__()
self.dims = dims
self.in_channels = in_channels
self.model_channels = model_channels
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.use_checkpoint = use_checkpoint
self.dtype = torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
# time ็ผ–็ ๅ™จ
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
# input ็ผ–็ ๅ™จ
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
# hint ็ผ–็ ๅ™จ
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
# UNet
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_heads = ch // num_head_channels
dim_head = num_head_channels
disabled_sa = False
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
num_heads = ch // num_head_channels
dim_head = num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, reference_dino):
# ๅค„็†่พ“ๅ…ฅ
context = reference_dino
t_emb = timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb)
# ้ข„ๆต‹ control
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
# =============================================================
# ๅ›บๅฎšๅ‚ๆ•ฐ้ƒจๅˆ† ControlledUnetModel
# =============================================================
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None):
hs = []
# UNet ็š„ไธŠๅŠ้ƒจๅˆ†
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
# ๆณจๅ…ฅ control
if control is not None:
h += control.pop()
# UNet ็š„ไธ‹ๅŠ้ƒจๅˆ†
for i, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
# ่พ“ๅ‡บ
h = h.type(x.dtype)
h = self.out(h)
return h
# =============================================================
# ไธปๅนฒ็ฝ‘็ปœ ControlLDM
# =============================================================
class ControlLDM(DDPM):
def __init__(self,
control_stage_config, # ControlNet
first_stage_config, # AutoencoderKL
cond_stage_config, # FrozenCLIPImageEmbedder
condi_stage_config, # FrozenCLIPTextEmbedder
scale_factor=1.0, # 0.18215
*args, **kwargs):
self.num_timesteps_cond = 1
super().__init__(*args, **kwargs) # self.model ๅ’Œ self.register_buffer
self.control_model = instantiate_from_config(control_stage_config) # self.control_model
self.instantiate_first_stage(first_stage_config) # self.first_stage_model ่ฐƒ็”จ AutoencoderKL
self.instantiate_cond_stage(cond_stage_config) # self.cond_stage_model ่ฐƒ็”จ FrozenCLIPImageEmbedder
self.instantiate_condi_stage(condi_stage_config) # self.condi_stage_model FrozenCLIPTextEmbedder
self.proj_out=nn.Linear(1024, 768) # ๅ…จ่ฟžๆŽฅๅฑ‚
self.scale_factor = scale_factor # 0.18215
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=False)
self.trainable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
self.dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
self.dinov2_vitl14.eval()
self.dinov2_vitl14.train = disabled_train
for param in self.dinov2_vitl14.parameters():
param.requires_grad = False
self.linear = nn.Linear(1024, 768)
# AutoencoderKL ไธ่ฎญ็ปƒ
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
# FrozenCLIPImageEmbedder ไธ่ฎญ็ปƒ
def instantiate_cond_stage(self, config):
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
def instantiate_condi_stage(self, config):
model = instantiate_from_config(config)
self.condi_stage_model = model.eval()
self.condi_stage_model.train = disabled_train
for param in self.condi_stage_model.parameters():
param.required_grad = False
# ่ฎญ็ปƒ
def training_step(self, batch, batch_idx):
z_new, reference, hint, cloth_annotation= self.get_input(batch) # ๅŠ ่ฝฝๆ•ฐๆฎ
loss= self(z_new, reference, hint, cloth_annotation) # ่ฎก็ฎ—ๆŸๅคฑ
self.log("loss", # ่ฎฐๅฝ•ๆŸๅคฑ
loss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log('lr_abs', # ่ฎฐๅฝ•ๅญฆไน ็Ž‡
self.optimizers().param_groups[0]['lr'],
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False)
return loss
# ๅŠ ่ฝฝๆ•ฐๆฎ
@torch.no_grad()
def get_input(self, batch):
# ๅŠ ่ฝฝๅŽŸๅง‹ๆ•ฐๆฎ
x, inpaint, mask, reference, hint, cloth_annotation = super().get_input(batch)
# AutoencoderKL ๅค„็†็œŸๅ€ผ
encoder_posterior = self.first_stage_model.encode(x)
z = self.scale_factor * (encoder_posterior.sample()).detach()
# AutoencoderKL ๅค„็† inpaint
encoder_posterior_inpaint = self.first_stage_model.encode(inpaint)
z_inpaint = self.scale_factor * (encoder_posterior_inpaint.sample()).detach()
# Resize mask
mask_resize = torchvision.transforms.Resize([z.shape[-2],z.shape[-1]])(mask)
# ๆ•ด็† z_new
z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
out = [z_new, reference, hint, cloth_annotation]
return out
# ่ฎก็ฎ—ๆŸๅคฑ
def forward(self, z_new, reference, hint, cloth_annotation):
# ้šๆœบๆ—ถ้—ด t
t = torch.randint(0, self.num_timesteps, (z_new.shape[0],), device=self.device).long()
# CLIP ๅค„็† reference
reference_clip = self.cond_stage_model.encode(reference)
reference_clip = self.proj_out(reference_clip)
# CLIP text reference
reference_clip_text = self.condi_stage_model.encode(cloth_annotation)
#apply CrossAttention to combine features
cross_att = CrossAttention().to('cuda')
reference_clip = cross_att(reference_clip, reference_clip_text, reference_clip_text)
# DINO ๅค„็† reference
dino = self.dinov2_vitl14(reference,is_training=True)
dino1 = dino["x_norm_clstoken"].unsqueeze(1)
dino2 = dino["x_norm_patchtokens"]
reference_dino = torch.cat((dino1, dino2), dim=1)
reference_dino = self.linear(reference_dino)
# ้šๆœบๅŠ ๅ™ช
noise = torch.randn_like(z_new[:,:4,:,:])
x_noisy = self.q_sample(x_start=z_new[:,:4,:,:], t=t, noise=noise)
x_noisy = torch.cat((x_noisy, z_new[:,4:,:,:]),dim=1)
# ้ข„ๆต‹ๅ™ชๅฃฐ
if random.uniform(0, 1)<0.2:
model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
else:
model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
# ่ฎก็ฎ—ๆŸๅคฑ
loss = self.get_loss(model_output, noise, mean=False).mean([1, 2, 3])
loss = loss.mean()
return loss
# ้ข„ๆต‹ๅ™ชๅฃฐ
def apply_model(self, x_noisy, hint, t, reference_clip, reference_dino):
# ้ข„ๆต‹ control
control = self.control_model(x_noisy, hint, t, reference_dino)
# ่ฐƒ็”จ PBE
model_output = self.model(x_noisy, t, reference_clip, control)
return model_output
# ไผ˜ๅŒ–ๅ™จ
def configure_optimizers(self):
# ๅญฆไน ็Ž‡่ฎพ็ฝฎ
lr = self.learning_rate
params = list(self.control_model.parameters())+list(self.linear.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
# ้‡‡ๆ ท
@torch.no_grad()
def sample_log(self, batch, ddim_steps=50, ddim_eta=0.):
z_new, reference, hint, cloth_annotation = self.get_input(batch)
x, _, mask, _, _, _ = super().get_input(batch)
log = dict()
# log["reference"] = reference
# reconstruction = 1. / self.scale_factor * z_new[:,:4,:,:]
# log["reconstruction"] = self.first_stage_model.decode(reconstruction)
log["mask"] = mask
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_new[:,4:8,:,:]
test_model_kwargs['inpaint_mask'] = z_new[:,8:,:,:]
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, _ = ddim_sampler.sample(ddim_steps,
reference.shape[0],
shape,
hint,
reference,
verbose=False,
eta=ddim_eta,
test_model_kwargs=test_model_kwargs)
samples = 1. / self.scale_factor * samples
x_samples = self.first_stage_model.decode(samples[:,:4,:,:])
# log["samples"] = x_samples
x = torchvision.transforms.Resize([512, 512])(x)
reference = torchvision.transforms.Resize([512, 512])(reference)
x_samples = torchvision.transforms.Resize([512, 512])(x_samples)
log["grid"] = torch.cat((x, reference, x_samples), dim=2)
return log
# CrossAttention class applies cross-attention between two embeddings: an image embedding and a text embedding.
class CrossAttention(nn.Module):
def __init__(
self,
embed_dim: int=768,
num_heads: int=8
):
"""
Initializes a CrossAttention layer using multi-head attention.
Args:
embed_dim (int): Dimensionality of the embeddings, which should match
the size of both reference_clip and reference_clip_text.
num_heads (int): Number of attention heads. Using multiple heads allows
the model to focus on different parts of the input embeddings.
"""
super(CrossAttention, self).__init__()
self.cross_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
def forward(self, query, key, value):
"""
Applies cross-attention to the query, key, and value inputs.
Args:
query (Tensor): The query tensor (in this case, reference_clip).
Shape should be [batch_size, seq_length, embed_dim].
key (Tensor): The key tensor (in this case, reference_clip_text).
Shape should be [batch_size, seq_length, embed_dim].
value (Tensor): The value tensor (in this case, reference_clip_text).
Shape should be [batch_size, seq_length, embed_dim].
Returns:
Tensor: The attention output after combining reference_clip and
reference_clip_text through cross-attention.
Shape is [batch_size, seq_length, embed_dim].
"""
query = query.to('cuda')
key = key.to('cuda')
value = value.to('cuda')
# Apply cross-attention, where `query` attends to `key` and `value`.
# `attn_output` contains the resulting embeddings, and `attn_weights` contains the attention weights.
attn_output, attn_weights = self.cross_attn(query, key, value)
return attn_output