basso4's picture
Upload 1459 files
3f9659e verified
raw
history blame
16.4 kB
import random
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from einops import rearrange
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from ldm.models.diffusion.ddpm import DDPM
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.util import instantiate_from_config, default
from ldm.models.diffusion.ddim import DDIMSampler
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LambdaLR
def disabled_train(self, mode=True):
return self
# =============================================================
# 可训练部分 ControlNet
# =============================================================
class ControlNet(nn.Module):
def __init__(
self,
in_channels, # 9
model_channels, # 320
hint_channels, # 20
attention_resolutions, # [4,2,1]
num_res_blocks, # 2
channel_mult=(1, 2, 4, 8), # [1,2,4,4]
num_head_channels=-1, # 64
transformer_depth=1, # 1
context_dim=None, # 768
use_checkpoint=False, # True
dropout=0,
conv_resample=True,
dims=2,
num_heads=-1,
use_scale_shift_norm=False):
super(ControlNet, self).__init__()
self.dims = dims
self.in_channels = in_channels
self.model_channels = model_channels
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.use_checkpoint = use_checkpoint
self.dtype = torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
# time 编码器
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
# input 编码器
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
# hint 编码器
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
# UNet
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_heads = ch // num_head_channels
dim_head = num_head_channels
disabled_sa = False
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
num_heads = ch // num_head_channels
dim_head = num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, reference_dino):
# 处理输入
context = reference_dino
t_emb = timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb)
# 预测 control
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
# =============================================================
# 固定参数部分 ControlledUnetModel
# =============================================================
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None):
hs = []
# UNet 的上半部分
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
# 注入 control
if control is not None:
h += control.pop()
# UNet 的下半部分
for i, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
# 输出
h = h.type(x.dtype)
h = self.out(h)
return h
# =============================================================
# 主干网络 ControlLDM
# =============================================================
class ControlLDM(DDPM):
def __init__(self,
control_stage_config, # ControlNet
first_stage_config, # AutoencoderKL
cond_stage_config, # FrozenCLIPImageEmbedder
scale_factor=1.0, # 0.18215
*args, **kwargs):
self.num_timesteps_cond = 1
super().__init__(*args, **kwargs) # self.model 和 self.register_buffer
self.control_model = instantiate_from_config(control_stage_config) # self.control_model
self.instantiate_first_stage(first_stage_config) # self.first_stage_model 调用 AutoencoderKL
self.instantiate_cond_stage(cond_stage_config) # self.cond_stage_model 调用 FrozenCLIPImageEmbedder
self.proj_out=nn.Linear(1024, 768) # 全连接层
self.scale_factor = scale_factor # 0.18215
self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=False)
self.trainable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
self.dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
self.dinov2_vitl14.eval()
self.dinov2_vitl14.train = disabled_train
for param in self.dinov2_vitl14.parameters():
param.requires_grad = False
self.linear = nn.Linear(1024, 768)
# self.dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
# self.dinov2_vitg14.eval()
# self.dinov2_vitg14.train = disabled_train
# for param in self.dinov2_vitg14.parameters():
# param.requires_grad = False
# self.linear = nn.Linear(1536, 768)
# AutoencoderKL 不训练
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
# FrozenCLIPImageEmbedder 不训练
def instantiate_cond_stage(self, config):
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
# 训练
def training_step(self, batch, batch_idx):
z_new, reference, hint= self.get_input(batch) # 加载数据
loss= self(z_new, reference, hint) # 计算损失
self.log("loss", # 记录损失
loss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log('lr_abs', # 记录学习率
self.optimizers().param_groups[0]['lr'],
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False)
return loss
# 加载数据
@torch.no_grad()
def get_input(self, batch):
# 加载原始数据
x, inpaint, mask, reference, hint = super().get_input(batch)
# AutoencoderKL 处理真值
encoder_posterior = self.first_stage_model.encode(x)
z = self.scale_factor * (encoder_posterior.sample()).detach()
# AutoencoderKL 处理 inpaint
encoder_posterior_inpaint = self.first_stage_model.encode(inpaint)
z_inpaint = self.scale_factor * (encoder_posterior_inpaint.sample()).detach()
# Resize mask
mask_resize = torchvision.transforms.Resize([z.shape[-2],z.shape[-1]])(mask)
# 整理 z_new
z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
out = [z_new, reference, hint]
return out
# 计算损失
def forward(self, z_new, reference, hint):
# 随机时间 t
t = torch.randint(0, self.num_timesteps, (z_new.shape[0],), device=self.device).long()
# CLIP 处理 reference
reference_clip = self.cond_stage_model.encode(reference)
reference_clip = self.proj_out(reference_clip)
# DINO 处理 reference
dino = self.dinov2_vitl14(reference,is_training=True)
dino1 = dino["x_norm_clstoken"].unsqueeze(1)
dino2 = dino["x_norm_patchtokens"]
reference_dino = torch.cat((dino1, dino2), dim=1)
reference_dino = self.linear(reference_dino)
# 随机加噪
noise = torch.randn_like(z_new[:,:4,:,:])
x_noisy = self.q_sample(x_start=z_new[:,:4,:,:], t=t, noise=noise)
x_noisy = torch.cat((x_noisy, z_new[:,4:,:,:]),dim=1)
# 预测噪声
if random.uniform(0, 1)<0.2:
model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
else:
model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
# 计算损失
loss = self.get_loss(model_output, noise, mean=False).mean([1, 2, 3])
loss = loss.mean()
return loss
# 预测噪声
def apply_model(self, x_noisy, hint, t, reference_clip, reference_dino):
# 预测 control
control = self.control_model(x_noisy, hint, t, reference_dino)
# 调用 PBE
model_output = self.model(x_noisy, t, reference_clip, control)
return model_output
# 优化器
def configure_optimizers(self):
# 学习率设置
lr = self.learning_rate
params = list(self.control_model.parameters())+list(self.linear.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
# 采样
@torch.no_grad()
def sample_log(self, batch, ddim_steps=50, ddim_eta=0.):
z_new, reference, hint = self.get_input(batch)
x, _, mask, _, _ = super().get_input(batch)
log = dict()
# log["reference"] = reference
# reconstruction = 1. / self.scale_factor * z_new[:,:4,:,:]
# log["reconstruction"] = self.first_stage_model.decode(reconstruction)
log["mask"] = mask
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_new[:,4:8,:,:]
test_model_kwargs['inpaint_mask'] = z_new[:,8:,:,:]
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, _ = ddim_sampler.sample(ddim_steps,
reference.shape[0],
shape,
hint,
reference,
verbose=False,
eta=ddim_eta,
test_model_kwargs=test_model_kwargs)
samples = 1. / self.scale_factor * samples
x_samples = self.first_stage_model.decode(samples[:,:4,:,:])
# log["samples"] = x_samples
x = torchvision.transforms.Resize([512, 512])(x)
reference = torchvision.transforms.Resize([512, 512])(reference)
x_samples = torchvision.transforms.Resize([512, 512])(x_samples)
log["grid"] = torch.cat((x, reference, x_samples), dim=2)
return log