animatedaliensfans's picture
Upload 4 files
79eea03 verified
import os
import cv2
import torch
import argparse
import glob
import numpy as np
import torch.nn.functional as F
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from model import FaceEncoder, AudioEncoder, FaceDecoder
from data import LRW
from PIL import Image
import datetime
parser = argparse.ArgumentParser(description='Lip Generator Example')
parser.add_argument('--data', type=str, default='/beegfs/cy1355/lipread_datachunk_big/', metavar='N',
help='data root directory')
parser.add_argument('--batch-size', type=int, default=1024, metavar='N',
help='input batch size for training (default: 512)')
parser.add_argument('--epochs', type=int, default=1001, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
print('Run on {}'.format(device))
face_encoder = FaceEncoder().to(device)
audio_encoder = AudioEncoder().to(device)
encoders_params = list(face_encoder.parameters()) + list(audio_encoder.parameters())
encoders_optimizer = optim.Adam(encoders_params, lr=1e-3, betas=(0.5, 0.999))
face_decoder = FaceDecoder().to(device)
decoder_optimizer = optim.Adam(face_decoder.parameters(), lr=1e-3, betas=(0.5, 0.999))
mse_loss = torch.nn.MSELoss()
dsets = {x: LRW(x, root_path = args.data) for x in ['train', 'val', 'test']}
dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=args.batch_size,\
shuffle=True, **kwargs) \
for x in ['train', 'val', 'test']}
train_loader = dset_loaders['train']
val_loader = dset_loaders['val']
test_loader = dset_loaders['test']
dset_sizes = {x: len(dsets[x]) for x in ['train', 'val', 'test']}
print('\nStatistics: train: {}, val: {}, test: {}'.format(dset_sizes['train'], dset_sizes['val'], dset_sizes['test']))
def open_level(mouth_keypoints):
u_index = [1, 2, 3, 4, 5, 13, 14, 15]
l_index = [11, 10, 9, 8, 7, 19, 18, 17]
upper_lip = mouth_keypoints[:, :, u_index, :]
lower_lip = mouth_keypoints[:, :, l_index, :]
# Coordinates
u_x = upper_lip[:, :, :, 0]
u_y = upper_lip[:, :, :, 1]
l_x = lower_lip[:, :, :, 0]
l_y = lower_lip[:, :, :, 1]
distance = ((u_x - l_x) ** 2 + (u_y - l_y) ** 2) ** 0.5
distance_mean = distance.mean(dim = 2)
distance_normed = distance_mean / distance_mean.sum(dim = 1).unsqueeze(1)
return distance_normed * 256
def train(epoch):
face_encoder.train()
audio_encoder.train()
face_decoder.train()
train_loss = 0
train_loader.dataset.test_case = False
for batch_idx, ((keypoints, mfcc), _) in enumerate(train_loader):
batch_size = keypoints.shape[0]
video_length = keypoints.shape[1]
encoders_optimizer.zero_grad()
decoder_optimizer.zero_grad()
keypoints = keypoints.to(device)
mfcc = mfcc.transpose(1,2).to(device).view(-1, 12)
face_points = keypoints[:, :, :48].view(-1, 96)
mouth_points = keypoints[:, : ,48:68].view(-1, 40)
face_embedding = face_encoder(face_points)
audio_embedding = audio_encoder(mfcc)
# Shuffle face_embedding
shuffle_index = torch.randperm(batch_size)
face_embedding_extended = face_embedding.view(batch_size, video_length, -1)
shuffled_face_embedding = face_embedding_extended[shuffle_index].view(batch_size * video_length, -1)
mixed_face_embedding = torch.cat((face_embedding, shuffled_face_embedding), dim = 0)
doubled_audio_embedding = torch.cat((audio_embedding, audio_embedding), dim = 0)
mixed_embedding = torch.cat((mixed_face_embedding, doubled_audio_embedding), dim = 1).view(batch_size * 2, -1, 144)
mixed_mouth_points_pred = face_decoder(mixed_embedding) * 255
mouth_points_pred = mixed_mouth_points_pred[:batch_size * video_length]
supervised_loss = mse_loss(mouth_points, mouth_points_pred)
shuffled_pred = mixed_mouth_points_pred[batch_size * video_length:].view(batch_size, video_length, 20, 2)
open_score_shuffled = open_level(shuffled_pred)
original_pred = mouth_points_pred.view(batch_size, video_length, 20, 2)
open_score_normal = open_level(original_pred)
#kld = torch.nn.KLDivLoss(reduction = 'batchmean')
#log_prob = torch.nn.LogSoftmax(dim=1)(open_score_shuffled.transpose(0,1))
#prob = torch.nn.Softmax(dim=1)(open_score_normal.transpose(0,1))
adversarial_loss = mse_loss(open_score_shuffled, open_score_normal)
#adversarial_loss = kld(log_prob, prob)
loss = supervised_loss + adversarial_loss
loss.backward()
train_loss += loss.item()
encoders_optimizer.step()
decoder_optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(train_loader.dataset),
100. * batch_idx * args.batch_size / len(train_loader.dataset),
loss.item() / len(mfcc)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
def save_image_(tensor, fp, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0, format=None):
"""Save a given Tensor into an image file.
Args:
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
saves the tensor as a grid of images by calling ``make_grid``.
fp - A filename(string) or file object
format(Optional): If omitted, the format to use is determined from the filename extension.
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
def image_from_tensor(tensor):
batch_size = tensor.shape[0]
img = np.zeros((batch_size,256,256,3), np.uint8) + 128
for i in range(batch_size):
for row in range(67):
x_1 = tensor[i, row, 0]
y_1 = tensor[i, row, 1]
x_2 = tensor[i, row+1, 0]
y_2 = tensor[i, row+1, 1]
cv2.circle(img[i], (x_1, y_1), 1, (0, 0, 255), -1)
cv2.circle(img[i], (x_2, y_2), 1, (0, 0, 255), -1)
cv2.line(img[i], (x_1, y_1), (x_2, y_2), (0, 0, 255), 1)
img = np.transpose(img, (0, 3, 1, 2))
return torch.tensor(img)
def test(epoch):
face_encoder.eval()
audio_encoder.eval()
face_decoder.eval()
test_loss = 0
test_loader.dataset.test_case = True
with torch.no_grad():
for batch_idx, ((keypoints, mfcc), _) in enumerate(test_loader):
batch_size = keypoints.shape[0]
keypoints = keypoints.to(device)
mfcc = mfcc.transpose(1,2).to(device).reshape(-1, 12)
face_points = keypoints[:, :, :48].reshape(-1, 96)
mouth_points = keypoints[:, : ,48:68].reshape(-1, 40)
face_embedding = face_encoder(face_points)
audio_embedding = audio_encoder(mfcc)
embedding = torch.cat((face_embedding, audio_embedding), dim = 1).view(batch_size, -1, 144)
mouth_points_pred = face_decoder(embedding) * 255
test_loss += mse_loss(mouth_points, mouth_points_pred)
# if epoch % 10 == 0:
n = min(keypoints.size(0), 8)
image_data = image_from_tensor(keypoints[:,0,:,:])
face_pred = torch.cat((face_points.view(batch_size, 29, 48, 2),
mouth_points_pred.view(batch_size, 29, 20, 2)), dim = 2)
face_pred_batch = image_from_tensor(face_pred[:,0,:,:])
comparison = torch.cat([image_data[:n], face_pred_batch[:n]], dim = 0)
if not os.path.exists('./wav_results_base'): os.mkdir('./wav_results_base')
save_image(comparison.cpu(), './wav_results_base/reconstruction_' + str(epoch) + '_' + str(round(test_loss.item(),6)) + '.png', nrow=n)
if not os.path.exists('./wav_saves_base'): os.mkdir('./wav_saves_base')
torch.save(face_encoder, './wav_saves_base/face_encoder_{}.pt'.format(epoch))
torch.save(audio_encoder, './wav_saves_base/audio_encoder_{}.pt'.format(epoch))
torch.save(face_decoder, './wav_saves_base/face_decoder_{}.pt'.format(epoch))
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
if __name__ == "__main__":
for epoch in range(0, args.epochs):
print('Epoch {} starts at {}'.format(epoch, datetime.datetime.now()))
train(epoch)
test(epoch)
print('Epoch {} ends at {}'.format(epoch, datetime.datetime.now()))