ibrahim313 commited on
Commit
b8abf33
·
verified ·
1 Parent(s): 6fea6b1

Upload 4 files

Browse files
Files changed (4) hide show
  1. models/__init__.py +2 -0
  2. models/conv.py +44 -0
  3. models/syncnet.py +66 -0
  4. models/wav2lip.py +184 -0
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
2
+ from .syncnet import SyncNet_color
models/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class nonorm_Conv2d(nn.Module):
22
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
26
+ )
27
+ self.act = nn.LeakyReLU(0.01, inplace=True)
28
+
29
+ def forward(self, x):
30
+ out = self.conv_block(x)
31
+ return self.act(out)
32
+
33
+ class Conv2dTranspose(nn.Module):
34
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.conv_block = nn.Sequential(
37
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
38
+ nn.BatchNorm2d(cout)
39
+ )
40
+ self.act = nn.ReLU()
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ return self.act(out)
models/syncnet.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from .conv import Conv2d
6
+
7
+ class SyncNet_color(nn.Module):
8
+ def __init__(self):
9
+ super(SyncNet_color, self).__init__()
10
+
11
+ self.face_encoder = nn.Sequential(
12
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
13
+
14
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
15
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
16
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
17
+
18
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
19
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
20
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
22
+
23
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
24
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
25
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
26
+
27
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
28
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
29
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
30
+
31
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
32
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
33
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
34
+
35
+ self.audio_encoder = nn.Sequential(
36
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
37
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
38
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
39
+
40
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
41
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
42
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
43
+
44
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
45
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
46
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
47
+
48
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
49
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
50
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
51
+
52
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
53
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
54
+
55
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
56
+ face_embedding = self.face_encoder(face_sequences)
57
+ audio_embedding = self.audio_encoder(audio_sequences)
58
+
59
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
60
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
61
+
62
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
63
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
64
+
65
+
66
+ return audio_embedding, face_embedding
models/wav2lip.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
7
+
8
+ class Wav2Lip(nn.Module):
9
+ def __init__(self):
10
+ super(Wav2Lip, self).__init__()
11
+
12
+ self.face_encoder_blocks = nn.ModuleList([
13
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
14
+
15
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
16
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
17
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
18
+
19
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
20
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
22
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
23
+
24
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
25
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
26
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
27
+
28
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
29
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
30
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
31
+
32
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
33
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
34
+
35
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
36
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
37
+
38
+ self.audio_encoder = nn.Sequential(
39
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
41
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
42
+
43
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+
51
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
52
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
53
+
54
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
55
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
56
+
57
+ self.face_decoder_blocks = nn.ModuleList([
58
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
59
+
60
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
62
+
63
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
64
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
65
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
66
+
67
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
69
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
70
+
71
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
73
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
74
+
75
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
76
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
77
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
78
+
79
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
80
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
81
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
82
+
83
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
84
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
85
+ nn.Sigmoid())
86
+
87
+ def forward(self, audio_sequences, face_sequences):
88
+ # audio_sequences = (B, T, 1, 80, 16)
89
+ B = audio_sequences.size(0)
90
+
91
+ input_dim_size = len(face_sequences.size())
92
+ if input_dim_size > 4:
93
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
94
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
95
+
96
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
97
+
98
+ feats = []
99
+ x = face_sequences
100
+ for f in self.face_encoder_blocks:
101
+ x = f(x)
102
+ feats.append(x)
103
+
104
+ x = audio_embedding
105
+ for f in self.face_decoder_blocks:
106
+ x = f(x)
107
+ try:
108
+ x = torch.cat((x, feats[-1]), dim=1)
109
+ except Exception as e:
110
+ print(x.size())
111
+ print(feats[-1].size())
112
+ raise e
113
+
114
+ feats.pop()
115
+
116
+ x = self.output_block(x)
117
+
118
+ if input_dim_size > 4:
119
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
120
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
121
+
122
+ else:
123
+ outputs = x
124
+
125
+ return outputs
126
+
127
+ class Wav2Lip_disc_qual(nn.Module):
128
+ def __init__(self):
129
+ super(Wav2Lip_disc_qual, self).__init__()
130
+
131
+ self.face_encoder_blocks = nn.ModuleList([
132
+ nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
133
+
134
+ nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
135
+ nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
136
+
137
+ nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
138
+ nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
139
+
140
+ nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
141
+ nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
142
+
143
+ nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
144
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
145
+
146
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
147
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
148
+
149
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
150
+ nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
151
+
152
+ self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
153
+ self.label_noise = .0
154
+
155
+ def get_lower_half(self, face_sequences):
156
+ return face_sequences[:, :, face_sequences.size(2)//2:]
157
+
158
+ def to_2d(self, face_sequences):
159
+ B = face_sequences.size(0)
160
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
161
+ return face_sequences
162
+
163
+ def perceptual_forward(self, false_face_sequences):
164
+ false_face_sequences = self.to_2d(false_face_sequences)
165
+ false_face_sequences = self.get_lower_half(false_face_sequences)
166
+
167
+ false_feats = false_face_sequences
168
+ for f in self.face_encoder_blocks:
169
+ false_feats = f(false_feats)
170
+
171
+ false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
172
+ torch.ones((len(false_feats), 1)).cuda())
173
+
174
+ return false_pred_loss
175
+
176
+ def forward(self, face_sequences):
177
+ face_sequences = self.to_2d(face_sequences)
178
+ face_sequences = self.get_lower_half(face_sequences)
179
+
180
+ x = face_sequences
181
+ for f in self.face_encoder_blocks:
182
+ x = f(x)
183
+
184
+ return self.binary_pred(x).view(len(x), -1)