File size: 9,177 Bytes
6957169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Transformers
import re
import torch
from torch import nn
from utils.utils import *
from typing import Optional, Tuple, Union
from transformers import MambaForCausalLM
from transformers import LlavaNextForConditionalGeneration, LlavaForConditionalGeneration

class MambaCache:
    def __init__(self, config, batch_size, dtype=torch.float16, device=None):
        self.seqlen_offset = 0
        self.dtype = dtype
        intermediate_size = config.intermediate_size
        ssm_state_size = config.state_size
        conv_kernel_size = config.conv_kernel

        self.conv_states = {
            i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
            for i in range(config.num_hidden_layers)
        }
        self.ssm_states = {
            i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
            for i in range(config.num_hidden_layers)
        }

# Dataclass & ModelOutput
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
@dataclass
class MambaCausalLMOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    cache_params: Optional[MambaCache] = None
    tor_features: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None

class MeteorMambaForCausalLM(MambaForCausalLM):
    def __init__(self, config):
        super().__init__(config)

        # initialize other projections for Vision and tor
        self.vision_proj = self.build_vision_projector(1024, self.config.hidden_size)
        self.tor_proj = self.build_vision_projector(self.config.hidden_size, 4096)
        
        # replacing embedding size of mamba with that of meteor
        self.backbone.embeddings = nn.Embedding(num_embeddings=92546,
                                                embedding_dim=self.config.hidden_size)

        # image processing variable
        self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1,-1,1,1) * 255
        self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1,-1,1,1) * 255

    def image_processor(self, images):
        norm_images = (images - self.mean.to(images.device)) / self.std.to(images.device)
        return norm_images

    @staticmethod
    def build_vision_projector(mm_hidden_size, hidden_size):
        projector_type = 'mlp2x_gelu'
        mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
        if mlp_gelu_match:
            mlp_depth = int(mlp_gelu_match.group(1))
            modules = [nn.Linear(mm_hidden_size, hidden_size)]
            for _ in range(1, mlp_depth):
                modules.append(nn.GELU())
                modules.append(nn.Linear(hidden_size, hidden_size))
            return nn.Sequential(*modules)

        raise ValueError(f'Unknown projector type: {projector_type}')

    def eval_process(
        self,
        inputs,
        tokenizer,
        device,
        img_token_number,
    ):
        batched_image=[]
        batched_qa_prompt=[]
        for _input in inputs:

            # Visualization
            # imim = _input['image'].cpu().permute(1, 2, 0)

            # adding <image> to question if not included despite being an image, and adding system prompt and <tor> prompt 
            if 'image' in _input.keys() and not '<image>' in _input['question']: _input['question'] = '<image>\n' + _input['question']

            # make question, rationale, and answer
            question = make_instruction_for_mmamba(question=_input['question'])

            # add bundle image tokens if it has <image> token
            question = add_bundle_tokens(question, '<image>', img_token_number) 

            # making batched moai prompt
            if 'image' in _input.keys() and _input['image'] != None: batched_image.append(_input['image'].to(device))
            batched_qa_prompt.append(question)

        '''For Final Outputs'''
        qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False)

        # [1] input_ids
        input_ids = qa_prompts.input_ids.to(device)

        # image or only text?
        if len(batched_image):
            # [2] pixel values
            try:
                pixel_values = self.image_processor(torch.stack(batched_image)).to(device)
                assert pixel_values.dim() == 4
            except:
                new_batched_image = []
                for batched_image_element in batched_image:
                    if batched_image_element.dim() == 3:
                        new_batched_image.append(batched_image_element.unsqueeze(0))
                    else:
                        new_batched_image.append(batched_image_element)
                pixel_values = self.image_processor(torch.cat(new_batched_image, dim=0)).to(device)

            return {"input_ids": input_ids, "image": pixel_values}
        else:
            return {"input_ids": input_ids}


    def _merge_input_embeds_with_image_features(self, image_features, inputs_embeds, input_ids):
        
        # batch index for image feature
        batch_ind_image_feature = 0

        # shape of image_features
        _, C, D = image_features.shape

        for ind, input_id in enumerate(input_ids):
            matching = torch.where(input_id==self.config.image_token_index)
            num_image_tokens_per_one_sample = len(matching[0]) // C
            inputs_embeds[ind][matching] = image_features[batch_ind_image_feature: batch_ind_image_feature+num_image_tokens_per_one_sample].view(-1, D)
            batch_ind_image_feature += num_image_tokens_per_one_sample

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        image_features: Optional[torch.FloatTensor] = None,
        cache_params: Optional[MambaCache] = None,
        # labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        **kwargs,  # for now we need this for generation
    ) -> Union[Tuple, MambaCausalLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict


        if inputs_embeds is None:
            # 1. Extra the input embeddings
            inputs_embeds = self.get_input_embeddings()(input_ids)

            # 2. Merge text and images
            if image_features is not None and input_ids.shape[1] != 1:
                image_features = self.vision_proj(image_features)
                self._merge_input_embeds_with_image_features(image_features, inputs_embeds, input_ids)

        mamba_outputs = self.backbone(
            cache_params=cache_params,
            inputs_embeds=inputs_embeds,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=use_cache,
        )
        hidden_states = mamba_outputs[0]

        # logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()

        loss = None
        # if labels is not None:
        #     # move labels to correct device to enable model parallelism
        #     labels = labels.to(logits.device)
        #     # Shift so that tokens < n predict n
        #     shift_logits = logits[..., :-1, :].contiguous()
        #     shift_labels = labels[..., 1:].contiguous()
        #     # Flatten the tokens
        #     loss_fct = nn.CrossEntropyLoss()
        #     loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # if not return_dict:
        #     output = (logits,) + mamba_outputs[1:]
        #     return ((loss,) + output) if loss is not None else output

        return MambaCausalLMOutput(
            loss=loss,
            cache_params=mamba_outputs.cache_params,
            tor_features=self.tor_proj(hidden_states[torch.where(input_ids==self.config.tor_token_index)]),
            hidden_states=mamba_outputs.hidden_states,
        )
    
    def prepare_inputs_for_generation(
        self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, image_features=None, **kwargs
    ):
        # only last token for inputs_ids if the state is passed along.
        if cache_params is not None:
            input_ids = input_ids[:, -1].unsqueeze(-1)

        if inputs_embeds is not None and cache_params is None:
            model_inputs = {"inputs_embeds": inputs_embeds, "image_features":image_features}
        else:
            model_inputs = {"input_ids": input_ids, "image_features":image_features}

        model_inputs["cache_params"] = cache_params
        return model_inputs