vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
"""Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import math
import argparse
import numpy as np
import pickle as pkl
import random
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from vq_gan_3d.utils import shift_dim, adopt_weight, comp_getattr
from vq_gan_3d.model.lpips import LPIPS
from vq_gan_3d.model.codebook import Codebook
def silu(x):
return x*torch.sigmoid(x)
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return silu(x)
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
class VQGAN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self._set_seed(0)
self.embedding_dim = cfg.model.embedding_dim
self.n_codes = cfg.model.n_codes
self.encoder = Encoder(cfg.model.n_hiddens, cfg.model.downsample,
cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.padding_type,
cfg.model.num_groups,
)
self.decoder = Decoder(
cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups)
self.enc_out_ch = self.encoder.out_channels
self.pre_vq_conv = SamePadConv3d(
self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type)
self.post_vq_conv = SamePadConv3d(
cfg.model.embedding_dim, self.enc_out_ch, 1)
self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim,
no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres)
self.gan_feat_weight = cfg.model.gan_feat_weight
# TODO: Changed batchnorm from sync to normal
self.image_discriminator = NLayerDiscriminator(
cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d)
if cfg.model.disc_loss_type == 'vanilla':
self.disc_loss = vanilla_d_loss
elif cfg.model.disc_loss_type == 'hinge':
self.disc_loss = hinge_d_loss
self.perceptual_model = LPIPS()
self.image_gan_weight = cfg.model.image_gan_weight
self.perceptual_weight = cfg.model.perceptual_weight
self.l1_weight = cfg.model.l1_weight
def encode(self, x, include_embeddings=False, quantize=True):
h = self.pre_vq_conv(self.encoder(x))
if quantize:
vq_output = self.codebook(h)
if include_embeddings:
return vq_output['embeddings'], vq_output['encodings']
else:
return vq_output['encodings']
return h
def decode(self, latent, quantize=False):
if quantize:
vq_output = self.codebook(latent)
latent = vq_output['encodings']
h = F.embedding(latent, self.codebook.embeddings)
h = self.post_vq_conv(shift_dim(h, -1, 1))
return self.decoder(h)
def forward(self, global_step, x, optimizer_idx=None, log_image=False, gpu_id=0):
B, C, T, H, W = x.shape
z = self.pre_vq_conv(self.encoder(x))
vq_output = self.codebook(z, gpu_id)
#vq_output['embeddings'] = torch.exp(vq_output['embeddings'])
x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings']))
recon_loss = (F.l1_loss(x_recon, x) * self.l1_weight)
# Selects one random 2D image from each 3D Image
frame_idx = torch.randint(0, T, [B]).to(gpu_id)
frame_idx_selected = frame_idx.reshape(-1,
1, 1, 1, 1).repeat(1, C, 1, H, W)
frames = torch.gather(x, 2, frame_idx_selected).squeeze(2)
frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2)
if log_image:
return frames, frames_recon, x, x_recon
if optimizer_idx == 0:
# Autoencoder - train the "generator"
# Perceptual loss
perceptual_loss = 0
if self.perceptual_weight > 0:
perceptual_loss = self.perceptual_model(
frames, frames_recon).mean() * self.perceptual_weight
# perceptual_loss = .123
# Discriminator loss (turned on after a certain epoch)
logits_image_fake, pred_image_fake = self.image_discriminator(
frames_recon)
g_image_loss = -torch.mean(logits_image_fake)
g_loss = self.image_gan_weight*g_image_loss
disc_factor = adopt_weight(
global_step, threshold=self.cfg.model.discriminator_iter_start)
aeloss = disc_factor * g_loss
# GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
image_gan_feat_loss = 0
feat_weights = 4.0 / (3 + 1)
if self.image_gan_weight > 0:
logits_image_real, pred_image_real = self.image_discriminator(
frames)
for i in range(len(pred_image_fake)-1):
image_gan_feat_loss += feat_weights * \
F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
)) * (self.image_gan_weight > 0)
gan_feat_loss = disc_factor * self.gan_feat_weight * \
(image_gan_feat_loss)
return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss, (g_image_loss, image_gan_feat_loss, vq_output['commitment_loss'], vq_output['perplexity'])
if optimizer_idx == 1:
# Train discriminator
logits_image_real, _ = self.image_discriminator(frames.detach())
logits_image_fake, _ = self.image_discriminator(
frames_recon.detach())
d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
disc_factor = adopt_weight(
global_step, threshold=self.cfg.model.discriminator_iter_start)
discloss = disc_factor * \
(self.image_gan_weight*d_image_loss)
return discloss
perceptual_loss = self.perceptual_model(
frames, frames_recon) * self.perceptual_weight
return recon_loss, x_recon, vq_output, perceptual_loss
def log_images(self, batch, **kwargs):
log = dict()
x = batch['data']
x = x.to(self.device)
frames, frames_rec, _, _ = self(x, log_image=True)
log["inputs"] = frames
log["reconstructions"] = frames_rec
#log['mean_org'] = batch['mean_org']
#log['std_org'] = batch['std_org']
return log
def _set_seed(self, value):
print('Random Seed:', value)
random.seed(value)
torch.manual_seed(value)
torch.cuda.manual_seed(value)
torch.cuda.manual_seed_all(value)
np.random.seed(value)
cudnn.deterministic = True
cudnn.benchmark = True
cudnn.enabled = True
def Normalize(in_channels, norm_type='group', num_groups=32):
assert norm_type in ['group', 'batch']
if norm_type == 'group':
# TODO Changed num_groups from 32 to 8
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == 'batch':
return torch.nn.SyncBatchNorm(in_channels)
class Encoder(nn.Module):
def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32):
super().__init__()
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
self.conv_blocks = nn.ModuleList()
max_ds = n_times_downsample.max()
self.conv_first = SamePadConv3d(
image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
for i in range(max_ds):
block = nn.Module()
in_channels = n_hiddens * 2**i
out_channels = n_hiddens * 2**(i+1)
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
block.down = SamePadConv3d(
in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
block.res = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
self.conv_blocks.append(block)
n_times_downsample -= 1
self.final_block = nn.Sequential(
Normalize(out_channels, norm_type, num_groups=num_groups),
SiLU()
)
self.out_channels = out_channels
def forward(self, x):
h = self.conv_first(x)
for block in self.conv_blocks:
h = block.down(h)
h = block.res(h)
h = self.final_block(h)
return h
class Decoder(nn.Module):
def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32):
super().__init__()
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
max_us = n_times_upsample.max()
in_channels = n_hiddens*2**max_us
self.final_block = nn.Sequential(
Normalize(in_channels, norm_type, num_groups=num_groups),
SiLU()
)
self.conv_blocks = nn.ModuleList()
for i in range(max_us):
block = nn.Module()
in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
out_channels = n_hiddens*2**(max_us-i)
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
block.up = SamePadConvTranspose3d(
in_channels, out_channels, 4, stride=us)
block.res1 = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
block.res2 = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
self.conv_blocks.append(block)
n_times_upsample -= 1
self.conv_last = SamePadConv3d(
out_channels, image_channel, kernel_size=3)
def forward(self, x):
h = self.final_block(x)
for i, block in enumerate(self.conv_blocks):
h = block.up(h)
h = block.res1(h)
h = block.res2(h)
h = self.conv_last(h)
return h
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
self.conv1 = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
self.dropout = torch.nn.Dropout(dropout)
self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
self.conv2 = SamePadConv3d(
out_channels, out_channels, kernel_size=3, padding_type=padding_type)
if self.in_channels != self.out_channels:
self.conv_shortcut = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
def forward(self, x):
h = x
h = self.norm1(h)
h = silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x)
return x+h
# Does not support dilation
class SamePadConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=0, bias=bias)
def forward(self, x):
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
class SamePadConvTranspose3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
stride=stride, bias=bias,
padding=tuple([k - 1 for k in kernel_size]))
def forward(self, x):
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
# def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[-1], res[1:]
else:
return self.model(input), _
class NLayerDiscriminator3D(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
super(NLayerDiscriminator3D, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[-1], res[1:]
else:
return self.model(input), _