File size: 7,221 Bytes
a059c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import contextlib
import logging

import torch
import torch.nn as nn
from lavis.common.registry import registry
from lavis.models import Blip2OPT, load_preprocess
from omegaconf import OmegaConf


@registry.register_model("blip2_opt_det")
class Blip2OPTDet(Blip2OPT):
    def __init__(
        self,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.opt_tokenizer.add_special_tokens({"mask_token": "<mask>"})

    def maybe_autocast(self, dtype=torch.float16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    @torch.no_grad()
    def forward(self, samples,
                use_nucleus_sampling=False,
                num_beams=5,
                max_length=30,
                min_length=1,
                top_p=0.9,
                repetition_penalty=1.0,
                length_penalty=1.0,
                num_captions=1,
                temperature=1,
                task_button=None):
        image = samples["image"]
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_opt = self.opt_proj(query_output.last_hidden_state)
        atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)

        self.opt_tokenizer.padding_side = "right"

        if "text_input" in samples.keys():
            # text = [t + "\n" for t in samples["text_input"]]
            text = [t for t in samples["text_input"]]
            opt_tokens = self.opt_tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
            ).to(image.device)
            input_ids = opt_tokens.input_ids
            attention_mask = opt_tokens.attention_mask
            output_text = text
        elif "input_ids" in samples.keys():
            input_ids = samples["input_ids"]
            attention_mask = samples["attention_mask"]
            output_text = []
        else:
            assert "prompt" in samples.keys()
            prompt = samples["prompt"]
            assert len(prompt) == image.size(0)

            opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt", padding=True).to(
                image.device
            )
            input_ids = opt_tokens.input_ids
            attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

            if use_nucleus_sampling:
                query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0)
                num_beams = 1
            else:
                query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0)

            with self.maybe_autocast():
                outputs = self.opt_model.generate(
                    input_ids=input_ids,
                    query_embeds=query_embeds,
                    attention_mask=attention_mask,
                    do_sample=use_nucleus_sampling,
                    top_p=top_p,
                    temperature=temperature,
                    num_beams=num_beams,
                    max_new_tokens=max_length,
                    min_length=min_length,
                    eos_token_id=self.eos_token_id,
                    repetition_penalty=repetition_penalty,
                    length_penalty=length_penalty,
                    num_return_sequences=num_captions,
                )

            prompt_length = opt_tokens.input_ids.shape[1]
            output_text = self.opt_tokenizer.batch_decode(
                outputs[:, prompt_length:], skip_special_tokens=True
            )
            output_text = [text.strip() for text in output_text]
            if task_button == 'Question Answering' or task_button == "Captioning":
                output_text_input = [prompt[0] + ' ' + output_text[0]]
                opt_tokens = self.opt_tokenizer(
                    output_text_input,
                    return_tensors="pt",
                    padding="longest",
                ).to(image.device)
            input_ids = opt_tokens.input_ids
            attention_mask = opt_tokens.attention_mask

        inputs_embeds = self.opt_model.model.decoder.embed_tokens(input_ids)
        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
        attention_mask = torch.cat([atts_opt, attention_mask], dim=1)
        with self.maybe_autocast():
            outputs = self.opt_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                output_hidden_states=True
            )
        n_queries = query_tokens.shape[1]
        out_logits = outputs['logits'][:, n_queries:]
        out_hidden = outputs['hidden_states'][-1][:, n_queries:]
        return out_logits, out_hidden, input_ids, output_text


def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
    model_cls = registry.get_model_class(name)

    # load model
    model = model_cls.from_pretrained(model_type=model_type)

    if is_eval:
        model.eval()

    # load preprocess
    cfg = OmegaConf.load(model_cls.default_config_path(model_type))
    if cfg is not None:
        preprocess_cfg = cfg.preprocess

        vis_processors, txt_processors = load_preprocess(preprocess_cfg)
    else:
        vis_processors, txt_processors = None, None
        logging.info(
            f"""No default preprocess for model {name} ({model_type}).
                This can happen if the model is not finetuned on downstream datasets,
                or it is not intended for direct use without finetuning.
            """
        )

    if device == "cpu" or device == torch.device("cpu"):
        model = model.float()

    return model.to(device), vis_processors, txt_processors


class BLIP2Decoder(nn.Module):
    def __init__(self, llm_name):
        super(BLIP2Decoder, self).__init__()

        self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
        if llm_name not in ['pretrain_opt2.7b', 'caption_coco_opt2.7b',
                            'pretrain_opt6.7b', 'caption_coco_opt6.7b']:
            raise ValueError(f"{llm_name} is not support yet")
        model_type = llm_name
        model, vis, _ = load_model_and_preprocess(name="blip2_opt_det",
                                                  model_type=model_type,
                                                  is_eval=True, device=self.device)
        self.model = model
        self.vis_processors = vis
        self.freeze_layers()

    def freeze_layers(self):
        for p in self.model.parameters():
            p.requires_grad = False