Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import random | |
import re | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from mmengine.logging import MMLogger | |
from mmengine.model import BaseModel | |
from mmpretrain.registry import MODELS, TOKENIZER | |
from mmpretrain.structures import DataSample | |
class MiniGPT4(BaseModel): | |
"""The multi-modality model of MiniGPT-4. | |
The implementation of `MiniGPT-4 <https://arxiv.org/abs/2304.10592>`_. | |
Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py | |
Args: | |
vision_encoder (dict): The config for vision encoder. | |
q_former_model (dict): The config for Qformer. | |
lang_encoder (dict): The config for language model. | |
tokenizer (dict): The config for tokenizer. | |
task (str): To define the task, which control the processing of text. | |
Defaults to 'caption'. | |
freeze_vit (bool): Freeze the training of ViT. Defaults to True. | |
freeze_q_former (bool): Freeze the training of Qformer. Defaults to | |
True. | |
num_query_token (int): Number of query tokens of Qformer. Defaults to | |
32. | |
prompt_template (str): Prompt template of the model. Defaults to | |
'###Human: {} ###Assistant: '. | |
raw_prompts (list): Prompts for training. Defaults to None. | |
max_txt_len (int): Max token length while doing tokenization. Defaults | |
to 32. | |
end_sym (str): Ended symbol of the sequence. Defaults to '\n'. | |
generation_cfg (dict): The config of text generation. Defaults to | |
dict(). | |
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for | |
pre-processing data sampled by dataloader to the format accepted by | |
:meth:`forward`. Defaults to None. | |
init_cfg (dict): Initialization config dict. Defaults to None. | |
""" # noqa | |
def __init__(self, | |
vision_encoder: dict, | |
q_former_model: dict, | |
lang_encoder: dict, | |
tokenizer: dict, | |
task: str = 'caption', | |
freeze_vit: bool = True, | |
freeze_q_former: bool = True, | |
num_query_token: int = 32, | |
prompt_template: str = '###Human: {} ###Assistant: ', | |
raw_prompts: Optional[list] = None, | |
max_txt_len: int = 32, | |
end_sym: str = '\n', | |
generation_cfg: dict = dict(), | |
data_preprocessor: Optional[dict] = None, | |
init_cfg: Optional[dict] = None): | |
if data_preprocessor is None: | |
data_preprocessor = {} | |
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') | |
data_preprocessor = MODELS.build(data_preprocessor) | |
super().__init__( | |
data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
self.task = task | |
logger = MMLogger.get_current_instance() | |
# build vision model | |
vision_encoder_weight = vision_encoder.pop('pretrained', None) | |
self.vision_encoder = MODELS.build(vision_encoder) | |
self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) | |
if vision_encoder_weight is not None: | |
from mmengine.runner.checkpoint import load_checkpoint | |
load_checkpoint(self.vision_encoder, vision_encoder_weight) | |
if freeze_vit: | |
for name, param in self.ln_vision.named_parameters(): | |
param.requires_grad = False | |
self.ln_vision = self.ln_vision.eval() | |
else: | |
logger.warning('Please check `frozen_stages` in the dict of' | |
'`vision_encoder`. Also set it to be -1 if do not' | |
'freeze ViT.') | |
# build Qformer | |
q_former_model_weight = q_former_model.pop('pretrained', None) | |
self.q_former = MODELS.build(q_former_model) | |
self.q_former.cls = None | |
self.q_former.bert.embeddings.word_embeddings = None | |
self.q_former.bert.embeddings.position_embeddings = None | |
for layer in self.q_former.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.query_tokens = nn.Parameter( | |
torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) | |
self.query_tokens.data.normal_( | |
mean=0.0, std=self.q_former.config.initializer_range) | |
if q_former_model_weight is not None: | |
from mmengine.runner.checkpoint import CheckpointLoader | |
state_dict = CheckpointLoader.load_checkpoint( | |
q_former_model_weight)['state_dict'] | |
self.load_state_dict(state_dict, strict=False) | |
if freeze_q_former: | |
for name, param in self.q_former.named_parameters(): | |
param.requires_grad = False | |
self.q_former.eval() | |
self.query_tokens.requires_grad = False | |
# build language model | |
self.llama_tokenizer = TOKENIZER.build(tokenizer) | |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token | |
self.llama_model = MODELS.build(lang_encoder) | |
for name, param in self.llama_model.named_parameters(): | |
param.requires_grad = False | |
# build linear projection layer | |
self.llama_proj = nn.Linear(self.q_former.config.hidden_size, | |
self.llama_model.config.hidden_size) | |
self.max_txt_len = max_txt_len | |
self.end_sym = end_sym | |
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] | |
# set prompts | |
if raw_prompts is not None: | |
filted_prompts = [ | |
raw_prompt for raw_prompt in raw_prompts | |
if '<ImageHere>' in raw_prompt | |
] | |
self.prompt_list = [ | |
prompt_template.format(p) for p in filted_prompts | |
] | |
else: | |
self.prompt_list = [] | |
# update generation configs | |
self.generation_cfg = dict( | |
max_new_tokens=300, | |
num_beams=1, | |
do_sample=True, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
length_penalty=1.0, | |
temperature=1.0, | |
**generation_cfg) | |
if hasattr(self, 'register_load_state_dict_post_hook'): | |
self.register_load_state_dict_post_hook(self._load_llama_proj_hook) | |
def encode_img(self, | |
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""The function to encode the images.""" | |
device = images.device | |
x = self.vision_encoder(images)[0] | |
image_embeds = self.ln_vision(x).to(device) | |
image_atts = torch.ones( | |
image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.q_former.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
atts_llama = torch.ones( | |
inputs_llama.size()[:-1], dtype=torch.long).to(images.device) | |
return inputs_llama, atts_llama | |
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, | |
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""The function to wrap the image and prompt. | |
Currently, the function only supports applying one prompt to all input | |
images in the one batch. | |
Args: | |
img_embeds (torch.Tensor): The embedding of the input images. | |
atts_img (torch.Tensor): Attention map of the image embeddings. | |
prompt (str): The prompt of the batch data. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. | |
""" | |
if prompt: | |
batch_size = img_embeds.shape[0] | |
p_before, p_after = prompt.split('<ImageHere>') | |
p_before_tokens = self.llama_tokenizer( | |
p_before, return_tensors='pt', | |
add_special_tokens=False).to(img_embeds.device) | |
p_after_tokens = self.llama_tokenizer( | |
p_after, return_tensors='pt', | |
add_special_tokens=False).to(img_embeds.device) | |
p_before_embeds = self.llama_model.model.embed_tokens( | |
p_before_tokens.input_ids).expand(batch_size, -1, -1) | |
p_after_embeds = self.llama_model.model.embed_tokens( | |
p_after_tokens.input_ids).expand(batch_size, -1, -1) | |
wrapped_img_embeds = torch.cat( | |
[p_before_embeds, img_embeds, p_after_embeds], dim=1) | |
wrapped_atts_img = atts_img[:, :1].expand( | |
-1, wrapped_img_embeds.shape[1]) | |
return wrapped_img_embeds, wrapped_atts_img | |
else: | |
return img_embeds, atts_img | |
def loss(self, | |
images: torch.Tensor, | |
data_samples: Optional[List[DataSample]] = None) -> dict: | |
"""The forward function in training. | |
Args: | |
inputs (List[torch.Tensor]): The input images. | |
data_samples (List[DataSample]): All elements required | |
during the forward function. | |
Returns: | |
Dict[str, torch.Tensor]: A dictionary of loss components. | |
""" | |
img_embeds, atts_img = self.encode_img(images) | |
if self.task == 'caption' and self.prompt_list: | |
prompt = random.choice(self.prompt_list) | |
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, | |
prompt) | |
self.llama_tokenizer.padding_side = 'right' | |
text = [t + self.end_sym for t in data_samples['text_input']] | |
to_regress_tokens = self.llama_tokenizer( | |
text, | |
return_tensors='pt', | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
add_special_tokens=False).to(images.device) | |
targets = to_regress_tokens.input_ids.masked_fill( | |
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, | |
-100) | |
empty_targets = ( | |
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], | |
dtype=torch.long).to(images.device).fill_( | |
-100) # plus one for bos | |
) | |
targets = torch.cat([empty_targets, targets], dim=1) | |
batch_size = img_embeds.shape[0] | |
bos = torch.ones([batch_size, 1], | |
dtype=to_regress_tokens.input_ids.dtype, | |
device=to_regress_tokens.input_ids.device | |
) * self.llama_tokenizer.bos_token_id | |
bos_embeds = self.llama_model.model.embed_tokens(bos) | |
atts_bos = atts_img[:, :1] | |
to_regress_embeds = self.llama_model.model.embed_tokens( | |
to_regress_tokens.input_ids) | |
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], | |
dim=1) | |
attention_mask = torch.cat( | |
[atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) | |
outputs = self.llama_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
labels=targets, | |
) | |
loss = outputs.loss | |
return dict(loss=loss) | |
def predict( | |
self, | |
images: torch.Tensor, | |
data_samples: Optional[List[DataSample]] = None | |
) -> List[DataSample]: | |
with torch.no_grad(): | |
img_embeds, atts_img = self.encode_img(images) | |
if self.task == 'caption' and self.prompt_list: | |
prompt = random.choice(self.prompt_list) | |
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, | |
prompt) | |
batch_size = img_embeds.shape[0] | |
bos = torch.ones( | |
[batch_size, 1], dtype=torch.long, | |
device=img_embeds.device) * self.llama_tokenizer.bos_token_id | |
bos_embeds = self.llama_model.model.embed_tokens(bos) | |
inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) | |
outputs = self.llama_model.generate( | |
inputs_embeds=inputs_embeds, | |
eos_token_id=self.end_token_id, | |
**self.generation_cfg) | |
return self.post_process(outputs, data_samples) | |
def post_process( | |
self, outputs: torch.Tensor, | |
data_samples: Optional[List[DataSample]]) -> List[DataSample]: | |
"""Perform post process for outputs for different task. | |
Args: | |
outputs (torch.Tensor): The generated outputs. | |
data_samples (List[DataSample], optional): The annotation | |
data of every samples. | |
Returns: | |
List[DataSample]: Return list of data samples. | |
""" | |
outputs = self.llama_tokenizer.batch_decode( | |
outputs, skip_special_tokens=True) | |
if data_samples is None: | |
data_samples = [DataSample() for _ in range(len(outputs))] | |
for output, data_sample in zip(outputs, data_samples): | |
if self.task == 'caption': | |
output = output.split('###')[0] | |
output = output.split('Assistant:')[-1].strip() | |
data_sample.pred_caption = output | |
else: | |
# raw output | |
data_sample.pred_output = output | |
return data_samples | |
def forward( | |
self, | |
images: torch.Tensor, | |
data_samples: Optional[list] = None, | |
mode: str = 'predict', | |
**kwargs, | |
): | |
"""The unified entry for a forward process in both training and test. | |
The method accepts the following modes: | |
- "predict": Forward and return a list of data samples contain the | |
predict results. | |
Args: | |
images (torch.Tensor): the preprocessed image tensor of shape | |
``(N, C, H, W)``. | |
data_samples (List[DataSample], optional): The annotation data | |
of every samples. Defaults to None. | |
mode (str): Return what kind of value. Defaults to 'predict'. | |
""" | |
if mode == 'loss': | |
return self.loss(images, data_samples) | |
elif mode == 'predict': | |
return self.predict(images, data_samples, **kwargs) | |
else: | |
raise RuntimeError(f'Invalid mode "{mode}".') | |
def _load_llama_proj_hook(module, incompatible_keys): | |
"""Avoid warning missing keys except LLaMA projection keys.""" | |
proj_patterns = [ | |
'vision_encoder.*', | |
'ln_vision.*', | |
'q_former.*', | |
'query_tokens', | |
'llama_model.*', | |
] | |
for key in list(incompatible_keys.missing_keys): | |
if any(re.match(pattern, key) for pattern in proj_patterns): | |
incompatible_keys.missing_keys.remove(key) | |