Spaces:
Sleeping
Sleeping
models
Browse files- two_stream_lipnet.py +113 -0
two_stream_lipnet.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class TwoStreamLipNet(torch.nn.Module):
|
8 |
+
def __init__(self, dropout_p=0.5, coord_input_dim=40, coord_hidden_dim=128):
|
9 |
+
super(TwoStreamLipNet, self).__init__()
|
10 |
+
self.conv1 = nn.Conv3d(3, 32, (3, 5, 5), (1, 2, 2), (1, 2, 2))
|
11 |
+
self.pool1 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
12 |
+
|
13 |
+
self.conv2 = nn.Conv3d(32, 64, (3, 5, 5), (1, 1, 1), (1, 2, 2))
|
14 |
+
self.pool2 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
15 |
+
|
16 |
+
self.conv3 = nn.Conv3d(64, 96, (3, 3, 3), (1, 1, 1), (1, 1, 1))
|
17 |
+
self.pool3 = nn.MaxPool3d((1, 2, 2), (1, 2, 2))
|
18 |
+
|
19 |
+
self.gru1 = nn.GRU(96 * 4 * 8, 256, 1, bidirectional=True)
|
20 |
+
self.gru2 = nn.GRU(512, 256, 1, bidirectional=True)
|
21 |
+
|
22 |
+
self.FC = nn.Linear(512 + 2 * coord_hidden_dim, 27 + 1)
|
23 |
+
self.dropout_p = dropout_p
|
24 |
+
|
25 |
+
self.relu = nn.ReLU(inplace=True)
|
26 |
+
self.dropout = nn.Dropout(self.dropout_p)
|
27 |
+
self.dropout3d = nn.Dropout3d(self.dropout_p)
|
28 |
+
|
29 |
+
# New GRU layer for lip coordinates
|
30 |
+
self.coord_gru = nn.GRU(
|
31 |
+
coord_input_dim, coord_hidden_dim, 1, bidirectional=True
|
32 |
+
)
|
33 |
+
|
34 |
+
self._init()
|
35 |
+
|
36 |
+
def _init(self):
|
37 |
+
init.kaiming_normal_(self.conv1.weight, nonlinearity="relu")
|
38 |
+
init.constant_(self.conv1.bias, 0)
|
39 |
+
|
40 |
+
init.kaiming_normal_(self.conv2.weight, nonlinearity="relu")
|
41 |
+
init.constant_(self.conv2.bias, 0)
|
42 |
+
|
43 |
+
init.kaiming_normal_(self.conv3.weight, nonlinearity="relu")
|
44 |
+
init.constant_(self.conv3.bias, 0)
|
45 |
+
|
46 |
+
init.kaiming_normal_(self.FC.weight, nonlinearity="sigmoid")
|
47 |
+
init.constant_(self.FC.bias, 0)
|
48 |
+
|
49 |
+
for m in (self.gru1, self.gru2):
|
50 |
+
stdv = math.sqrt(2 / (96 * 3 * 6 + 256))
|
51 |
+
for i in range(0, 256 * 3, 256):
|
52 |
+
init.uniform_(
|
53 |
+
m.weight_ih_l0[i : i + 256],
|
54 |
+
-math.sqrt(3) * stdv,
|
55 |
+
math.sqrt(3) * stdv,
|
56 |
+
)
|
57 |
+
init.orthogonal_(m.weight_hh_l0[i : i + 256])
|
58 |
+
init.constant_(m.bias_ih_l0[i : i + 256], 0)
|
59 |
+
init.uniform_(
|
60 |
+
m.weight_ih_l0_reverse[i : i + 256],
|
61 |
+
-math.sqrt(3) * stdv,
|
62 |
+
math.sqrt(3) * stdv,
|
63 |
+
)
|
64 |
+
init.orthogonal_(m.weight_hh_l0_reverse[i : i + 256])
|
65 |
+
init.constant_(m.bias_ih_l0_reverse[i : i + 256], 0)
|
66 |
+
|
67 |
+
def forward(self, x, coords):
|
68 |
+
# branch 1
|
69 |
+
x = self.conv1(x)
|
70 |
+
x = self.relu(x)
|
71 |
+
x = self.dropout3d(x)
|
72 |
+
x = self.pool1(x)
|
73 |
+
|
74 |
+
x = self.conv2(x)
|
75 |
+
x = self.relu(x)
|
76 |
+
x = self.dropout3d(x)
|
77 |
+
x = self.pool2(x)
|
78 |
+
|
79 |
+
x = self.conv3(x)
|
80 |
+
x = self.relu(x)
|
81 |
+
x = self.dropout3d(x)
|
82 |
+
x = self.pool3(x)
|
83 |
+
|
84 |
+
# (B, C, T, H, W)->(T, B, C, H, W)
|
85 |
+
x = x.permute(2, 0, 1, 3, 4).contiguous()
|
86 |
+
# (B, C, T, H, W)->(T, B, C*H*W)
|
87 |
+
x = x.view(x.size(0), x.size(1), -1)
|
88 |
+
|
89 |
+
self.gru1.flatten_parameters()
|
90 |
+
self.gru2.flatten_parameters()
|
91 |
+
|
92 |
+
x, h = self.gru1(x)
|
93 |
+
x = self.dropout(x)
|
94 |
+
x, h = self.gru2(x)
|
95 |
+
x = self.dropout(x)
|
96 |
+
|
97 |
+
# branch 2
|
98 |
+
# Process lip coordinates through GRU
|
99 |
+
self.coord_gru.flatten_parameters()
|
100 |
+
|
101 |
+
# (B, T, N, C)->(T, B, C, N, C)
|
102 |
+
coords = coords.permute(1, 0, 2, 3).contiguous()
|
103 |
+
# (T, B, C, N, C)->(T, B, C, N*C)
|
104 |
+
coords = coords.view(coords.size(0), coords.size(1), -1)
|
105 |
+
coords, _ = self.coord_gru(coords)
|
106 |
+
coords = self.dropout(coords)
|
107 |
+
|
108 |
+
# combine the two branches
|
109 |
+
combined = torch.cat((x, coords), dim=2)
|
110 |
+
|
111 |
+
x = self.FC(combined)
|
112 |
+
x = x.permute(1, 0, 2).contiguous()
|
113 |
+
return x
|