MultiTalk-Code / models /stage1_vocaset.py
ameerazam08's picture
Upload folder using huggingface_hub
6931c7b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.lib.quantizer import VectorQuantizer
from models.lib.base_models import Transformer, LinearEmbedding, PositionalEncoding
from base import BaseModel
class VQAutoEncoder(BaseModel):
""" VQ-GAN model """
def __init__(self, args):
super().__init__()
self.encoder = TransformerEncoder(args)
self.decoder = TransformerDecoder(args, args.in_dim)
self.quantize = VectorQuantizer(args.n_embed,
args.zquant_dim,
beta=0.25)
self.args = args
def encode(self, x, x_a=None):
h = self.encoder(x) ## x --> z'
h = h.view(x.shape[0], -1, self.args.face_quan_num, self.args.zquant_dim)
h = h.view(x.shape[0], -1, self.args.zquant_dim)
quant, emb_loss, info = self.quantize(h) ## finds nearest quantization
return quant, emb_loss, info
def decode(self, quant):
#BCL
quant = quant.permute(0,2,1)
quant = quant.view(quant.shape[0], -1, self.args.face_quan_num, self.args.zquant_dim).contiguous()
quant = quant.view(quant.shape[0], -1, self.args.face_quan_num*self.args.zquant_dim).contiguous()
quant = quant.permute(0,2,1).contiguous()
dec = self.decoder(quant) ## z' --> x
return dec
def forward(self, x, template):
template = template.unsqueeze(1) #B,V*3 -> B, 1, V*3
x = x - template
###x.shape: [B, L C]
quant, emb_loss, info = self.encode(x)
### quant [B, C, L]
dec = self.decode(quant)
dec = dec + template
return dec, emb_loss, info
def sample_step(self, x, x_a=None):
quant_z, _, info = self.encode(x, x_a)
x_sample_det = self.decode(quant_z)
btc = quant_z.shape[0], quant_z.shape[2], quant_z.shape[1]
indices = info[2]
x_sample_check = self.decode_to_img(indices, btc)
return x_sample_det, x_sample_check
def get_quant(self, x, x_a=None):
quant_z, _, info = self.encode(x, x_a)
indices = info[2]
return quant_z, indices
def get_distances(self, x):
h = self.encoder(x) ## x --> z'
d = self.quantize.get_distance(h)
return d
def get_quant_from_d(self, d, btc):
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
x = self.decode_to_img(min_encoding_indices, btc)
return x
@torch.no_grad()
def entry_to_feature(self, index, zshape):
index = index.long()
quant_z = self.quantize.get_codebook_entry(index.reshape(-1),
shape=None)
quant_z = torch.reshape(quant_z, zshape)
return quant_z
@torch.no_grad()
def decode_to_img(self, index, zshape):
index = index.long()
quant_z = self.quantize.get_codebook_entry(index.reshape(-1),
shape=None)
quant_z = torch.reshape(quant_z, zshape).permute(0,2,1) # B L 1 -> B L C -> B C L
x = self.decode(quant_z)
return x
@torch.no_grad()
def decode_logit(self, logits, zshape):
if logits.dim() == 3:
probs = F.softmax(logits, dim=-1)
_, ix = torch.topk(probs, k=1, dim=-1)
else:
ix = logits
ix = torch.reshape(ix, (-1,1))
x = self.decode_to_img(ix, zshape)
return x
def get_logit(self, logits, sample=True, filter_value=-float('Inf'),
temperature=0.7, top_p=0.9, sample_idx=None):
""" function that samples the distribution of logits. (used in test)
if sample_idx is None, we perform nucleus sampling
"""
logits = logits / temperature
sample_idx = 0
##########
probs = F.softmax(logits, dim=-1) # B, N, embed_num
if sample:
## multinomial sampling
shape = probs.shape
probs = probs.reshape(shape[0]*shape[1],shape[2])
ix = torch.multinomial(probs, num_samples=sample_idx+1)
probs = probs.reshape(shape[0],shape[1],shape[2])
ix = ix.reshape(shape[0],shape[1])
else:
## top 1; no sampling
_, ix = torch.topk(probs, k=1, dim=-1)
return ix, probs
class TransformerEncoder(nn.Module):
""" Encoder class for VQ-VAE with Transformer backbone """
def __init__(self, args):
super().__init__()
self.args = args
size = self.args.in_dim
dim = self.args.hidden_size
self.vertice_mapping = nn.Sequential(nn.Linear(size,dim), nn.LeakyReLU(self.args.neg, True))
if args.quant_factor == 0:
layers = [nn.Sequential(
nn.Conv1d(dim,dim,5,stride=1,padding=2,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine)
)]
else:
layers = [nn.Sequential(
nn.Conv1d(dim,dim,5,stride=2,padding=2,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine)
)]
for _ in range(1, args.quant_factor):
layers += [nn.Sequential(
nn.Conv1d(dim,dim,5,stride=1,padding=2,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine),
nn.MaxPool1d(2)
)]
self.squasher = nn.Sequential(*layers)
self.encoder_transformer = Transformer(
in_size=self.args.hidden_size,
hidden_size=self.args.hidden_size,
num_hidden_layers=\
self.args.num_hidden_layers,
num_attention_heads=\
self.args.num_attention_heads,
intermediate_size=\
self.args.intermediate_size)
self.encoder_pos_embedding = PositionalEncoding(
self.args.hidden_size)
self.encoder_linear_embedding = LinearEmbedding(
self.args.hidden_size,
self.args.hidden_size)
def forward(self, inputs):
## downdample into path-wise length seq before passing into transformer
dummy_mask = {'max_mask': None, 'mask_index': -1, 'mask': None}
inputs = self.vertice_mapping(inputs)
inputs = self.squasher(inputs.permute(0,2,1)).permute(0,2,1) # [N L C]
encoder_features = self.encoder_linear_embedding(inputs)
encoder_features = self.encoder_pos_embedding(encoder_features)
encoder_features = self.encoder_transformer((encoder_features, dummy_mask))
return encoder_features
class TransformerDecoder(nn.Module):
""" Decoder class for VQ-VAE with Transformer backbone """
def __init__(self, args, out_dim, is_audio=False):
super().__init__()
self.args = args
size=self.args.hidden_size
dim=self.args.hidden_size
self.expander = nn.ModuleList()
if args.quant_factor == 0:
self.expander.append(nn.Sequential(
nn.Conv1d(size,dim,5,stride=1,padding=2,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine)
))
else:
self.expander.append(nn.Sequential(
nn.ConvTranspose1d(size,dim,5,stride=2,padding=2,
output_padding=1,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine)
))
num_layers = args.quant_factor+2 \
if is_audio else args.quant_factor
for _ in range(1, num_layers):
self.expander.append(nn.Sequential(
nn.Conv1d(dim,dim,5,stride=1,padding=2,
padding_mode='replicate'),
nn.LeakyReLU(self.args.neg, True),
nn.InstanceNorm1d(dim, affine=args.INaffine),
))
self.decoder_transformer = Transformer(
in_size=self.args.hidden_size,
hidden_size=self.args.hidden_size,
num_hidden_layers=\
self.args.num_hidden_layers,
num_attention_heads=\
self.args.num_attention_heads,
intermediate_size=\
self.args.intermediate_size)
self.decoder_pos_embedding = PositionalEncoding(
self.args.hidden_size)
self.decoder_linear_embedding = LinearEmbedding(
self.args.hidden_size,
self.args.hidden_size)
self.vertice_map_reverse = nn.Linear(args.hidden_size,out_dim)
def forward(self, inputs):
dummy_mask = {'max_mask': None, 'mask_index': -1, 'mask': None}
## upsample into original length seq before passing into transformer
for i, module in enumerate(self.expander):
inputs = module(inputs)
if i > 0:
inputs = inputs.repeat_interleave(2, dim=2)
inputs = inputs.permute(0,2,1) #BLC
decoder_features = self.decoder_linear_embedding(inputs)
decoder_features = self.decoder_pos_embedding(decoder_features)
decoder_features = self.decoder_transformer((decoder_features, dummy_mask))
pred_recon = self.vertice_map_reverse(decoder_features)
return pred_recon