|
import logging |
|
import math |
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributions as dists |
|
import torch.nn.functional as F |
|
from torchvision.utils import save_image |
|
|
|
from models.archs.transformer_arch import TransformerMultiHead |
|
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer, |
|
VectorQuantizerTexture) |
|
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
class TransformerTextureAwareModel(): |
|
"""Texture-Aware Diffusion based Transformer model. |
|
""" |
|
|
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.device = torch.device('cuda') |
|
self.is_train = opt['is_train'] |
|
|
|
|
|
self.img_encoder = Encoder( |
|
ch=opt['img_ch'], |
|
num_res_blocks=opt['img_num_res_blocks'], |
|
attn_resolutions=opt['img_attn_resolutions'], |
|
ch_mult=opt['img_ch_mult'], |
|
in_channels=opt['img_in_channels'], |
|
resolution=opt['img_resolution'], |
|
z_channels=opt['img_z_channels'], |
|
double_z=opt['img_double_z'], |
|
dropout=opt['img_dropout']).to(self.device) |
|
self.img_decoder = Decoder( |
|
in_channels=opt['img_in_channels'], |
|
resolution=opt['img_resolution'], |
|
z_channels=opt['img_z_channels'], |
|
ch=opt['img_ch'], |
|
out_ch=opt['img_out_ch'], |
|
num_res_blocks=opt['img_num_res_blocks'], |
|
attn_resolutions=opt['img_attn_resolutions'], |
|
ch_mult=opt['img_ch_mult'], |
|
dropout=opt['img_dropout'], |
|
resamp_with_conv=True, |
|
give_pre_end=False).to(self.device) |
|
self.img_quantizer = VectorQuantizerTexture( |
|
opt['img_n_embed'], opt['img_embed_dim'], |
|
beta=0.25).to(self.device) |
|
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"], |
|
opt['img_embed_dim'], |
|
1).to(self.device) |
|
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'], |
|
opt["img_z_channels"], |
|
1).to(self.device) |
|
self.load_pretrained_image_vae() |
|
|
|
|
|
self.segm_encoder = Encoder( |
|
ch=opt['segm_ch'], |
|
num_res_blocks=opt['segm_num_res_blocks'], |
|
attn_resolutions=opt['segm_attn_resolutions'], |
|
ch_mult=opt['segm_ch_mult'], |
|
in_channels=opt['segm_in_channels'], |
|
resolution=opt['segm_resolution'], |
|
z_channels=opt['segm_z_channels'], |
|
double_z=opt['segm_double_z'], |
|
dropout=opt['segm_dropout']).to(self.device) |
|
self.segm_quantizer = VectorQuantizer( |
|
opt['segm_n_embed'], |
|
opt['segm_embed_dim'], |
|
beta=0.25, |
|
sane_index_shape=True).to(self.device) |
|
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"], |
|
opt['segm_embed_dim'], |
|
1).to(self.device) |
|
self.load_pretrained_segm_vae() |
|
|
|
|
|
self._denoise_fn = TransformerMultiHead( |
|
codebook_size=opt['codebook_size'], |
|
segm_codebook_size=opt['segm_codebook_size'], |
|
texture_codebook_size=opt['texture_codebook_size'], |
|
bert_n_emb=opt['bert_n_emb'], |
|
bert_n_layers=opt['bert_n_layers'], |
|
bert_n_head=opt['bert_n_head'], |
|
block_size=opt['block_size'], |
|
latent_shape=opt['latent_shape'], |
|
embd_pdrop=opt['embd_pdrop'], |
|
resid_pdrop=opt['resid_pdrop'], |
|
attn_pdrop=opt['attn_pdrop'], |
|
num_head=opt['num_head']).to(self.device) |
|
|
|
self.num_classes = opt['codebook_size'] |
|
self.shape = tuple(opt['latent_shape']) |
|
self.num_timesteps = 1000 |
|
|
|
self.mask_id = opt['codebook_size'] |
|
self.loss_type = opt['loss_type'] |
|
self.mask_schedule = opt['mask_schedule'] |
|
|
|
self.sample_steps = opt['sample_steps'] |
|
|
|
self.init_training_settings() |
|
|
|
def load_pretrained_image_vae(self): |
|
|
|
img_ae_checkpoint = torch.load(self.opt['img_ae_path']) |
|
self.img_encoder.load_state_dict( |
|
img_ae_checkpoint['encoder'], strict=True) |
|
self.img_decoder.load_state_dict( |
|
img_ae_checkpoint['decoder'], strict=True) |
|
self.img_quantizer.load_state_dict( |
|
img_ae_checkpoint['quantize'], strict=True) |
|
self.img_quant_conv.load_state_dict( |
|
img_ae_checkpoint['quant_conv'], strict=True) |
|
self.img_post_quant_conv.load_state_dict( |
|
img_ae_checkpoint['post_quant_conv'], strict=True) |
|
self.img_encoder.eval() |
|
self.img_decoder.eval() |
|
self.img_quantizer.eval() |
|
self.img_quant_conv.eval() |
|
self.img_post_quant_conv.eval() |
|
|
|
def load_pretrained_segm_vae(self): |
|
|
|
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path']) |
|
self.segm_encoder.load_state_dict( |
|
segm_ae_checkpoint['encoder'], strict=True) |
|
self.segm_quantizer.load_state_dict( |
|
segm_ae_checkpoint['quantize'], strict=True) |
|
self.segm_quant_conv.load_state_dict( |
|
segm_ae_checkpoint['quant_conv'], strict=True) |
|
self.segm_encoder.eval() |
|
self.segm_quantizer.eval() |
|
self.segm_quant_conv.eval() |
|
|
|
def init_training_settings(self): |
|
optim_params = [] |
|
for v in self._denoise_fn.parameters(): |
|
if v.requires_grad: |
|
optim_params.append(v) |
|
|
|
self.optimizer = torch.optim.Adam( |
|
optim_params, |
|
self.opt['lr'], |
|
weight_decay=self.opt['weight_decay']) |
|
self.log_dict = OrderedDict() |
|
|
|
@torch.no_grad() |
|
def get_quantized_img(self, image, texture_mask): |
|
encoded_img = self.img_encoder(image) |
|
encoded_img = self.img_quant_conv(encoded_img) |
|
|
|
|
|
|
|
_, _, [_, img_tokens_input, img_tokens_gt_list |
|
] = self.img_quantizer(encoded_img, texture_mask) |
|
|
|
|
|
b = image.size(0) |
|
img_tokens_input = img_tokens_input.view(b, -1) |
|
img_tokens_gt_return_list = [ |
|
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list |
|
] |
|
|
|
return img_tokens_input, img_tokens_gt_return_list |
|
|
|
@torch.no_grad() |
|
def decode(self, quant): |
|
quant = self.img_post_quant_conv(quant) |
|
dec = self.img_decoder(quant) |
|
return dec |
|
|
|
@torch.no_grad() |
|
def decode_image_indices(self, indices_list, texture_mask): |
|
quant = self.img_quantizer.get_codebook_entry( |
|
indices_list, texture_mask, |
|
(indices_list[0].size(0), self.shape[0], self.shape[1], |
|
self.opt["img_z_channels"])) |
|
dec = self.decode(quant) |
|
|
|
return dec |
|
|
|
def sample_time(self, b, device, method='uniform'): |
|
if method == 'importance': |
|
if not (self.Lt_count > 10).all(): |
|
return self.sample_time(b, device, method='uniform') |
|
|
|
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001 |
|
Lt_sqrt[0] = Lt_sqrt[1] |
|
pt_all = Lt_sqrt / Lt_sqrt.sum() |
|
|
|
t = torch.multinomial(pt_all, num_samples=b, replacement=True) |
|
|
|
pt = pt_all.gather(dim=0, index=t) |
|
|
|
return t, pt |
|
|
|
elif method == 'uniform': |
|
t = torch.randint( |
|
1, self.num_timesteps + 1, (b, ), device=device).long() |
|
pt = torch.ones_like(t).float() / self.num_timesteps |
|
return t, pt |
|
|
|
else: |
|
raise ValueError |
|
|
|
def q_sample(self, x_0, x_0_gt_list, t): |
|
|
|
|
|
|
|
x_t = x_0.clone() |
|
|
|
mask = torch.rand_like(x_t.float()) < ( |
|
t.float().unsqueeze(-1) / self.num_timesteps) |
|
x_t[mask] = self.mask_id |
|
|
|
|
|
|
|
x_0_gt_ignore_list = [] |
|
for x_0_gt in x_0_gt_list: |
|
x_0_gt_ignore = x_0_gt.clone() |
|
x_0_gt_ignore[torch.bitwise_not(mask)] = -1 |
|
x_0_gt_ignore_list.append(x_0_gt_ignore) |
|
|
|
return x_t, x_0_gt_ignore_list, mask |
|
|
|
def _train_loss(self, x_0, x_0_gt_list): |
|
b, device = x_0.size(0), x_0.device |
|
|
|
|
|
t, pt = self.sample_time(b, device, 'uniform') |
|
|
|
|
|
if self.mask_schedule == 'random': |
|
x_t, x_0_gt_ignore_list, mask = self.q_sample( |
|
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
x_0_hat_logits_list = self._denoise_fn( |
|
x_t, self.segm_tokens, self.texture_tokens, t=t) |
|
|
|
|
|
cross_entropy_loss = 0 |
|
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list, |
|
x_0_gt_ignore_list): |
|
cross_entropy_loss += F.cross_entropy( |
|
x_0_hat_logits.permute(0, 2, 1), |
|
x_0_gt_ignore, |
|
ignore_index=-1, |
|
reduction='none').sum(1) |
|
vb_loss = cross_entropy_loss / t |
|
vb_loss = vb_loss / pt |
|
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel()) |
|
if self.loss_type == 'elbo': |
|
loss = vb_loss |
|
elif self.loss_type == 'mlm': |
|
denom = mask.float().sum(1) |
|
denom[denom == 0] = 1 |
|
loss = cross_entropy_loss / denom |
|
elif self.loss_type == 'reweighted_elbo': |
|
weight = (1 - (t / self.num_timesteps)) |
|
loss = weight * cross_entropy_loss |
|
loss = loss / (math.log(2) * x_0.shape[1:].numel()) |
|
else: |
|
raise ValueError |
|
|
|
return loss.mean(), vb_loss.mean() |
|
|
|
def feed_data(self, data): |
|
self.image = data['image'].to(self.device) |
|
self.segm = data['segm'].to(self.device) |
|
self.texture_mask = data['texture_mask'].to(self.device) |
|
self.input_indices, self.gt_indices_list = self.get_quantized_img( |
|
self.image, self.texture_mask) |
|
|
|
self.texture_tokens = F.interpolate( |
|
self.texture_mask, size=self.shape, |
|
mode='nearest').view(self.image.size(0), -1).long() |
|
|
|
self.segm_tokens = self.get_quantized_segm(self.segm) |
|
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1) |
|
|
|
def optimize_parameters(self): |
|
self._denoise_fn.train() |
|
|
|
loss, vb_loss = self._train_loss(self.input_indices, |
|
self.gt_indices_list) |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
self.log_dict['loss'] = loss |
|
self.log_dict['vb_loss'] = vb_loss |
|
|
|
self._denoise_fn.eval() |
|
|
|
@torch.no_grad() |
|
def get_quantized_segm(self, segm): |
|
segm_one_hot = F.one_hot( |
|
segm.squeeze(1).long(), |
|
num_classes=self.opt['segm_num_segm_classes']).permute( |
|
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() |
|
encoded_segm_mask = self.segm_encoder(segm_one_hot) |
|
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask) |
|
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask) |
|
|
|
return segm_tokens |
|
|
|
def sample_fn(self, temp=1.0, sample_steps=None): |
|
self._denoise_fn.eval() |
|
|
|
b, device = self.image.size(0), 'cuda' |
|
x_t = torch.ones( |
|
(b, np.prod(self.shape)), device=device).long() * self.mask_id |
|
unmasked = torch.zeros_like(x_t, device=device).bool() |
|
sample_steps = list(range(1, sample_steps + 1)) |
|
|
|
texture_mask_flatten = self.texture_tokens.view(-1) |
|
|
|
|
|
min_encodings_indices_list = [ |
|
torch.full( |
|
texture_mask_flatten.size(), |
|
fill_value=-1, |
|
dtype=torch.long, |
|
device=texture_mask_flatten.device) for _ in range(18) |
|
] |
|
|
|
for t in reversed(sample_steps): |
|
print(f'Sample timestep {t:4d}', end='\r') |
|
t = torch.full((b, ), t, device=device, dtype=torch.long) |
|
|
|
|
|
changes = torch.rand( |
|
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) |
|
|
|
changes = torch.bitwise_xor(changes, |
|
torch.bitwise_and(changes, unmasked)) |
|
|
|
unmasked = torch.bitwise_or(unmasked, changes) |
|
|
|
x_0_logits_list = self._denoise_fn( |
|
x_t, self.segm_tokens, self.texture_tokens, t=t) |
|
|
|
changes_flatten = changes.view(-1) |
|
ori_shape = x_t.shape |
|
x_t = x_t.view(-1) |
|
for codebook_idx, x_0_logits in enumerate(x_0_logits_list): |
|
if torch.sum(texture_mask_flatten[changes_flatten] == |
|
codebook_idx) > 0: |
|
|
|
x_0_logits = x_0_logits / temp |
|
x_0_dist = dists.Categorical(logits=x_0_logits) |
|
x_0_hat = x_0_dist.sample().long() |
|
x_0_hat = x_0_hat.view(-1) |
|
|
|
|
|
changes_segm = torch.bitwise_and( |
|
changes_flatten, texture_mask_flatten == codebook_idx) |
|
|
|
|
|
x_t[changes_segm] = x_0_hat[ |
|
changes_segm] + 1024 * codebook_idx |
|
min_encodings_indices_list[codebook_idx][ |
|
changes_segm] = x_0_hat[changes_segm] |
|
|
|
x_t = x_t.view(ori_shape) |
|
|
|
min_encodings_indices_return_list = [ |
|
min_encodings_indices.view(ori_shape) |
|
for min_encodings_indices in min_encodings_indices_list |
|
] |
|
|
|
self._denoise_fn.train() |
|
|
|
return min_encodings_indices_return_list |
|
|
|
def get_vis(self, image, gt_indices, predicted_indices, texture_mask, |
|
save_path): |
|
|
|
ori_img = self.decode_image_indices(gt_indices, texture_mask) |
|
|
|
pred_img = self.decode_image_indices(predicted_indices, texture_mask) |
|
img_cat = torch.cat([ |
|
image, |
|
ori_img, |
|
pred_img, |
|
], dim=3).detach() |
|
img_cat = ((img_cat + 1) / 2) |
|
img_cat = img_cat.clamp_(0, 1) |
|
save_image(img_cat, save_path, nrow=1, padding=4) |
|
|
|
def inference(self, data_loader, save_dir): |
|
self._denoise_fn.eval() |
|
|
|
for _, data in enumerate(data_loader): |
|
img_name = data['img_name'] |
|
self.feed_data(data) |
|
b = self.image.size(0) |
|
with torch.no_grad(): |
|
sampled_indices_list = self.sample_fn( |
|
temp=1, sample_steps=self.sample_steps) |
|
for idx in range(b): |
|
self.get_vis(self.image[idx:idx + 1], [ |
|
gt_indices[idx:idx + 1] |
|
for gt_indices in self.gt_indices_list |
|
], [ |
|
sampled_indices[idx:idx + 1] |
|
for sampled_indices in sampled_indices_list |
|
], self.texture_mask[idx:idx + 1], |
|
f'{save_dir}/{img_name[idx]}') |
|
|
|
self._denoise_fn.train() |
|
|
|
def get_current_log(self): |
|
return self.log_dict |
|
|
|
def update_learning_rate(self, epoch, iters=None): |
|
"""Update learning rate. |
|
|
|
Args: |
|
current_iter (int): Current iteration. |
|
warmup_iter (int): Warmup iter numbers. -1 for no warmup. |
|
Default: -1. |
|
""" |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
if self.opt['lr_decay'] == 'step': |
|
lr = self.opt['lr'] * ( |
|
self.opt['gamma']**(epoch // self.opt['step'])) |
|
elif self.opt['lr_decay'] == 'cos': |
|
lr = self.opt['lr'] * ( |
|
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2 |
|
elif self.opt['lr_decay'] == 'linear': |
|
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs']) |
|
elif self.opt['lr_decay'] == 'linear2exp': |
|
if epoch < self.opt['turning_point'] + 1: |
|
|
|
|
|
lr = self.opt['lr'] * ( |
|
1 - epoch / int(self.opt['turning_point'] * 1.0526)) |
|
else: |
|
lr *= self.opt['gamma'] |
|
elif self.opt['lr_decay'] == 'schedule': |
|
if epoch in self.opt['schedule']: |
|
lr *= self.opt['gamma'] |
|
elif self.opt['lr_decay'] == 'warm_up': |
|
if iters <= self.opt['warmup_iters']: |
|
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters'] |
|
else: |
|
lr = self.opt['lr'] |
|
else: |
|
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay'])) |
|
|
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
return lr |
|
|
|
def save_network(self, net, save_path): |
|
"""Save networks. |
|
|
|
Args: |
|
net (nn.Module): Network to be saved. |
|
net_label (str): Network label. |
|
current_iter (int): Current iter number. |
|
""" |
|
state_dict = net.state_dict() |
|
torch.save(state_dict, save_path) |
|
|
|
def load_network(self): |
|
checkpoint = torch.load(self.opt['pretrained_sampler']) |
|
self._denoise_fn.load_state_dict(checkpoint, strict=True) |
|
self._denoise_fn.eval() |
|
|