animatedaliensfans commited on
Commit
79eea03
·
verified ·
1 Parent(s): 984f4d5

Upload 4 files

Browse files
Files changed (4) hide show
  1. data.py +54 -0
  2. eval.py +8 -0
  3. main.py +239 -0
  4. model.py +81 -0
data.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ import torch
5
+
6
+ class LRW():
7
+
8
+ def __init__(self, folds,
9
+ labels_file = './data/label_sorted.txt',
10
+ root_path = '/beegfs/cy1355/lipread_datachunk_big/',
11
+ transform = None):
12
+ """
13
+ Args:
14
+ labels_file (string): Path to the text file with labels
15
+ root_path (string): Path to the file with the facial landmarks and audio features (MFCC)
16
+ folds (string): train / val / test indicator
17
+ transform (callable, optional): Optional transform to be applied
18
+ on a sample
19
+ """
20
+ self.folds = folds
21
+ self.labels_file = labels_file
22
+ self.root_path = root_path
23
+ with open(self.labels_file) as myfile:
24
+ self.data_dir = myfile.read().splitlines()
25
+
26
+ self.video_file = os.path.join(self.root_path, 'video_' + self.folds+ '.npy')
27
+ self.audio_file = os.path.join(self.root_path, 'audio_' + self.folds +'.npy')
28
+
29
+ self.video = npy_loader(self.video_file)
30
+ self.audio = npy_loader(self.audio_file)
31
+
32
+ print('Loading {} part'.format(self.folds))
33
+
34
+ def __len__(self):
35
+ return self.video.shape[0]
36
+
37
+ def __getitem__(self, idx):
38
+ vid = self.augmentation(self.video[idx, :, :, :])
39
+ aud = self.audio[idx, :, :]
40
+ labels = 0
41
+ return (vid, aud), labels
42
+
43
+ def augmentation(self, keypoints):
44
+ keypoints_move = keypoints * 0.7
45
+ ones = torch.ones(keypoints.shape, dtype = torch.float)
46
+ randint = torch.randint(1,73,(1,),dtype = torch.float)
47
+ if self.folds == 'train':
48
+ d = keypoints_move + ones * randint
49
+ else:
50
+ d = keypoints_move + ones * 38
51
+ return d
52
+
53
+ def npy_loader(file):
54
+ return torch.tensor(np.load(file)).float()
eval.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def main():
4
+ pass
5
+
6
+
7
+ def eval(epoch):
8
+ pass
main.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import glob
6
+
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+
10
+ from torch import nn, optim
11
+ from torchvision import datasets, transforms
12
+ from torchvision.utils import make_grid, save_image
13
+
14
+ from model import FaceEncoder, AudioEncoder, FaceDecoder
15
+ from data import LRW
16
+ from PIL import Image
17
+ import datetime
18
+
19
+ parser = argparse.ArgumentParser(description='Lip Generator Example')
20
+ parser.add_argument('--data', type=str, default='/beegfs/cy1355/lipread_datachunk_big/', metavar='N',
21
+ help='data root directory')
22
+ parser.add_argument('--batch-size', type=int, default=1024, metavar='N',
23
+ help='input batch size for training (default: 512)')
24
+ parser.add_argument('--epochs', type=int, default=1001, metavar='N',
25
+ help='number of epochs to train (default: 10)')
26
+ parser.add_argument('--no-cuda', action='store_true', default=False,
27
+ help='enables CUDA training')
28
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
29
+ help='random seed (default: 1)')
30
+ parser.add_argument('--log-interval', type=int, default=10, metavar='N',
31
+ help='how many batches to wait before logging training status')
32
+ args = parser.parse_args()
33
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
34
+ device = torch.device("cuda" if args.cuda else "cpu")
35
+ kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
36
+
37
+ print('Run on {}'.format(device))
38
+
39
+
40
+ face_encoder = FaceEncoder().to(device)
41
+
42
+ audio_encoder = AudioEncoder().to(device)
43
+ encoders_params = list(face_encoder.parameters()) + list(audio_encoder.parameters())
44
+ encoders_optimizer = optim.Adam(encoders_params, lr=1e-3, betas=(0.5, 0.999))
45
+ face_decoder = FaceDecoder().to(device)
46
+ decoder_optimizer = optim.Adam(face_decoder.parameters(), lr=1e-3, betas=(0.5, 0.999))
47
+ mse_loss = torch.nn.MSELoss()
48
+
49
+ dsets = {x: LRW(x, root_path = args.data) for x in ['train', 'val', 'test']}
50
+ dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=args.batch_size,\
51
+ shuffle=True, **kwargs) \
52
+ for x in ['train', 'val', 'test']}
53
+ train_loader = dset_loaders['train']
54
+ val_loader = dset_loaders['val']
55
+ test_loader = dset_loaders['test']
56
+ dset_sizes = {x: len(dsets[x]) for x in ['train', 'val', 'test']}
57
+ print('\nStatistics: train: {}, val: {}, test: {}'.format(dset_sizes['train'], dset_sizes['val'], dset_sizes['test']))
58
+
59
+ def open_level(mouth_keypoints):
60
+ u_index = [1, 2, 3, 4, 5, 13, 14, 15]
61
+ l_index = [11, 10, 9, 8, 7, 19, 18, 17]
62
+ upper_lip = mouth_keypoints[:, :, u_index, :]
63
+ lower_lip = mouth_keypoints[:, :, l_index, :]
64
+
65
+ # Coordinates
66
+ u_x = upper_lip[:, :, :, 0]
67
+ u_y = upper_lip[:, :, :, 1]
68
+ l_x = lower_lip[:, :, :, 0]
69
+ l_y = lower_lip[:, :, :, 1]
70
+
71
+ distance = ((u_x - l_x) ** 2 + (u_y - l_y) ** 2) ** 0.5
72
+ distance_mean = distance.mean(dim = 2)
73
+ distance_normed = distance_mean / distance_mean.sum(dim = 1).unsqueeze(1)
74
+ return distance_normed * 256
75
+
76
+ def train(epoch):
77
+ face_encoder.train()
78
+ audio_encoder.train()
79
+ face_decoder.train()
80
+
81
+ train_loss = 0
82
+ train_loader.dataset.test_case = False
83
+
84
+ for batch_idx, ((keypoints, mfcc), _) in enumerate(train_loader):
85
+
86
+ batch_size = keypoints.shape[0]
87
+ video_length = keypoints.shape[1]
88
+
89
+ encoders_optimizer.zero_grad()
90
+ decoder_optimizer.zero_grad()
91
+
92
+ keypoints = keypoints.to(device)
93
+ mfcc = mfcc.transpose(1,2).to(device).view(-1, 12)
94
+
95
+ face_points = keypoints[:, :, :48].view(-1, 96)
96
+ mouth_points = keypoints[:, : ,48:68].view(-1, 40)
97
+
98
+ face_embedding = face_encoder(face_points)
99
+ audio_embedding = audio_encoder(mfcc)
100
+
101
+ # Shuffle face_embedding
102
+ shuffle_index = torch.randperm(batch_size)
103
+ face_embedding_extended = face_embedding.view(batch_size, video_length, -1)
104
+ shuffled_face_embedding = face_embedding_extended[shuffle_index].view(batch_size * video_length, -1)
105
+
106
+ mixed_face_embedding = torch.cat((face_embedding, shuffled_face_embedding), dim = 0)
107
+ doubled_audio_embedding = torch.cat((audio_embedding, audio_embedding), dim = 0)
108
+
109
+ mixed_embedding = torch.cat((mixed_face_embedding, doubled_audio_embedding), dim = 1).view(batch_size * 2, -1, 144)
110
+
111
+ mixed_mouth_points_pred = face_decoder(mixed_embedding) * 255
112
+
113
+ mouth_points_pred = mixed_mouth_points_pred[:batch_size * video_length]
114
+ supervised_loss = mse_loss(mouth_points, mouth_points_pred)
115
+
116
+ shuffled_pred = mixed_mouth_points_pred[batch_size * video_length:].view(batch_size, video_length, 20, 2)
117
+ open_score_shuffled = open_level(shuffled_pred)
118
+ original_pred = mouth_points_pred.view(batch_size, video_length, 20, 2)
119
+ open_score_normal = open_level(original_pred)
120
+
121
+ #kld = torch.nn.KLDivLoss(reduction = 'batchmean')
122
+
123
+ #log_prob = torch.nn.LogSoftmax(dim=1)(open_score_shuffled.transpose(0,1))
124
+ #prob = torch.nn.Softmax(dim=1)(open_score_normal.transpose(0,1))
125
+
126
+ adversarial_loss = mse_loss(open_score_shuffled, open_score_normal)
127
+ #adversarial_loss = kld(log_prob, prob)
128
+
129
+ loss = supervised_loss + adversarial_loss
130
+ loss.backward()
131
+
132
+ train_loss += loss.item()
133
+
134
+ encoders_optimizer.step()
135
+ decoder_optimizer.step()
136
+
137
+ if batch_idx % args.log_interval == 0:
138
+
139
+ print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
140
+ epoch, batch_idx * args.batch_size, len(train_loader.dataset),
141
+ 100. * batch_idx * args.batch_size / len(train_loader.dataset),
142
+ loss.item() / len(mfcc)))
143
+
144
+ print('====> Epoch: {} Average loss: {:.4f}'.format(
145
+ epoch, train_loss / len(train_loader.dataset)))
146
+
147
+ def save_image_(tensor, fp, nrow=8, padding=2,
148
+ normalize=False, range=None, scale_each=False, pad_value=0, format=None):
149
+ """Save a given Tensor into an image file.
150
+
151
+ Args:
152
+ tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
153
+ saves the tensor as a grid of images by calling ``make_grid``.
154
+ fp - A filename(string) or file object
155
+ format(Optional): If omitted, the format to use is determined from the filename extension.
156
+ If a file object was used instead of a filename, this parameter should always be used.
157
+ **kwargs: Other arguments are documented in ``make_grid``.
158
+ """
159
+
160
+ grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
161
+ normalize=normalize, range=range, scale_each=scale_each)
162
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
163
+ ndarr = grid.mul(255).add(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
164
+ im = Image.fromarray(ndarr)
165
+ im.save(fp, format=format)
166
+
167
+
168
+ def image_from_tensor(tensor):
169
+ batch_size = tensor.shape[0]
170
+ img = np.zeros((batch_size,256,256,3), np.uint8) + 128
171
+ for i in range(batch_size):
172
+ for row in range(67):
173
+ x_1 = tensor[i, row, 0]
174
+ y_1 = tensor[i, row, 1]
175
+ x_2 = tensor[i, row+1, 0]
176
+ y_2 = tensor[i, row+1, 1]
177
+ cv2.circle(img[i], (x_1, y_1), 1, (0, 0, 255), -1)
178
+ cv2.circle(img[i], (x_2, y_2), 1, (0, 0, 255), -1)
179
+ cv2.line(img[i], (x_1, y_1), (x_2, y_2), (0, 0, 255), 1)
180
+ img = np.transpose(img, (0, 3, 1, 2))
181
+ return torch.tensor(img)
182
+
183
+ def test(epoch):
184
+
185
+ face_encoder.eval()
186
+ audio_encoder.eval()
187
+ face_decoder.eval()
188
+
189
+ test_loss = 0
190
+
191
+ test_loader.dataset.test_case = True
192
+
193
+ with torch.no_grad():
194
+ for batch_idx, ((keypoints, mfcc), _) in enumerate(test_loader):
195
+
196
+ batch_size = keypoints.shape[0]
197
+
198
+ keypoints = keypoints.to(device)
199
+ mfcc = mfcc.transpose(1,2).to(device).reshape(-1, 12)
200
+
201
+ face_points = keypoints[:, :, :48].reshape(-1, 96)
202
+ mouth_points = keypoints[:, : ,48:68].reshape(-1, 40)
203
+
204
+ face_embedding = face_encoder(face_points)
205
+ audio_embedding = audio_encoder(mfcc)
206
+
207
+ embedding = torch.cat((face_embedding, audio_embedding), dim = 1).view(batch_size, -1, 144)
208
+
209
+ mouth_points_pred = face_decoder(embedding) * 255
210
+
211
+ test_loss += mse_loss(mouth_points, mouth_points_pred)
212
+
213
+ # if epoch % 10 == 0:
214
+ n = min(keypoints.size(0), 8)
215
+ image_data = image_from_tensor(keypoints[:,0,:,:])
216
+ face_pred = torch.cat((face_points.view(batch_size, 29, 48, 2),
217
+ mouth_points_pred.view(batch_size, 29, 20, 2)), dim = 2)
218
+ face_pred_batch = image_from_tensor(face_pred[:,0,:,:])
219
+ comparison = torch.cat([image_data[:n], face_pred_batch[:n]], dim = 0)
220
+
221
+ if not os.path.exists('./wav_results_base'): os.mkdir('./wav_results_base')
222
+ save_image(comparison.cpu(), './wav_results_base/reconstruction_' + str(epoch) + '_' + str(round(test_loss.item(),6)) + '.png', nrow=n)
223
+
224
+ if not os.path.exists('./wav_saves_base'): os.mkdir('./wav_saves_base')
225
+ torch.save(face_encoder, './wav_saves_base/face_encoder_{}.pt'.format(epoch))
226
+ torch.save(audio_encoder, './wav_saves_base/audio_encoder_{}.pt'.format(epoch))
227
+ torch.save(face_decoder, './wav_saves_base/face_decoder_{}.pt'.format(epoch))
228
+
229
+ test_loss /= len(test_loader.dataset)
230
+ print('====> Test set loss: {:.4f}'.format(test_loss))
231
+
232
+ if __name__ == "__main__":
233
+ for epoch in range(0, args.epochs):
234
+ print('Epoch {} starts at {}'.format(epoch, datetime.datetime.now()))
235
+ train(epoch)
236
+ test(epoch)
237
+ print('Epoch {} ends at {}'.format(epoch, datetime.datetime.now()))
238
+
239
+
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class FaceEncoder(nn.Module):
5
+ def __init__(self):
6
+ super(FaceEncoder, self).__init__()
7
+
8
+ self.encoder = nn.Sequential(
9
+ nn.Linear(96, 128),
10
+ nn.BatchNorm1d(128),
11
+ nn.ReLU(),
12
+ nn.Linear(128, 64),
13
+ nn.BatchNorm1d(64),
14
+ nn.ReLU(),
15
+ nn.Linear(64, 16),
16
+ )
17
+
18
+ for m in self.modules():
19
+ if isinstance(m, torch.nn.Linear):
20
+ torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
21
+ elif isinstance(m, nn.BatchNorm1d):
22
+ m.weight.data.fill_(1)
23
+ m.bias.data.zero_()
24
+
25
+ def forward(self, x):
26
+ return self.encoder(x)
27
+
28
+
29
+ class AudioEncoder(nn.Module):
30
+ def __init__(self):
31
+ super(AudioEncoder, self).__init__()
32
+
33
+ self.encoder = nn.Sequential(
34
+ nn.Linear(12, 32),
35
+ nn.BatchNorm1d(32),
36
+ nn.ReLU(),
37
+ nn.Linear(32, 64),
38
+ nn.BatchNorm1d(64),
39
+ nn.ReLU(),
40
+ nn.Linear(64, 128),
41
+ )
42
+
43
+ for m in self.modules():
44
+ if isinstance(m, torch.nn.Linear):
45
+ torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
46
+ elif isinstance(m, nn.BatchNorm1d):
47
+ m.weight.data.fill_(1)
48
+ m.bias.data.zero_()
49
+
50
+ def forward(self, x):
51
+ return self.encoder(x)
52
+
53
+
54
+ class FaceDecoder(nn.Module):
55
+ def __init__(self):
56
+ super(FaceDecoder, self).__init__()
57
+ h_GRU = 144
58
+ self.stabilizer = nn.GRU(144, h_GRU, 2, batch_first = True, dropout = 0.2)
59
+
60
+ self.decoder = nn.Sequential(
61
+ nn.Linear(144, 256),
62
+ nn.BatchNorm1d(256),
63
+ nn.ReLU(),
64
+ nn.Linear(256, 128),
65
+ nn.BatchNorm1d(128),
66
+ nn.ReLU(),
67
+ nn.Linear(128, 40),
68
+ nn.Sigmoid(),
69
+ )
70
+
71
+ for m in self.modules():
72
+ if isinstance(m, torch.nn.Linear):
73
+ torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
74
+ elif isinstance(m, nn.BatchNorm1d):
75
+ m.weight.data.fill_(1)
76
+ m.bias.data.zero_()
77
+
78
+ def forward(self, x):
79
+ x, _ = self.stabilizer(x)
80
+ return self.decoder(x.reshape(-1, 144))
81
+