Upload 4 files
Browse files
data.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class LRW():
|
7 |
+
|
8 |
+
def __init__(self, folds,
|
9 |
+
labels_file = './data/label_sorted.txt',
|
10 |
+
root_path = '/beegfs/cy1355/lipread_datachunk_big/',
|
11 |
+
transform = None):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
labels_file (string): Path to the text file with labels
|
15 |
+
root_path (string): Path to the file with the facial landmarks and audio features (MFCC)
|
16 |
+
folds (string): train / val / test indicator
|
17 |
+
transform (callable, optional): Optional transform to be applied
|
18 |
+
on a sample
|
19 |
+
"""
|
20 |
+
self.folds = folds
|
21 |
+
self.labels_file = labels_file
|
22 |
+
self.root_path = root_path
|
23 |
+
with open(self.labels_file) as myfile:
|
24 |
+
self.data_dir = myfile.read().splitlines()
|
25 |
+
|
26 |
+
self.video_file = os.path.join(self.root_path, 'video_' + self.folds+ '.npy')
|
27 |
+
self.audio_file = os.path.join(self.root_path, 'audio_' + self.folds +'.npy')
|
28 |
+
|
29 |
+
self.video = npy_loader(self.video_file)
|
30 |
+
self.audio = npy_loader(self.audio_file)
|
31 |
+
|
32 |
+
print('Loading {} part'.format(self.folds))
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.video.shape[0]
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
vid = self.augmentation(self.video[idx, :, :, :])
|
39 |
+
aud = self.audio[idx, :, :]
|
40 |
+
labels = 0
|
41 |
+
return (vid, aud), labels
|
42 |
+
|
43 |
+
def augmentation(self, keypoints):
|
44 |
+
keypoints_move = keypoints * 0.7
|
45 |
+
ones = torch.ones(keypoints.shape, dtype = torch.float)
|
46 |
+
randint = torch.randint(1,73,(1,),dtype = torch.float)
|
47 |
+
if self.folds == 'train':
|
48 |
+
d = keypoints_move + ones * randint
|
49 |
+
else:
|
50 |
+
d = keypoints_move + ones * 38
|
51 |
+
return d
|
52 |
+
|
53 |
+
def npy_loader(file):
|
54 |
+
return torch.tensor(np.load(file)).float()
|
eval.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def main():
|
4 |
+
pass
|
5 |
+
|
6 |
+
|
7 |
+
def eval(epoch):
|
8 |
+
pass
|
main.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from torch import nn, optim
|
11 |
+
from torchvision import datasets, transforms
|
12 |
+
from torchvision.utils import make_grid, save_image
|
13 |
+
|
14 |
+
from model import FaceEncoder, AudioEncoder, FaceDecoder
|
15 |
+
from data import LRW
|
16 |
+
from PIL import Image
|
17 |
+
import datetime
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='Lip Generator Example')
|
20 |
+
parser.add_argument('--data', type=str, default='/beegfs/cy1355/lipread_datachunk_big/', metavar='N',
|
21 |
+
help='data root directory')
|
22 |
+
parser.add_argument('--batch-size', type=int, default=1024, metavar='N',
|
23 |
+
help='input batch size for training (default: 512)')
|
24 |
+
parser.add_argument('--epochs', type=int, default=1001, metavar='N',
|
25 |
+
help='number of epochs to train (default: 10)')
|
26 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
27 |
+
help='enables CUDA training')
|
28 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
29 |
+
help='random seed (default: 1)')
|
30 |
+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
31 |
+
help='how many batches to wait before logging training status')
|
32 |
+
args = parser.parse_args()
|
33 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
34 |
+
device = torch.device("cuda" if args.cuda else "cpu")
|
35 |
+
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
36 |
+
|
37 |
+
print('Run on {}'.format(device))
|
38 |
+
|
39 |
+
|
40 |
+
face_encoder = FaceEncoder().to(device)
|
41 |
+
|
42 |
+
audio_encoder = AudioEncoder().to(device)
|
43 |
+
encoders_params = list(face_encoder.parameters()) + list(audio_encoder.parameters())
|
44 |
+
encoders_optimizer = optim.Adam(encoders_params, lr=1e-3, betas=(0.5, 0.999))
|
45 |
+
face_decoder = FaceDecoder().to(device)
|
46 |
+
decoder_optimizer = optim.Adam(face_decoder.parameters(), lr=1e-3, betas=(0.5, 0.999))
|
47 |
+
mse_loss = torch.nn.MSELoss()
|
48 |
+
|
49 |
+
dsets = {x: LRW(x, root_path = args.data) for x in ['train', 'val', 'test']}
|
50 |
+
dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=args.batch_size,\
|
51 |
+
shuffle=True, **kwargs) \
|
52 |
+
for x in ['train', 'val', 'test']}
|
53 |
+
train_loader = dset_loaders['train']
|
54 |
+
val_loader = dset_loaders['val']
|
55 |
+
test_loader = dset_loaders['test']
|
56 |
+
dset_sizes = {x: len(dsets[x]) for x in ['train', 'val', 'test']}
|
57 |
+
print('\nStatistics: train: {}, val: {}, test: {}'.format(dset_sizes['train'], dset_sizes['val'], dset_sizes['test']))
|
58 |
+
|
59 |
+
def open_level(mouth_keypoints):
|
60 |
+
u_index = [1, 2, 3, 4, 5, 13, 14, 15]
|
61 |
+
l_index = [11, 10, 9, 8, 7, 19, 18, 17]
|
62 |
+
upper_lip = mouth_keypoints[:, :, u_index, :]
|
63 |
+
lower_lip = mouth_keypoints[:, :, l_index, :]
|
64 |
+
|
65 |
+
# Coordinates
|
66 |
+
u_x = upper_lip[:, :, :, 0]
|
67 |
+
u_y = upper_lip[:, :, :, 1]
|
68 |
+
l_x = lower_lip[:, :, :, 0]
|
69 |
+
l_y = lower_lip[:, :, :, 1]
|
70 |
+
|
71 |
+
distance = ((u_x - l_x) ** 2 + (u_y - l_y) ** 2) ** 0.5
|
72 |
+
distance_mean = distance.mean(dim = 2)
|
73 |
+
distance_normed = distance_mean / distance_mean.sum(dim = 1).unsqueeze(1)
|
74 |
+
return distance_normed * 256
|
75 |
+
|
76 |
+
def train(epoch):
|
77 |
+
face_encoder.train()
|
78 |
+
audio_encoder.train()
|
79 |
+
face_decoder.train()
|
80 |
+
|
81 |
+
train_loss = 0
|
82 |
+
train_loader.dataset.test_case = False
|
83 |
+
|
84 |
+
for batch_idx, ((keypoints, mfcc), _) in enumerate(train_loader):
|
85 |
+
|
86 |
+
batch_size = keypoints.shape[0]
|
87 |
+
video_length = keypoints.shape[1]
|
88 |
+
|
89 |
+
encoders_optimizer.zero_grad()
|
90 |
+
decoder_optimizer.zero_grad()
|
91 |
+
|
92 |
+
keypoints = keypoints.to(device)
|
93 |
+
mfcc = mfcc.transpose(1,2).to(device).view(-1, 12)
|
94 |
+
|
95 |
+
face_points = keypoints[:, :, :48].view(-1, 96)
|
96 |
+
mouth_points = keypoints[:, : ,48:68].view(-1, 40)
|
97 |
+
|
98 |
+
face_embedding = face_encoder(face_points)
|
99 |
+
audio_embedding = audio_encoder(mfcc)
|
100 |
+
|
101 |
+
# Shuffle face_embedding
|
102 |
+
shuffle_index = torch.randperm(batch_size)
|
103 |
+
face_embedding_extended = face_embedding.view(batch_size, video_length, -1)
|
104 |
+
shuffled_face_embedding = face_embedding_extended[shuffle_index].view(batch_size * video_length, -1)
|
105 |
+
|
106 |
+
mixed_face_embedding = torch.cat((face_embedding, shuffled_face_embedding), dim = 0)
|
107 |
+
doubled_audio_embedding = torch.cat((audio_embedding, audio_embedding), dim = 0)
|
108 |
+
|
109 |
+
mixed_embedding = torch.cat((mixed_face_embedding, doubled_audio_embedding), dim = 1).view(batch_size * 2, -1, 144)
|
110 |
+
|
111 |
+
mixed_mouth_points_pred = face_decoder(mixed_embedding) * 255
|
112 |
+
|
113 |
+
mouth_points_pred = mixed_mouth_points_pred[:batch_size * video_length]
|
114 |
+
supervised_loss = mse_loss(mouth_points, mouth_points_pred)
|
115 |
+
|
116 |
+
shuffled_pred = mixed_mouth_points_pred[batch_size * video_length:].view(batch_size, video_length, 20, 2)
|
117 |
+
open_score_shuffled = open_level(shuffled_pred)
|
118 |
+
original_pred = mouth_points_pred.view(batch_size, video_length, 20, 2)
|
119 |
+
open_score_normal = open_level(original_pred)
|
120 |
+
|
121 |
+
#kld = torch.nn.KLDivLoss(reduction = 'batchmean')
|
122 |
+
|
123 |
+
#log_prob = torch.nn.LogSoftmax(dim=1)(open_score_shuffled.transpose(0,1))
|
124 |
+
#prob = torch.nn.Softmax(dim=1)(open_score_normal.transpose(0,1))
|
125 |
+
|
126 |
+
adversarial_loss = mse_loss(open_score_shuffled, open_score_normal)
|
127 |
+
#adversarial_loss = kld(log_prob, prob)
|
128 |
+
|
129 |
+
loss = supervised_loss + adversarial_loss
|
130 |
+
loss.backward()
|
131 |
+
|
132 |
+
train_loss += loss.item()
|
133 |
+
|
134 |
+
encoders_optimizer.step()
|
135 |
+
decoder_optimizer.step()
|
136 |
+
|
137 |
+
if batch_idx % args.log_interval == 0:
|
138 |
+
|
139 |
+
print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
|
140 |
+
epoch, batch_idx * args.batch_size, len(train_loader.dataset),
|
141 |
+
100. * batch_idx * args.batch_size / len(train_loader.dataset),
|
142 |
+
loss.item() / len(mfcc)))
|
143 |
+
|
144 |
+
print('====> Epoch: {} Average loss: {:.4f}'.format(
|
145 |
+
epoch, train_loss / len(train_loader.dataset)))
|
146 |
+
|
147 |
+
def save_image_(tensor, fp, nrow=8, padding=2,
|
148 |
+
normalize=False, range=None, scale_each=False, pad_value=0, format=None):
|
149 |
+
"""Save a given Tensor into an image file.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
153 |
+
saves the tensor as a grid of images by calling ``make_grid``.
|
154 |
+
fp - A filename(string) or file object
|
155 |
+
format(Optional): If omitted, the format to use is determined from the filename extension.
|
156 |
+
If a file object was used instead of a filename, this parameter should always be used.
|
157 |
+
**kwargs: Other arguments are documented in ``make_grid``.
|
158 |
+
"""
|
159 |
+
|
160 |
+
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
|
161 |
+
normalize=normalize, range=range, scale_each=scale_each)
|
162 |
+
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
163 |
+
ndarr = grid.mul(255).add(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
164 |
+
im = Image.fromarray(ndarr)
|
165 |
+
im.save(fp, format=format)
|
166 |
+
|
167 |
+
|
168 |
+
def image_from_tensor(tensor):
|
169 |
+
batch_size = tensor.shape[0]
|
170 |
+
img = np.zeros((batch_size,256,256,3), np.uint8) + 128
|
171 |
+
for i in range(batch_size):
|
172 |
+
for row in range(67):
|
173 |
+
x_1 = tensor[i, row, 0]
|
174 |
+
y_1 = tensor[i, row, 1]
|
175 |
+
x_2 = tensor[i, row+1, 0]
|
176 |
+
y_2 = tensor[i, row+1, 1]
|
177 |
+
cv2.circle(img[i], (x_1, y_1), 1, (0, 0, 255), -1)
|
178 |
+
cv2.circle(img[i], (x_2, y_2), 1, (0, 0, 255), -1)
|
179 |
+
cv2.line(img[i], (x_1, y_1), (x_2, y_2), (0, 0, 255), 1)
|
180 |
+
img = np.transpose(img, (0, 3, 1, 2))
|
181 |
+
return torch.tensor(img)
|
182 |
+
|
183 |
+
def test(epoch):
|
184 |
+
|
185 |
+
face_encoder.eval()
|
186 |
+
audio_encoder.eval()
|
187 |
+
face_decoder.eval()
|
188 |
+
|
189 |
+
test_loss = 0
|
190 |
+
|
191 |
+
test_loader.dataset.test_case = True
|
192 |
+
|
193 |
+
with torch.no_grad():
|
194 |
+
for batch_idx, ((keypoints, mfcc), _) in enumerate(test_loader):
|
195 |
+
|
196 |
+
batch_size = keypoints.shape[0]
|
197 |
+
|
198 |
+
keypoints = keypoints.to(device)
|
199 |
+
mfcc = mfcc.transpose(1,2).to(device).reshape(-1, 12)
|
200 |
+
|
201 |
+
face_points = keypoints[:, :, :48].reshape(-1, 96)
|
202 |
+
mouth_points = keypoints[:, : ,48:68].reshape(-1, 40)
|
203 |
+
|
204 |
+
face_embedding = face_encoder(face_points)
|
205 |
+
audio_embedding = audio_encoder(mfcc)
|
206 |
+
|
207 |
+
embedding = torch.cat((face_embedding, audio_embedding), dim = 1).view(batch_size, -1, 144)
|
208 |
+
|
209 |
+
mouth_points_pred = face_decoder(embedding) * 255
|
210 |
+
|
211 |
+
test_loss += mse_loss(mouth_points, mouth_points_pred)
|
212 |
+
|
213 |
+
# if epoch % 10 == 0:
|
214 |
+
n = min(keypoints.size(0), 8)
|
215 |
+
image_data = image_from_tensor(keypoints[:,0,:,:])
|
216 |
+
face_pred = torch.cat((face_points.view(batch_size, 29, 48, 2),
|
217 |
+
mouth_points_pred.view(batch_size, 29, 20, 2)), dim = 2)
|
218 |
+
face_pred_batch = image_from_tensor(face_pred[:,0,:,:])
|
219 |
+
comparison = torch.cat([image_data[:n], face_pred_batch[:n]], dim = 0)
|
220 |
+
|
221 |
+
if not os.path.exists('./wav_results_base'): os.mkdir('./wav_results_base')
|
222 |
+
save_image(comparison.cpu(), './wav_results_base/reconstruction_' + str(epoch) + '_' + str(round(test_loss.item(),6)) + '.png', nrow=n)
|
223 |
+
|
224 |
+
if not os.path.exists('./wav_saves_base'): os.mkdir('./wav_saves_base')
|
225 |
+
torch.save(face_encoder, './wav_saves_base/face_encoder_{}.pt'.format(epoch))
|
226 |
+
torch.save(audio_encoder, './wav_saves_base/audio_encoder_{}.pt'.format(epoch))
|
227 |
+
torch.save(face_decoder, './wav_saves_base/face_decoder_{}.pt'.format(epoch))
|
228 |
+
|
229 |
+
test_loss /= len(test_loader.dataset)
|
230 |
+
print('====> Test set loss: {:.4f}'.format(test_loss))
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
for epoch in range(0, args.epochs):
|
234 |
+
print('Epoch {} starts at {}'.format(epoch, datetime.datetime.now()))
|
235 |
+
train(epoch)
|
236 |
+
test(epoch)
|
237 |
+
print('Epoch {} ends at {}'.format(epoch, datetime.datetime.now()))
|
238 |
+
|
239 |
+
|
model.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class FaceEncoder(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(FaceEncoder, self).__init__()
|
7 |
+
|
8 |
+
self.encoder = nn.Sequential(
|
9 |
+
nn.Linear(96, 128),
|
10 |
+
nn.BatchNorm1d(128),
|
11 |
+
nn.ReLU(),
|
12 |
+
nn.Linear(128, 64),
|
13 |
+
nn.BatchNorm1d(64),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Linear(64, 16),
|
16 |
+
)
|
17 |
+
|
18 |
+
for m in self.modules():
|
19 |
+
if isinstance(m, torch.nn.Linear):
|
20 |
+
torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
|
21 |
+
elif isinstance(m, nn.BatchNorm1d):
|
22 |
+
m.weight.data.fill_(1)
|
23 |
+
m.bias.data.zero_()
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.encoder(x)
|
27 |
+
|
28 |
+
|
29 |
+
class AudioEncoder(nn.Module):
|
30 |
+
def __init__(self):
|
31 |
+
super(AudioEncoder, self).__init__()
|
32 |
+
|
33 |
+
self.encoder = nn.Sequential(
|
34 |
+
nn.Linear(12, 32),
|
35 |
+
nn.BatchNorm1d(32),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.Linear(32, 64),
|
38 |
+
nn.BatchNorm1d(64),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.Linear(64, 128),
|
41 |
+
)
|
42 |
+
|
43 |
+
for m in self.modules():
|
44 |
+
if isinstance(m, torch.nn.Linear):
|
45 |
+
torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
|
46 |
+
elif isinstance(m, nn.BatchNorm1d):
|
47 |
+
m.weight.data.fill_(1)
|
48 |
+
m.bias.data.zero_()
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return self.encoder(x)
|
52 |
+
|
53 |
+
|
54 |
+
class FaceDecoder(nn.Module):
|
55 |
+
def __init__(self):
|
56 |
+
super(FaceDecoder, self).__init__()
|
57 |
+
h_GRU = 144
|
58 |
+
self.stabilizer = nn.GRU(144, h_GRU, 2, batch_first = True, dropout = 0.2)
|
59 |
+
|
60 |
+
self.decoder = nn.Sequential(
|
61 |
+
nn.Linear(144, 256),
|
62 |
+
nn.BatchNorm1d(256),
|
63 |
+
nn.ReLU(),
|
64 |
+
nn.Linear(256, 128),
|
65 |
+
nn.BatchNorm1d(128),
|
66 |
+
nn.ReLU(),
|
67 |
+
nn.Linear(128, 40),
|
68 |
+
nn.Sigmoid(),
|
69 |
+
)
|
70 |
+
|
71 |
+
for m in self.modules():
|
72 |
+
if isinstance(m, torch.nn.Linear):
|
73 |
+
torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
|
74 |
+
elif isinstance(m, nn.BatchNorm1d):
|
75 |
+
m.weight.data.fill_(1)
|
76 |
+
m.bias.data.zero_()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x, _ = self.stabilizer(x)
|
80 |
+
return self.decoder(x.reshape(-1, 144))
|
81 |
+
|