File size: 6,046 Bytes
3eb682b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import re

import torch
from torch import nn
from torchvision import transforms


from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import Accelerator
from models.opt import OPTModel, OPTConfig, OPTForCausalLM
import models.vit 

from PIL import Image
import json 
import numpy as np



import torch.nn.functional as F
from transformers.tokenization_utils_base import BatchEncoding

def rank_answer(model, image, question_input, answer_ids, answer_atts, k, tokenizer):

    num_ques = question_input.input_ids.size(0)
    start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token

    start_ids = torch.cat((question_input.input_ids,  start_ids), dim=1)
    attention_mask = torch.cat((question_input.attention_mask,  torch.ones((num_ques, 1)).to(question_input.attention_mask.device)), dim=1)
    
    start_input = {'input_ids': start_ids, 'attention_mask': attention_mask}
    start_input = BatchEncoding(start_input)
    
    
    
    start_output = model(image, start_input, return_dict = True, mode='evaluate')     
    
    logits = start_output.logits[:,-1,:] # first token's logit

    # topk_probs: top-k probability 
    # topk_ids: [num_question, k]        
    answer_first_token = answer_ids[:,1]
    prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 
    topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 

    # answer input: [num_question*k, answer_len]                 
    input_ids = []
    input_atts = []
    for b, topk_id in enumerate(topk_ids):
        input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
        input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
    input_ids = torch.cat(input_ids,dim=0)  
    input_atts = torch.cat(input_atts,dim=0)  

    start_ids = tile(start_ids, 0, k)
    attention_mask = tile(attention_mask, 0, k)
    image = tile(image, 0, k)
    
    
        
    
    input_ids = torch.cat((start_ids, input_ids), dim=1) # include the  <s> ?
    input_atts = torch.cat((attention_mask, input_atts), dim=1)
        
    targets_ids = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)

    
    
    # repeat encoder's output for top-k answers


    inputs = {'input_ids': input_ids, 'attention_mask': input_atts}
    inputs = BatchEncoding(inputs)
    
    output = model(image, inputs, labels = targets_ids, return_dict = True, mode='train', reduction='none')                 

    answer_loss = output.loss 
    answer_loss = answer_loss.view(input_ids.size(0),-1)

    # topk_prob: first token probability

    topk_probs = topk_probs.view(-1,1)
    log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1)

    # re-calculate log probabilities for the answer sequences using chain rule
    log_probs_sum = log_probs.sum(1)
    log_probs_sum = log_probs_sum.view(num_ques,k)

    topk_probs = F.softmax(log_probs_sum, dim=-1)
    # get top-k after re-ranking
    topk_probs, rerank_id = topk_probs.topk(k,dim=1) 
    topk_ids = torch.gather(topk_ids, 1, rerank_id)    

    return topk_ids, topk_probs
    
def tile(x, dim, n_tile):
    init_dim = x.size(dim)
    repeat_idx = [1] * x.dim()
    repeat_idx[dim] = n_tile
    x = x.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(x, dim, order_index.to(x.device))    





class VisOPT(nn.Module):
    def __init__(self,                 
                 opt_model_name = 'facebook/opt-350m',
                 vision_model_name = 'vit_base_patch16_224',
                 use_vis_prefix = True,
                 start_layer_idx = 11,
                 end_layer_idx = 23,
                 return_hidden_state_vision = True,
                 injected_hidden_states = 1,
                 
                 ):
        super().__init__()
        print("Loading VisOPT ...")
        # text
        config_opt = AutoConfig.from_pretrained(opt_model_name)
        
        config_opt.use_vis_prefix = use_vis_prefix
        config_opt.start_layer_idx = start_layer_idx
        config_opt.end_layer_idx = end_layer_idx
            
        print(config_opt)
        print("Loading: ", opt_model_name)
        self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt)
        
        # vision
        print("Loading: ", vision_model_name)
        vision_func = getattr(models.vit, vision_model_name)
        self.model_vision = vision_func(pretrained=True, return_hidden_state=return_hidden_state_vision)
        
        # connector
        self.injected_hidden_states = injected_hidden_states
        vis_dim = self.model_vision.embed_dim
        text_dim = config_opt.hidden_size
        self.connector = nn.ModuleList([nn.Linear(vis_dim, text_dim) for i in range(injected_hidden_states)])
        
        
    def forward(self, image=None, text=None, mode='generate', return_dict=True, labels=None, reduction='mean', **generation_kwargs):
        
        if image is not None:
            image_embed, image_feat = self.model_vision(image, external_features=None)

            image_feat = list(image_feat)
            image_feat = image_feat[-self.injected_hidden_states:]

            ## only cls token,  we can think of somthing else
            for i in range(1, self.injected_hidden_states + 1):
                image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1))
        else:
            image_feat = None

            
        # image_feat = None
        if mode == 'train' or mode == 'evaluate':
            text_output = self.model_text(input_ids=text.input_ids, attention_mask=text.attention_mask, return_dict=return_dict, vis_prefix=image_feat, labels = labels, reduction=reduction)
            return text_output
        elif mode == 'generate':
            print('generation')
            gen = self.model_text.generate(input_ids=text.input_ids, vis_prefix=image_feat, **generation_kwargs)        
            return gen