|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from models.lib.quantizer import VectorQuantizer
|
|
from models.lib.base_models import Transformer, LinearEmbedding, PositionalEncoding
|
|
from base import BaseModel
|
|
|
|
|
|
class VQAutoEncoder(BaseModel):
|
|
""" VQ-GAN model """
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.encoder = TransformerEncoder(args)
|
|
self.decoder = TransformerDecoder(args, args.in_dim)
|
|
self.quantize = VectorQuantizer(args.n_embed,
|
|
args.zquant_dim,
|
|
beta=0.25)
|
|
self.args = args
|
|
|
|
|
|
|
|
def encode(self, x, x_a=None):
|
|
h = self.encoder(x)
|
|
h = h.view(x.shape[0], -1, self.args.face_quan_num, self.args.zquant_dim)
|
|
h = h.view(x.shape[0], -1, self.args.zquant_dim)
|
|
quant, emb_loss, info = self.quantize(h)
|
|
return quant, emb_loss, info
|
|
|
|
|
|
def decode(self, quant):
|
|
|
|
quant = quant.permute(0,2,1)
|
|
quant = quant.view(quant.shape[0], -1, self.args.face_quan_num, self.args.zquant_dim).contiguous()
|
|
quant = quant.view(quant.shape[0], -1, self.args.face_quan_num*self.args.zquant_dim).contiguous()
|
|
quant = quant.permute(0,2,1).contiguous()
|
|
dec = self.decoder(quant)
|
|
|
|
return dec
|
|
|
|
def forward(self, x, template):
|
|
template = template.unsqueeze(1)
|
|
x = x - template
|
|
|
|
|
|
quant, emb_loss, info = self.encode(x)
|
|
|
|
dec = self.decode(quant)
|
|
|
|
dec = dec + template
|
|
return dec, emb_loss, info
|
|
|
|
|
|
def sample_step(self, x, x_a=None):
|
|
quant_z, _, info = self.encode(x, x_a)
|
|
x_sample_det = self.decode(quant_z)
|
|
btc = quant_z.shape[0], quant_z.shape[2], quant_z.shape[1]
|
|
indices = info[2]
|
|
x_sample_check = self.decode_to_img(indices, btc)
|
|
return x_sample_det, x_sample_check
|
|
|
|
def get_quant(self, x, x_a=None):
|
|
quant_z, _, info = self.encode(x, x_a)
|
|
indices = info[2]
|
|
return quant_z, indices
|
|
|
|
def get_distances(self, x):
|
|
h = self.encoder(x)
|
|
d = self.quantize.get_distance(h)
|
|
return d
|
|
|
|
def get_quant_from_d(self, d, btc):
|
|
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
x = self.decode_to_img(min_encoding_indices, btc)
|
|
return x
|
|
|
|
@torch.no_grad()
|
|
def entry_to_feature(self, index, zshape):
|
|
index = index.long()
|
|
quant_z = self.quantize.get_codebook_entry(index.reshape(-1),
|
|
shape=None)
|
|
quant_z = torch.reshape(quant_z, zshape)
|
|
return quant_z
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def decode_to_img(self, index, zshape):
|
|
index = index.long()
|
|
quant_z = self.quantize.get_codebook_entry(index.reshape(-1),
|
|
shape=None)
|
|
quant_z = torch.reshape(quant_z, zshape).permute(0,2,1)
|
|
x = self.decode(quant_z)
|
|
return x
|
|
|
|
@torch.no_grad()
|
|
def decode_logit(self, logits, zshape):
|
|
if logits.dim() == 3:
|
|
probs = F.softmax(logits, dim=-1)
|
|
_, ix = torch.topk(probs, k=1, dim=-1)
|
|
else:
|
|
ix = logits
|
|
ix = torch.reshape(ix, (-1,1))
|
|
x = self.decode_to_img(ix, zshape)
|
|
return x
|
|
|
|
def get_logit(self, logits, sample=True, filter_value=-float('Inf'),
|
|
temperature=0.7, top_p=0.9, sample_idx=None):
|
|
""" function that samples the distribution of logits. (used in test)
|
|
if sample_idx is None, we perform nucleus sampling
|
|
"""
|
|
logits = logits / temperature
|
|
sample_idx = 0
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
if sample:
|
|
|
|
shape = probs.shape
|
|
probs = probs.reshape(shape[0]*shape[1],shape[2])
|
|
ix = torch.multinomial(probs, num_samples=sample_idx+1)
|
|
probs = probs.reshape(shape[0],shape[1],shape[2])
|
|
ix = ix.reshape(shape[0],shape[1])
|
|
else:
|
|
|
|
_, ix = torch.topk(probs, k=1, dim=-1)
|
|
return ix, probs
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
""" Encoder class for VQ-VAE with Transformer backbone """
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.args = args
|
|
size = self.args.in_dim
|
|
dim = self.args.hidden_size
|
|
self.vertice_mapping = nn.Sequential(nn.Linear(size,dim), nn.LeakyReLU(self.args.neg, True))
|
|
if args.quant_factor == 0:
|
|
layers = [nn.Sequential(
|
|
nn.Conv1d(dim,dim,5,stride=1,padding=2,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine)
|
|
)]
|
|
else:
|
|
layers = [nn.Sequential(
|
|
nn.Conv1d(dim,dim,5,stride=2,padding=2,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine)
|
|
)]
|
|
for _ in range(1, args.quant_factor):
|
|
layers += [nn.Sequential(
|
|
nn.Conv1d(dim,dim,5,stride=1,padding=2,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine),
|
|
nn.MaxPool1d(2)
|
|
)]
|
|
self.squasher = nn.Sequential(*layers)
|
|
self.encoder_transformer = Transformer(
|
|
in_size=self.args.hidden_size,
|
|
hidden_size=self.args.hidden_size,
|
|
num_hidden_layers=\
|
|
self.args.num_hidden_layers,
|
|
num_attention_heads=\
|
|
self.args.num_attention_heads,
|
|
intermediate_size=\
|
|
self.args.intermediate_size)
|
|
self.encoder_pos_embedding = PositionalEncoding(
|
|
self.args.hidden_size)
|
|
self.encoder_linear_embedding = LinearEmbedding(
|
|
self.args.hidden_size,
|
|
self.args.hidden_size)
|
|
|
|
def forward(self, inputs):
|
|
|
|
dummy_mask = {'max_mask': None, 'mask_index': -1, 'mask': None}
|
|
inputs = self.vertice_mapping(inputs)
|
|
inputs = self.squasher(inputs.permute(0,2,1)).permute(0,2,1)
|
|
|
|
encoder_features = self.encoder_linear_embedding(inputs)
|
|
encoder_features = self.encoder_pos_embedding(encoder_features)
|
|
encoder_features = self.encoder_transformer((encoder_features, dummy_mask))
|
|
|
|
return encoder_features
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
""" Decoder class for VQ-VAE with Transformer backbone """
|
|
|
|
def __init__(self, args, out_dim, is_audio=False):
|
|
super().__init__()
|
|
self.args = args
|
|
size=self.args.hidden_size
|
|
dim=self.args.hidden_size
|
|
self.expander = nn.ModuleList()
|
|
if args.quant_factor == 0:
|
|
self.expander.append(nn.Sequential(
|
|
nn.Conv1d(size,dim,5,stride=1,padding=2,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine)
|
|
))
|
|
else:
|
|
self.expander.append(nn.Sequential(
|
|
nn.ConvTranspose1d(size,dim,5,stride=2,padding=2,
|
|
output_padding=1,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine)
|
|
))
|
|
num_layers = args.quant_factor+2 \
|
|
if is_audio else args.quant_factor
|
|
|
|
for _ in range(1, num_layers):
|
|
self.expander.append(nn.Sequential(
|
|
nn.Conv1d(dim,dim,5,stride=1,padding=2,
|
|
padding_mode='replicate'),
|
|
nn.LeakyReLU(self.args.neg, True),
|
|
nn.InstanceNorm1d(dim, affine=args.INaffine),
|
|
))
|
|
self.decoder_transformer = Transformer(
|
|
in_size=self.args.hidden_size,
|
|
hidden_size=self.args.hidden_size,
|
|
num_hidden_layers=\
|
|
self.args.num_hidden_layers,
|
|
num_attention_heads=\
|
|
self.args.num_attention_heads,
|
|
intermediate_size=\
|
|
self.args.intermediate_size)
|
|
self.decoder_pos_embedding = PositionalEncoding(
|
|
self.args.hidden_size)
|
|
self.decoder_linear_embedding = LinearEmbedding(
|
|
self.args.hidden_size,
|
|
self.args.hidden_size)
|
|
|
|
self.vertice_map_reverse = nn.Linear(args.hidden_size,out_dim)
|
|
|
|
def forward(self, inputs):
|
|
dummy_mask = {'max_mask': None, 'mask_index': -1, 'mask': None}
|
|
|
|
for i, module in enumerate(self.expander):
|
|
inputs = module(inputs)
|
|
if i > 0:
|
|
inputs = inputs.repeat_interleave(2, dim=2)
|
|
inputs = inputs.permute(0,2,1)
|
|
decoder_features = self.decoder_linear_embedding(inputs)
|
|
decoder_features = self.decoder_pos_embedding(decoder_features)
|
|
|
|
decoder_features = self.decoder_transformer((decoder_features, dummy_mask))
|
|
pred_recon = self.vertice_map_reverse(decoder_features)
|
|
return pred_recon |