File size: 1,668 Bytes
5b70063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from config_loconet import LoCoNetConfig
from transformers import PreTrainedModel
from loconet_encoder import locoencoder
from loss_multi import lossAV, lossA, lossV


class loconet(PreTrainedModel):
    config_class = LoCoNetConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = locoencoder(config)

    def forward(self, audioFeature, visualFeature, masks, labels=None):
        b, s, t = visualFeature.shape[:3]
        visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
        labels = labels.view(b * s, *labels.shape[2:])
        masks = masks.view(b * s, *masks.shape[2:])

        audioEmbed = self.model.forward_audio_frontend(audioFeature)    # B, C, T, 4
        visualEmbed = self.model.forward_visual_frontend(visualFeature)
        audioEmbed = audioEmbed.repeat(s, 1, 1)

        audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
        outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
        outsA = self.model.forward_audio_backend(audioEmbed)
        outsV = self.model.forward_visual_backend(visualEmbed)
        num_frames = masks.sum()

        if labels is not None:

            labels = labels.reshape((-1))
            masks = masks.reshape((-1))
            nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
            nlossA = self.lossA.forward(outsA, labels, masks)
            nlossV = self.lossV.forward(outsV, labels, masks)

            nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV

            return {"loss": nloss, "logits": outsAV}

        else:

            return {"logits": outsAV}