vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
"""Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import math
import argparse
import numpy as np
import pickle as pkl
import random
import gc
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from vq_gan_3d.utils import shift_dim, adopt_weight, comp_getattr
from vq_gan_3d.model.lpips import LPIPS
from vq_gan_3d.model.codebook import Codebook
def silu(x):
return x*torch.sigmoid(x)
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return silu(x)
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
class MeanPooling(nn.Module):
def __init__(self, kernel_size=16):
super(MeanPooling, self).__init__()
# Define a 3D average pooling layer
self.pool = nn.AvgPool3d(kernel_size=kernel_size)
def forward(self, x):
# Apply average pooling
x = self.pool(x)
# Flatten the tensor to a single dimension per batch element
x = x.view(x.size(0), -1)
return x
class VQGAN(nn.Module):
def __init__(self):
super().__init__()
self._set_seed(0)
self.embedding_dim = 256
self.n_codes = 16384
self.encoder = Encoder(16, [4,4,4], 1, 'group', 'replicate', 32)
self.decoder = Decoder(16, [4,4,4], 1, 'group', 32)
self.enc_out_ch = self.encoder.out_channels
self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, 256, 1, padding_type='replicate')
self.post_vq_conv = SamePadConv3d(256, self.enc_out_ch, 1)
self.codebook = Codebook(16384, 256, no_random_restart=False, restart_thres=False)
self.pooling = MeanPooling(kernel_size=16)
self.gan_feat_weight = 4
# TODO: Changed batchnorm from sync to normal
self.image_discriminator = NLayerDiscriminator(1, 64, 3, norm_layer=nn.BatchNorm2d)
self.disc_loss = hinge_d_loss
self.perceptual_model = LPIPS()
self.image_gan_weight = 1
self.perceptual_weight = 4
self.l1_weight = 4
def encode(self, x, include_embeddings=False, quantize=True):
h = self.pre_vq_conv(self.encoder(x))
if quantize:
vq_output = self.codebook(h)
if include_embeddings:
return vq_output['embeddings'], vq_output['encodings']
else:
return vq_output['encodings']
return h
def decode(self, latent, quantize=False):
if quantize:
vq_output = self.codebook(latent)
latent = vq_output['encodings']
h = F.embedding(latent, self.codebook.embeddings)
h = self.post_vq_conv(shift_dim(h, -1, 1))
return self.decoder(h)
def feature_extraction(self, x):
"""Extract embeddings given a grid."""
h = self.encode(x, include_embeddings=False, quantize=False)
return self.pooling(h.permute(0, 2, 3, 4, 1))
def forward(self, global_step, x, optimizer_idx=None, log_image=False, gpu_id=0):
B, C, T, H, W = x.shape
z = self.pre_vq_conv(self.encoder(x))
vq_output = self.codebook(z, gpu_id)
#vq_output['embeddings'] = torch.exp(vq_output['embeddings'])
x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings']))
recon_loss = (F.l1_loss(x_recon, x) * self.l1_weight)
# Selects one random 2D image from each 3D Image
frame_idx = torch.randint(0, T, [B]).to(gpu_id)
frame_idx_selected = frame_idx.reshape(-1,
1, 1, 1, 1).repeat(1, C, 1, H, W)
frames = torch.gather(x, 2, frame_idx_selected).squeeze(2)
frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2)
if log_image:
return frames, frames_recon, x, x_recon
if optimizer_idx == 0:
# Autoencoder - train the "generator"
# Perceptual loss
perceptual_loss = 0
if self.perceptual_weight > 0:
perceptual_loss = self.perceptual_model(
frames, frames_recon).mean() * self.perceptual_weight
# perceptual_loss = .123
# Discriminator loss (turned on after a certain epoch)
logits_image_fake, pred_image_fake = self.image_discriminator(
frames_recon)
g_image_loss = -torch.mean(logits_image_fake)
g_loss = self.image_gan_weight*g_image_loss
disc_factor = adopt_weight(
global_step, threshold=self.cfg.model.discriminator_iter_start)
aeloss = disc_factor * g_loss
# GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
image_gan_feat_loss = 0
feat_weights = 4.0 / (3 + 1)
if self.image_gan_weight > 0:
logits_image_real, pred_image_real = self.image_discriminator(
frames)
for i in range(len(pred_image_fake)-1):
image_gan_feat_loss += feat_weights * \
F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
)) * (self.image_gan_weight > 0)
gan_feat_loss = disc_factor * self.gan_feat_weight * \
(image_gan_feat_loss)
return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss, (g_image_loss, image_gan_feat_loss, vq_output['commitment_loss'], vq_output['perplexity'])
if optimizer_idx == 1:
# Train discriminator
logits_image_real, _ = self.image_discriminator(frames.detach())
logits_image_fake, _ = self.image_discriminator(
frames_recon.detach())
d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
disc_factor = adopt_weight(
global_step, threshold=self.cfg.model.discriminator_iter_start)
discloss = disc_factor * \
(self.image_gan_weight*d_image_loss)
return discloss
perceptual_loss = self.perceptual_model(
frames, frames_recon) * self.perceptual_weight
return recon_loss, x_recon, vq_output, perceptual_loss
def load_checkpoint(self, ckpt_path):
# load checkpoint file
ckpt_dict = torch.load(ckpt_path, map_location='cpu', weights_only=False)
# load hyparameters
self.config = ckpt_dict['hparams']['_content']
self.embedding_dim = self.config['model']['embedding_dim']
self.n_codes = self.config['model']['n_codes']
# instantiate modules
self.encoder = Encoder(
self.config['model']['n_hiddens'],
self.config['model']['downsample'],
self.config['dataset']['image_channels'],
self.config['model']['norm_type'],
self.config['model']['padding_type'],
self.config['model']['num_groups'],
)
self.decoder = Decoder(
self.config['model']['n_hiddens'],
self.config['model']['downsample'],
self.config['dataset']['image_channels'],
self.config['model']['norm_type'],
self.config['model']['num_groups']
)
self.enc_out_ch = self.encoder.out_channels
self.pre_vq_conv = SamePadConv3d(self.enc_out_ch, self.embedding_dim, 1, padding_type=self.config['model']['padding_type'])
self.post_vq_conv = SamePadConv3d(self.embedding_dim, self.enc_out_ch, 1)
self.codebook = Codebook(
self.n_codes,
self.embedding_dim,
no_random_restart=self.config['model']['no_random_restart'],
restart_thres=False
)
self.gan_feat_weight = self.config['model']['gan_feat_weight']
# TODO: Changed batchnorm from sync to normal
self.image_discriminator = NLayerDiscriminator(
self.config['dataset']['image_channels'],
self.config['model']['disc_channels'],
self.config['model']['disc_layers'],
norm_layer=nn.BatchNorm2d
)
self.disc_loss = hinge_d_loss
self.perceptual_model = LPIPS()
self.image_gan_weight = self.config['model']['gan_feat_weight']
self.perceptual_weight = self.config['model']['perceptual_weight']
self.l1_weight = self.config['model']['l1_weight']
# restore model weights
self.load_state_dict(ckpt_dict["MODEL_STATE"], strict=True)
# load RNG states each time the model and states are loaded from checkpoint
if 'rng' in self.config:
rng = self.config['rng']
for key, value in rng.items():
if key =='torch_state':
torch.set_rng_state(value.cpu())
elif key =='cuda_state':
torch.cuda.set_rng_state(value.cpu())
elif key =='numpy_state':
np.random.set_state(value)
elif key =='python_state':
random.setstate(value)
else:
print('unrecognized state')
def log_images(self, batch, **kwargs):
log = dict()
x = batch['data']
x = x.to(self.device)
frames, frames_rec, _, _ = self(x, log_image=True)
log["inputs"] = frames
log["reconstructions"] = frames_rec
#log['mean_org'] = batch['mean_org']
#log['std_org'] = batch['std_org']
return log
def _set_seed(self, value):
print('Random Seed:', value)
random.seed(value)
torch.manual_seed(value)
torch.cuda.manual_seed(value)
torch.cuda.manual_seed_all(value)
np.random.seed(value)
cudnn.deterministic = True
cudnn.benchmark = True
cudnn.enabled = True
def Normalize(in_channels, norm_type='group', num_groups=32):
assert norm_type in ['group', 'batch']
if norm_type == 'group':
# TODO Changed num_groups from 32 to 8
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == 'batch':
return torch.nn.SyncBatchNorm(in_channels)
class Encoder(nn.Module):
def __init__(self, n_hiddens = 16, downsample = [2,2,2] , image_channel=64, norm_type='group', padding_type='replicate', num_groups=32):
super().__init__()
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
self.conv_blocks = nn.ModuleList()
max_ds = n_times_downsample.max()
self.conv_first = SamePadConv3d(
image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
for i in range(max_ds):
block = nn.Module()
in_channels = n_hiddens * 2**i
out_channels = n_hiddens * 2**(i+1)
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
block.down = SamePadConv3d(
in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
block.res = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
self.conv_blocks.append(block)
n_times_downsample -= 1
self.final_block = nn.Sequential(
Normalize(out_channels, norm_type, num_groups=num_groups),
SiLU()
)
self.out_channels = out_channels
def forward(self, x):
h = self.conv_first(x)
for block in self.conv_blocks:
h = block.down(h)
h = block.res(h)
h = self.final_block(h)
return h
class Decoder(nn.Module):
def __init__(self, n_hiddens = 16, upsample= [4,4,4], image_channel=1, norm_type='group', num_groups=1):
super().__init__()
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
print('n_times_upsample :', n_times_upsample)
max_us = n_times_upsample.max()
print('max_us :', max_us)
in_channels = n_hiddens*2**max_us
self.final_block = nn.Sequential(
Normalize(in_channels, norm_type, num_groups=num_groups),
SiLU()
)
self.conv_blocks = nn.ModuleList()
for i in range(max_us):
block = nn.Module()
in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
out_channels = n_hiddens*2**(max_us-i)
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
block.up = SamePadConvTranspose3d(
in_channels, out_channels, 4, stride=us)
block.res1 = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
block.res2 = ResBlock(
out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
self.conv_blocks.append(block)
n_times_upsample -= 1
self.conv_last = SamePadConv3d(
out_channels, image_channel, kernel_size=3)
def forward(self, x):
h = self.final_block(x)
for i, block in enumerate(self.conv_blocks):
h = block.up(h)
h = block.res1(h)
h = block.res2(h)
h = self.conv_last(h)
return h
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
self.conv1 = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
self.dropout = torch.nn.Dropout(dropout)
self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
self.conv2 = SamePadConv3d(
out_channels, out_channels, kernel_size=3, padding_type=padding_type)
if self.in_channels != self.out_channels:
self.conv_shortcut = SamePadConv3d(
in_channels, out_channels, kernel_size=3, padding_type=padding_type)
def forward(self, x):
h = x
h = self.norm1(h)
h = silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x)
return x+h
# Does not support dilation
class SamePadConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=0, bias=bias)
def forward(self, x):
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
class SamePadConvTranspose3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
stride=stride, bias=bias,
padding=tuple([k - 1 for k in kernel_size]))
def forward(self, x):
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
# def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[-1], res[1:]
else:
return self.model(input), _
class NLayerDiscriminator3D(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
super(NLayerDiscriminator3D, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[-1], res[1:]
else:
return self.model(input), _
def load_VQGAN(folder="../data/checkpoints/pretrained", ckpt_filename="VQGAN_43.pt"):
model = VQGAN()
model.load_checkpoint(os.path.join(folder, ckpt_filename))
model.eval()
return model