Outfits / models /transformer_model.py
elias3446's picture
Upload 37 files
a087ce1
raw
history blame
18.7 kB
import logging
import math
from collections import OrderedDict
import numpy as np
import torch
import torch.distributions as dists
import torch.nn.functional as F
from torchvision.utils import save_image
from models.archs.transformer_arch import TransformerMultiHead
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
VectorQuantizerTexture)
logger = logging.getLogger('base')
class TransformerTextureAwareModel():
"""Texture-Aware Diffusion based Transformer model.
"""
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda')
self.is_train = opt['is_train']
# VQVAE for image
self.img_encoder = Encoder(
ch=opt['img_ch'],
num_res_blocks=opt['img_num_res_blocks'],
attn_resolutions=opt['img_attn_resolutions'],
ch_mult=opt['img_ch_mult'],
in_channels=opt['img_in_channels'],
resolution=opt['img_resolution'],
z_channels=opt['img_z_channels'],
double_z=opt['img_double_z'],
dropout=opt['img_dropout']).to(self.device)
self.img_decoder = Decoder(
in_channels=opt['img_in_channels'],
resolution=opt['img_resolution'],
z_channels=opt['img_z_channels'],
ch=opt['img_ch'],
out_ch=opt['img_out_ch'],
num_res_blocks=opt['img_num_res_blocks'],
attn_resolutions=opt['img_attn_resolutions'],
ch_mult=opt['img_ch_mult'],
dropout=opt['img_dropout'],
resamp_with_conv=True,
give_pre_end=False).to(self.device)
self.img_quantizer = VectorQuantizerTexture(
opt['img_n_embed'], opt['img_embed_dim'],
beta=0.25).to(self.device)
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
opt['img_embed_dim'],
1).to(self.device)
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
opt["img_z_channels"],
1).to(self.device)
self.load_pretrained_image_vae()
# VAE for segmentation mask
self.segm_encoder = Encoder(
ch=opt['segm_ch'],
num_res_blocks=opt['segm_num_res_blocks'],
attn_resolutions=opt['segm_attn_resolutions'],
ch_mult=opt['segm_ch_mult'],
in_channels=opt['segm_in_channels'],
resolution=opt['segm_resolution'],
z_channels=opt['segm_z_channels'],
double_z=opt['segm_double_z'],
dropout=opt['segm_dropout']).to(self.device)
self.segm_quantizer = VectorQuantizer(
opt['segm_n_embed'],
opt['segm_embed_dim'],
beta=0.25,
sane_index_shape=True).to(self.device)
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
opt['segm_embed_dim'],
1).to(self.device)
self.load_pretrained_segm_vae()
# define sampler
self._denoise_fn = TransformerMultiHead(
codebook_size=opt['codebook_size'],
segm_codebook_size=opt['segm_codebook_size'],
texture_codebook_size=opt['texture_codebook_size'],
bert_n_emb=opt['bert_n_emb'],
bert_n_layers=opt['bert_n_layers'],
bert_n_head=opt['bert_n_head'],
block_size=opt['block_size'],
latent_shape=opt['latent_shape'],
embd_pdrop=opt['embd_pdrop'],
resid_pdrop=opt['resid_pdrop'],
attn_pdrop=opt['attn_pdrop'],
num_head=opt['num_head']).to(self.device)
self.num_classes = opt['codebook_size']
self.shape = tuple(opt['latent_shape'])
self.num_timesteps = 1000
self.mask_id = opt['codebook_size']
self.loss_type = opt['loss_type']
self.mask_schedule = opt['mask_schedule']
self.sample_steps = opt['sample_steps']
self.init_training_settings()
def load_pretrained_image_vae(self):
# load pretrained vqgan for segmentation mask
img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
self.img_encoder.load_state_dict(
img_ae_checkpoint['encoder'], strict=True)
self.img_decoder.load_state_dict(
img_ae_checkpoint['decoder'], strict=True)
self.img_quantizer.load_state_dict(
img_ae_checkpoint['quantize'], strict=True)
self.img_quant_conv.load_state_dict(
img_ae_checkpoint['quant_conv'], strict=True)
self.img_post_quant_conv.load_state_dict(
img_ae_checkpoint['post_quant_conv'], strict=True)
self.img_encoder.eval()
self.img_decoder.eval()
self.img_quantizer.eval()
self.img_quant_conv.eval()
self.img_post_quant_conv.eval()
def load_pretrained_segm_vae(self):
# load pretrained vqgan for segmentation mask
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
self.segm_encoder.load_state_dict(
segm_ae_checkpoint['encoder'], strict=True)
self.segm_quantizer.load_state_dict(
segm_ae_checkpoint['quantize'], strict=True)
self.segm_quant_conv.load_state_dict(
segm_ae_checkpoint['quant_conv'], strict=True)
self.segm_encoder.eval()
self.segm_quantizer.eval()
self.segm_quant_conv.eval()
def init_training_settings(self):
optim_params = []
for v in self._denoise_fn.parameters():
if v.requires_grad:
optim_params.append(v)
# set up optimizer
self.optimizer = torch.optim.Adam(
optim_params,
self.opt['lr'],
weight_decay=self.opt['weight_decay'])
self.log_dict = OrderedDict()
@torch.no_grad()
def get_quantized_img(self, image, texture_mask):
encoded_img = self.img_encoder(image)
encoded_img = self.img_quant_conv(encoded_img)
# img_tokens_input is the continual index for the input of transformer
# img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
_, _, [_, img_tokens_input, img_tokens_gt_list
] = self.img_quantizer(encoded_img, texture_mask)
# reshape the tokens
b = image.size(0)
img_tokens_input = img_tokens_input.view(b, -1)
img_tokens_gt_return_list = [
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
]
return img_tokens_input, img_tokens_gt_return_list
@torch.no_grad()
def decode(self, quant):
quant = self.img_post_quant_conv(quant)
dec = self.img_decoder(quant)
return dec
@torch.no_grad()
def decode_image_indices(self, indices_list, texture_mask):
quant = self.img_quantizer.get_codebook_entry(
indices_list, texture_mask,
(indices_list[0].size(0), self.shape[0], self.shape[1],
self.opt["img_z_channels"]))
dec = self.decode(quant)
return dec
def sample_time(self, b, device, method='uniform'):
if method == 'importance':
if not (self.Lt_count > 10).all():
return self.sample_time(b, device, method='uniform')
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
pt_all = Lt_sqrt / Lt_sqrt.sum()
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
pt = pt_all.gather(dim=0, index=t)
return t, pt
elif method == 'uniform':
t = torch.randint(
1, self.num_timesteps + 1, (b, ), device=device).long()
pt = torch.ones_like(t).float() / self.num_timesteps
return t, pt
else:
raise ValueError
def q_sample(self, x_0, x_0_gt_list, t):
# samples q(x_t | x_0)
# randomly set token to mask with probability t/T
# x_t, x_0_ignore = x_0.clone(), x_0.clone()
x_t = x_0.clone()
mask = torch.rand_like(x_t.float()) < (
t.float().unsqueeze(-1) / self.num_timesteps)
x_t[mask] = self.mask_id
# x_0_ignore[torch.bitwise_not(mask)] = -1
# for every gt token list, we also need to do the mask
x_0_gt_ignore_list = []
for x_0_gt in x_0_gt_list:
x_0_gt_ignore = x_0_gt.clone()
x_0_gt_ignore[torch.bitwise_not(mask)] = -1
x_0_gt_ignore_list.append(x_0_gt_ignore)
return x_t, x_0_gt_ignore_list, mask
def _train_loss(self, x_0, x_0_gt_list):
b, device = x_0.size(0), x_0.device
# choose what time steps to compute loss at
t, pt = self.sample_time(b, device, 'uniform')
# make x noisy and denoise
if self.mask_schedule == 'random':
x_t, x_0_gt_ignore_list, mask = self.q_sample(
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
else:
raise NotImplementedError
# sample p(x_0 | x_t)
x_0_hat_logits_list = self._denoise_fn(
x_t, self.segm_tokens, self.texture_tokens, t=t)
# Always compute ELBO for comparison purposes
cross_entropy_loss = 0
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
x_0_gt_ignore_list):
cross_entropy_loss += F.cross_entropy(
x_0_hat_logits.permute(0, 2, 1),
x_0_gt_ignore,
ignore_index=-1,
reduction='none').sum(1)
vb_loss = cross_entropy_loss / t
vb_loss = vb_loss / pt
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
if self.loss_type == 'elbo':
loss = vb_loss
elif self.loss_type == 'mlm':
denom = mask.float().sum(1)
denom[denom == 0] = 1 # prevent divide by 0 errors.
loss = cross_entropy_loss / denom
elif self.loss_type == 'reweighted_elbo':
weight = (1 - (t / self.num_timesteps))
loss = weight * cross_entropy_loss
loss = loss / (math.log(2) * x_0.shape[1:].numel())
else:
raise ValueError
return loss.mean(), vb_loss.mean()
def feed_data(self, data):
self.image = data['image'].to(self.device)
self.segm = data['segm'].to(self.device)
self.texture_mask = data['texture_mask'].to(self.device)
self.input_indices, self.gt_indices_list = self.get_quantized_img(
self.image, self.texture_mask)
self.texture_tokens = F.interpolate(
self.texture_mask, size=self.shape,
mode='nearest').view(self.image.size(0), -1).long()
self.segm_tokens = self.get_quantized_segm(self.segm)
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
def optimize_parameters(self):
self._denoise_fn.train()
loss, vb_loss = self._train_loss(self.input_indices,
self.gt_indices_list)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.log_dict['loss'] = loss
self.log_dict['vb_loss'] = vb_loss
self._denoise_fn.eval()
@torch.no_grad()
def get_quantized_segm(self, segm):
segm_one_hot = F.one_hot(
segm.squeeze(1).long(),
num_classes=self.opt['segm_num_segm_classes']).permute(
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
encoded_segm_mask = self.segm_encoder(segm_one_hot)
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
return segm_tokens
def sample_fn(self, temp=1.0, sample_steps=None):
self._denoise_fn.eval()
b, device = self.image.size(0), 'cuda'
x_t = torch.ones(
(b, np.prod(self.shape)), device=device).long() * self.mask_id
unmasked = torch.zeros_like(x_t, device=device).bool()
sample_steps = list(range(1, sample_steps + 1))
texture_mask_flatten = self.texture_tokens.view(-1)
# min_encodings_indices_list would be used to visualize the image
min_encodings_indices_list = [
torch.full(
texture_mask_flatten.size(),
fill_value=-1,
dtype=torch.long,
device=texture_mask_flatten.device) for _ in range(18)
]
for t in reversed(sample_steps):
print(f'Sample timestep {t:4d}', end='\r')
t = torch.full((b, ), t, device=device, dtype=torch.long)
# where to unmask
changes = torch.rand(
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
# don't unmask somewhere already unmasked
changes = torch.bitwise_xor(changes,
torch.bitwise_and(changes, unmasked))
# update mask with changes
unmasked = torch.bitwise_or(unmasked, changes)
x_0_logits_list = self._denoise_fn(
x_t, self.segm_tokens, self.texture_tokens, t=t)
changes_flatten = changes.view(-1)
ori_shape = x_t.shape # [b, h*w]
x_t = x_t.view(-1) # [b*h*w]
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
if torch.sum(texture_mask_flatten[changes_flatten] ==
codebook_idx) > 0:
# scale by temperature
x_0_logits = x_0_logits / temp
x_0_dist = dists.Categorical(logits=x_0_logits)
x_0_hat = x_0_dist.sample().long()
x_0_hat = x_0_hat.view(-1)
# only replace the changed indices with corresponding codebook_idx
changes_segm = torch.bitwise_and(
changes_flatten, texture_mask_flatten == codebook_idx)
# x_t would be the input to the transformer, so the index range should be continual one
x_t[changes_segm] = x_0_hat[
changes_segm] + 1024 * codebook_idx
min_encodings_indices_list[codebook_idx][
changes_segm] = x_0_hat[changes_segm]
x_t = x_t.view(ori_shape) # [b, h*w]
min_encodings_indices_return_list = [
min_encodings_indices.view(ori_shape)
for min_encodings_indices in min_encodings_indices_list
]
self._denoise_fn.train()
return min_encodings_indices_return_list
def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
save_path):
# original image
ori_img = self.decode_image_indices(gt_indices, texture_mask)
# pred image
pred_img = self.decode_image_indices(predicted_indices, texture_mask)
img_cat = torch.cat([
image,
ori_img,
pred_img,
], dim=3).detach()
img_cat = ((img_cat + 1) / 2)
img_cat = img_cat.clamp_(0, 1)
save_image(img_cat, save_path, nrow=1, padding=4)
def inference(self, data_loader, save_dir):
self._denoise_fn.eval()
for _, data in enumerate(data_loader):
img_name = data['img_name']
self.feed_data(data)
b = self.image.size(0)
with torch.no_grad():
sampled_indices_list = self.sample_fn(
temp=1, sample_steps=self.sample_steps)
for idx in range(b):
self.get_vis(self.image[idx:idx + 1], [
gt_indices[idx:idx + 1]
for gt_indices in self.gt_indices_list
], [
sampled_indices[idx:idx + 1]
for sampled_indices in sampled_indices_list
], self.texture_mask[idx:idx + 1],
f'{save_dir}/{img_name[idx]}')
self._denoise_fn.train()
def get_current_log(self):
return self.log_dict
def update_learning_rate(self, epoch, iters=None):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
Default: -1.
"""
lr = self.optimizer.param_groups[0]['lr']
if self.opt['lr_decay'] == 'step':
lr = self.opt['lr'] * (
self.opt['gamma']**(epoch // self.opt['step']))
elif self.opt['lr_decay'] == 'cos':
lr = self.opt['lr'] * (
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
elif self.opt['lr_decay'] == 'linear':
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
elif self.opt['lr_decay'] == 'linear2exp':
if epoch < self.opt['turning_point'] + 1:
# learning rate decay as 95%
# at the turning point (1 / 95% = 1.0526)
lr = self.opt['lr'] * (
1 - epoch / int(self.opt['turning_point'] * 1.0526))
else:
lr *= self.opt['gamma']
elif self.opt['lr_decay'] == 'schedule':
if epoch in self.opt['schedule']:
lr *= self.opt['gamma']
elif self.opt['lr_decay'] == 'warm_up':
if iters <= self.opt['warmup_iters']:
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
else:
lr = self.opt['lr']
else:
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
# set learning rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def save_network(self, net, save_path):
"""Save networks.
Args:
net (nn.Module): Network to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
"""
state_dict = net.state_dict()
torch.save(state_dict, save_path)
def load_network(self):
checkpoint = torch.load(self.opt['pretrained_sampler'])
self._denoise_fn.load_state_dict(checkpoint, strict=True)
self._denoise_fn.eval()