ngocson2002 commited on
Commit
d710c3f
·
1 Parent(s): 8eb6782

Create modeling_vivqa.py

Browse files
Files changed (1) hide show
  1. modeling_vivqa.py +206 -0
modeling_vivqa.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
2
+ from torchscale.component.multiway_network import MutliwayEmbedding
3
+ from torchscale.component.embedding import PositionalEmbedding
4
+ from torchscale.architecture.encoder import Encoder
5
+ from transformers import PreTrainedModel
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch
9
+ import math
10
+ from transformers import AutoModel
11
+ from transformers.utils.generic import ModelOutput
12
+ from dataclasses import dataclass
13
+ from typing import Optional
14
+ from efficientnet_pytorch import EfficientNet
15
+ from lavis.common.registry import registry
16
+
17
+ class BartPhoExtractor(nn.Module):
18
+ def __init__(self):
19
+ super(BartPhoExtractor, self).__init__()
20
+ self.bartpho_word = AutoModel.from_pretrained("vinai/bartpho-word")
21
+
22
+ def forward(self, input_ids, attention_mask):
23
+ last_hidden_states = self.bartpho_word(input_ids, attention_mask)
24
+ features = last_hidden_states[0]
25
+ return features
26
+
27
+ class Blip2EfficientExtractor(nn.Module):
28
+ def __init__(self):
29
+ super(Blip2EfficientExtractor, self).__init__()
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ # BLIP-2
33
+ self.model_blip2 = registry.get_model_class(name="blip2_feature_extractor").from_pretrained(model_type="pretrain").to(self.device)
34
+ if self.device == "cpu" or self.device == torch.device("cpu"):
35
+ self.model_blip2 = self.model_blip2.float()
36
+ self.model_blip2.eval()
37
+
38
+ # Efficientnet
39
+ self.model_efficient = EfficientNet.from_pretrained('efficientnet-b7').to(self.device)
40
+ self.pooling1 = nn.AdaptiveAvgPool2d((1, 32))
41
+ self.pooling2 = nn.AdaptiveAvgPool2d((1, 768))
42
+
43
+ def forward(self, images):
44
+ global_features = self.model_blip2.extract_features(samples={"image": images}, mode="image").image_embeds
45
+
46
+ local_features = self.model_efficient.extract_features(images)
47
+ local_features = self.pooling1(local_features)
48
+ local_features = local_features.permute(0, 3, 2, 1)
49
+ local_features = self.pooling2(local_features)
50
+ batch_size = images.shape[0]
51
+ local_features = local_features.reshape(batch_size, local_features.shape[1], -1)
52
+
53
+ v = torch.cat([global_features, local_features], dim=1)
54
+ return v
55
+
56
+ @dataclass
57
+ class ViVQAOutput(ModelOutput):
58
+ loss: Optional[torch.FloatTensor] = None
59
+ logits: torch.FloatTensor = None
60
+
61
+ def trunc_normal_(tensor, mean=0., std=1.):
62
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
63
+
64
+ class Pooler(nn.Module):
65
+ def __init__(self, input_features, output_features, norm_layer):
66
+ super().__init__()
67
+ self.norm = norm_layer(input_features)
68
+ self.dense = nn.Linear(input_features, output_features)
69
+ self.activation = nn.Tanh()
70
+
71
+ def forward(self, x):
72
+ cls_rep = x[:, 0, :]
73
+ cls_rep = self.norm(cls_rep)
74
+ pooled_output = self.dense(cls_rep)
75
+ pooled_output = self.activation(pooled_output)
76
+ return pooled_output
77
+
78
+ class ViVQABEiT3(PreTrainedModel):
79
+ def __init__(self, args):
80
+ super().__init__(args)
81
+ assert args.multiway
82
+ assert not args.share_encoder_input_output_embed
83
+
84
+ self.text_embed = BartPhoExtractor()
85
+
86
+ self.vision_embed = Blip2EfficientExtractor()
87
+ for param in self.vision_embed.parameters():
88
+ param.requires_grad = False
89
+
90
+ self.linear = nn.Linear(1024, 768)
91
+
92
+ # being consistent with Fairseq, which starts from 2 for position embedding
93
+ num_position_embeddings = 64
94
+ embed_positions = MutliwayEmbedding(
95
+ modules=[
96
+ PositionalEmbedding(num_position_embeddings + 2, args.encoder_embed_dim),
97
+ PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
98
+ ],
99
+ dim=1,
100
+ )
101
+ self.encoder = Encoder(
102
+ args,
103
+ embed_tokens=None,
104
+ embed_positions=embed_positions,
105
+ output_projection=None,
106
+ is_encoder_decoder=False,
107
+ )
108
+
109
+ def forward(self, textual_tokens, visual_tokens, text_padding_position):
110
+ x1 = self.vision_embed(visual_tokens)
111
+ multiway_split_position = x1.size(1)
112
+
113
+ x2 = self.text_embed(textual_tokens, text_padding_position)
114
+ x2 = self.linear(x2)
115
+
116
+ x = torch.cat([x1, x2], dim=1)
117
+ if text_padding_position is not None:
118
+ encoder_padding_mask = torch.cat(
119
+ [
120
+ torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
121
+ text_padding_position,
122
+ ],
123
+ dim=1,
124
+ )
125
+ encoder_out = self.encoder(
126
+ src_tokens=None,
127
+ encoder_padding_mask=encoder_padding_mask,
128
+ token_embeddings=x,
129
+ multiway_split_position=multiway_split_position
130
+ )
131
+ encoder_out["multiway_split_position"] = multiway_split_position
132
+ return encoder_out
133
+
134
+ class BEiT3Wrapper(PreTrainedModel):
135
+ def __init__(self, args, **kwargs):
136
+ super().__init__(args)
137
+ self.beit3 = ViVQABEiT3(args)
138
+ self.apply(self._init_weights)
139
+
140
+ def fix_init_weight(self):
141
+ def rescale(param, layer_id):
142
+ param.div_(math.sqrt(2.0 * layer_id))
143
+
144
+ for layer_id, layer in enumerate(self.blocks):
145
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
146
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
147
+
148
+ def get_num_layers(self):
149
+ return self.beit3.encoder.num_layers
150
+
151
+ @torch.jit.ignore
152
+ def no_weight_decay(self):
153
+ return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
154
+
155
+ def _init_weights(self, m):
156
+ if isinstance(m, nn.Linear):
157
+ trunc_normal_(m.weight, std=.02)
158
+ if isinstance(m, nn.Linear) and m.bias is not None:
159
+ nn.init.constant_(m.bias, 0)
160
+ elif isinstance(m, nn.LayerNorm):
161
+ nn.init.constant_(m.bias, 0)
162
+ nn.init.constant_(m.weight, 1.0)
163
+
164
+
165
+ class BEiT3ForVietnameseVisualQuestionAnswering(BEiT3Wrapper):
166
+ config_class = ViVQAConfig
167
+ def __init__(
168
+ self,
169
+ args,
170
+ num_classes=353,
171
+ **kwargs
172
+ ):
173
+ super(BEiT3ForVietnameseVisualQuestionAnswering, self).__init__(args=args)
174
+ embed_dim = args.encoder_embed_dim
175
+ self.pooler = Pooler(
176
+ input_features=embed_dim,
177
+ output_features=embed_dim,
178
+ norm_layer=nn.LayerNorm,
179
+ )
180
+ self.pooler.apply(self._init_weights)
181
+ self.head = nn.Sequential(
182
+ nn.Linear(embed_dim, embed_dim * 2),
183
+ nn.LayerNorm(embed_dim * 2),
184
+ nn.GELU(),
185
+ nn.Linear(embed_dim * 2, num_classes),
186
+ )
187
+ self.head.apply(self._init_weights)
188
+
189
+ def forward(self, image, question, padding_mask, labels=None, **kwargs):
190
+ outputs = self.beit3(
191
+ textual_tokens=question,
192
+ visual_tokens=image,
193
+ text_padding_position=padding_mask,
194
+ )
195
+ x = outputs["encoder_out"]
196
+ cls_rep = self.pooler(x)
197
+ logits = self.head(cls_rep)
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ loss = F.cross_entropy(logits, labels)
202
+
203
+ return ViVQAOutput(
204
+ loss=loss,
205
+ logits=logits,
206
+ )