File size: 2,916 Bytes
b98cec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn

from attentionLayer import attentionLayer
from convLayer import ConvLayer
from torchvggish import vggish
from visualEncoder import visualFrontend, visualConv1D, visualTCN


class locoencoder(nn.Module):

    def __init__(self, cfg):
        super(locoencoder, self).__init__()
        self.cfg = cfg
        # Visual Temporal Encoder
        self.visualFrontend = visualFrontend(cfg)    # Visual Frontend
        self.visualTCN = visualTCN()    # Visual Temporal Network TCN
        self.visualConv1D = visualConv1D()    # Visual Temporal Network Conv1d

        urls = {
            'vggish':
                "https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth"
        }
        self.audioEncoder = vggish.VGGish(urls, preprocess=False, postprocess=False)
        self.audio_pool = nn.AdaptiveAvgPool1d(1)

        # Audio-visual Cross Attention
        self.crossA2V = attentionLayer(d_model=128, nhead=8)
        self.crossV2A = attentionLayer(d_model=128, nhead=8)

        # Audio-visual Self Attention

        num_layers = self.cfg.av_layers
        layers = nn.ModuleList()
        for i in range(num_layers):
            layers.append(ConvLayer(cfg))
            layers.append(attentionLayer(d_model=256, nhead=8))
        self.convAV = layers

    def forward_visual_frontend(self, x):

        B, T, W, H = x.shape
        x = x.view(B * T, 1, 1, W, H)
        x = (x / 255 - 0.4161) / 0.1688
        x = self.visualFrontend(x)
        x = x.view(B, T, 512)
        x = x.transpose(1, 2)
        x = self.visualTCN(x)
        x = self.visualConv1D(x)
        x = x.transpose(1, 2)
        return x

    def forward_audio_frontend(self, x):
        t = x.shape[-2]
        numFrames = t // 4
        pad = 8 - (t % 8)
        x = torch.nn.functional.pad(x, (0, 0, 0, pad), "constant")
        # x = x.unsqueeze(1).transpose(2, 3)
        x = self.audioEncoder(x)

        b, c, t2, freq = x.shape
        x = x.view(b * c, t2, freq)
        x = self.audio_pool(x)
        x = x.view(b, c, t2)[:, :, :numFrames]
        x = x.permute(0, 2, 1)
        return x

    def forward_cross_attention(self, x1, x2):
        x1_c = self.crossA2V(src=x1, tar=x2, adjust=self.cfg.adjust_attention)
        x2_c = self.crossV2A(src=x2, tar=x1, adjust=self.cfg.adjust_attention)
        return x1_c, x2_c

    def forward_audio_visual_backend(self, x1, x2, b=1, s=1):
        x = torch.cat((x1, x2), 2)    # B*S, T, 2C
        for i, layer in enumerate(self.convAV):
            if i % 2 == 0:
                x, b, s = layer(x, b, s)
            else:
                x = layer(src=x, tar=x)

        x = torch.reshape(x, (-1, 256))
        return x

    def forward_audio_backend(self, x):
        x = torch.reshape(x, (-1, 128))
        return x

    def forward_visual_backend(self, x):
        x = torch.reshape(x, (-1, 128))
        return x