File size: 2,537 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from fastai.vision import *

from .model_vision import BaseVision
from .model_language import BCNLanguage
from .model_semantic_visual_backbone_feature import BaseSemanticVisual_backbone_feature


class MATRN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.iter_size = ifnone(config.model_iter_size, 1)
        self.test_bh = ifnone(config.test_bh, None)
        self.max_length = config.dataset_max_length + 1  # additional stop token
        self.vision = BaseVision(config)
        self.language = BCNLanguage(config)
        self.semantic_visual = BaseSemanticVisual_backbone_feature(config)

    # def forward(self, images, *args):
    def forward(self, images, texts=None):
        v_res = self.vision(images)
        a_res = v_res
        all_l_res, all_a_res = [], []
        for _ in range(self.iter_size):
            tokens = torch.softmax(a_res['logits'], dim=-1)
            lengths = a_res['pt_lengths']
            lengths.clamp_(2, self.max_length)
            l_res = self.language(tokens, lengths)
            all_l_res.append(l_res)
            lengths_l = l_res['pt_lengths']
            lengths_l.clamp_(2, self.max_length)

            v_attn_input = v_res['attn_scores'].clone().detach()
            l_logits_input = None
            texts_input = None

            a_res = self.semantic_visual(l_res['feature'], v_res['backbone_feature'], lengths_l=lengths_l, v_attn=v_attn_input, l_logits=l_logits_input, texts=texts_input, training=self.training)

            a_v_res = {'logits': a_res['v_logits'], 'pt_lengths': a_res['pt_v_lengths'], 'loss_weight': a_res['loss_weight'],
                          'name': 'alignment'}
            all_a_res.append(a_v_res)
            a_s_res = {'logits': a_res['s_logits'], 'pt_lengths': a_res['pt_s_lengths'], 'loss_weight': a_res['loss_weight'],
                          'name': 'alignment'}
            all_a_res.append(a_s_res)
            all_a_res.append(a_res)

        if self.training:
            return all_a_res, all_l_res, v_res
        else:
            if self.test_bh is None:
                return a_res, all_l_res[-1], v_res
            elif self.test_bh == 'final':
                return a_res, all_l_res[-1], v_res
            elif self.test_bh == 'semantic':
                return all_a_res[-2], all_l_res[-1], v_res
            elif self.test_bh == 'visual':
                return all_a_res[-3], all_l_res[-1], v_res
            else:
                raise NotImplementedError