""" Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ """ Requires Transformer 4.28 and above, implementation may change according the Llama implementation """ import logging import string from packaging import version import os from omegaconf import OmegaConf import torch from torch.cuda.amp import autocast as autocast import torch.nn as nn from torch.nn.modules.module import _IncompatibleKeys from peft import ( get_peft_model, LoraConfig, TaskType, ) import transformers import random from lavis.common.registry import registry from lavis.models.base_model import BaseModel from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train, LayerNorm from lavis.models.ulip_models.ULIP_models import ULIP_PointBERT from lavis.tasks.multimodal_classification import MultimodalClassificationTask from lavis.common.utils import is_url from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel from lavis.common.dist_utils import download_cached_file from lavis.processors.blip_processors import BlipCaptionProcessor class CastOutputToFloat(nn.Sequential): def forward(self, x): return super().forward(x).to(torch.float32) @registry.register_model("blip2_vicuna_xinstruct") class Blip2VicunaXInstruct(Blip2Base): """ BLIP2 Vicuna model. Supported model types: - vicuna7b - vicuna13b Usage: >>> from lavis.models import load_model >>> model = load_model("blip2_vicuna_xinstruct", "vicuna7b") """ PRETRAINED_MODEL_CONFIG_DICT = { "vicuna7b": "configs/models/blip2/blip2_xinstruct_vicuna7b.yaml", "vicuna13b": "configs/models/blip2/blip2_xinstruct_vicuna13b.yaml", } SEQUENCIAL_ENCODERS = [ "eva_clip_g", "beats" ] SEQUENCIAL_MODALITIES = [ "video", "audio" ] MODALITY_TO_CUE = { "image": " image: ", "pc": " 3d: ", "video": " video: ", "audio": " audio: ", } def __init__( self, modalities = ["image", "pc", "audio", "video"], use_cues=True, num_query_token=32, qformer_text_input=True, llm_text_input=False, apply_lemmatizer=False, ## encoders image_model="eva_clip_g", pc_model="ulip2_pointbert", video_model="eva_clip_g", audio_model="beats", image_encoder_kwargs = {"image_size": 224, "drop_path_rate": 0, "use_grad_checkpoint": False}, pc_encoder_kwargs = {}, video_encoder_kwargs = {}, audio_encoder_kwargs = {}, image_precision="fp16", pc_precision="fp16", video_precision="fp16", audio_precision="fp16", freeze_image=True, freeze_pc=True, freeze_video=True, freeze_audio=True, ## load pretrained parameters pretrained_image_qformer=None, pretrained_pc_qformer=None, pretrained_video_qformer=None, pretrained_audio_qformer=None, load_attention_image_qformer=False, load_attention_pc_qformer=False, load_attention_video_qformer=False, load_attention_audio_qformer=False, load_qformer_type_image="", load_qformer_type_pc="", load_qformer_type_video="", load_qformer_type_audio="", load_ln_type_image="", load_ln_type_pc="", load_ln_type_video="", load_ln_type_audio="", load_projection_image=True, load_projection_pc=True, load_projection_video=True, load_projection_audio=True, load_projection_type_image="", load_projection_type_pc="", load_projection_type_video="", load_projection_type_audio="", ## llm model parameters llm_model="", lora_model="", lora=False, ## generation parameters prompt="", prefix="", postfix="", max_txt_len=128, max_output_txt_len=256, special_qformer_input_prompt=False, enumerate_inputs=False, add_space=False, remove_start=False, clean_tokenization=False, # if set to true removes whitespace from cue, and start token from prompt. ## shared Q-former setup shared_qformer=False, pretrained_shared_qformer=None, load_attention_shared_qformer=False, load_qformer_type_shared="", load_projection_shared=False, load_projection_type_shared="", encoder_projection_type_image="", encoder_projection_type_pc="", encoder_projection_type_video="", encoder_projection_type_audio="", shared_qformer_num_features=512, ## use cached features cached_audio=False, cached_image=False, cached_pc=False, cached_video=False, ## num features for modality (only needed in cached cases.) num_features_audio=768, num_features_image=1408, num_features_video=1408, num_features_pc=512, joint_video_audio=False, ## DisCRN use_caption=False, use_describe=False, ## classification setup predict_with_gen=False, format_candidates_prompt="{}", ## projection only parameters projection_only=False, projection_only_audio=False, projection_only_pc=False, projection_only_video=False, projection_only_image=False, projection_path_audio=False, projection_path_pc=False, projection_path_video=False, projection_path_image=False, proj_dim=1, ): super().__init__() transformers_version = version.parse(transformers.__version__) assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28" from transformers import LlamaTokenizer from lavis.models.blip2_models.modeling_llama import LlamaForCausalLM logging.info(f"Using modalities {modalities}") self.modalities = modalities logging.info(f"Shared Qformer is set to {shared_qformer}") self.shared_qformer = shared_qformer logging.info(f"Video-audio interleaving is set to {joint_video_audio}") self.joint_video_audio = joint_video_audio logging.info(f"Using Spacy en_core_wb_sm lemmatizer is set to {apply_lemmatizer}") self._lemmatizer = None self.apply_lemmatizer = apply_lemmatizer logging.info(f"Qformer text input {qformer_text_input} and LLM Text Input {llm_text_input}") self.qformer_text_input = qformer_text_input self.llm_text_input = llm_text_input self.projection_only = projection_only self.proj_dim = proj_dim logging.info(f"Projection only setup is set to {projection_only} with dimension {proj_dim}") for modality in self.modalities: setattr(self, f"cached_{modality}", locals()[f"cached_{modality}"]) if locals()[f"cached_{modality}"]: setattr(self, f"num_features_{modality}", locals()[f"num_features_{modality}"]) logging.info(f"Using cached {modality} representation with {getattr(self, f'num_features_{modality}')} embedding dim.") ### Initialize modality enoders ### for modality in self.modalities: modality_model = locals()[f"{modality}_model"] modality_precision = locals()[f"{modality}_precision"] modality_kwargs = locals()[f"{modality}_encoder_kwargs"] modality_kwargs['load_ln_path'] = locals()[f"pretrained_shared_qformer"] if shared_qformer else \ locals()[f"pretrained_{modality}_qformer"] setattr(self, f"projection_only_{modality}", locals()[f"projection_only_{modality}"]) setattr(self, f"projection_path_{modality}", locals()[f"projection_path_{modality}"]) modality_kwargs['load_ln_type'] = locals()[f"load_ln_type_{modality}"] if self.projection_only or locals()[f"projection_only_{modality}"]: modality_kwargs['load_ln_path'] = getattr(self, f"projection_path_{modality}") modality_kwargs['load_ln_type'] = modality setattr(self, f"load_ln_type_{modality}", locals()[f"load_ln_type_{modality}"]) setattr(self, f"pretrained_{modality}_qformer", locals()[f"pretrained_{modality}_qformer"]) modality_encoder, modality_ln = getattr(self, f"init_{modality}_encoder")( modality_model, precision=modality_precision, **modality_kwargs ) freeze_modality = locals()[f"freeze_{modality}"] cached_modality = locals()[f"cached_{modality}"] if cached_modality: setattr(self, f"{modality}_encoder", modality_encoder) setattr(self, f"{modality}_ln", modality_ln) continue if freeze_modality: for name, param in modality_encoder.named_parameters(): param.requires_grad = False modality_encoder = modality_encoder.eval() modality_encoder.train = disabled_train logging.info(f"freeze {modality} encoder") setattr(self, f"{modality}_encoder", modality_encoder) setattr(self, f"{modality}_ln", modality_ln) ##### Init QFormers #### self.tokenizer = self.init_tokenizer(truncation_side="left") # 30523 tokens. self.num_query_token = num_query_token if self.shared_qformer: logging.info(f"Initializing shared QFormer with {shared_qformer_num_features} \ number of features and query tokens of length {num_query_token}") setattr(self, f"pretrained_shared_qformer", pretrained_shared_qformer) setattr(self, f"load_qformer_type_shared", load_qformer_type_shared) self.shared_Qformer, self.shared_query_tokens = self.init_Qformer( num_query_token, shared_qformer_num_features, pretrained_qformer=pretrained_shared_qformer, load_attention=load_attention_shared_qformer, load_qformer_type=load_qformer_type_shared ) if not qformer_text_input: self.shared_Qformer.bert.embeddings.word_embeddings = None self.shared_Qformer.bert.embeddings.position_embeddings = None for layer in self.shared_Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None else: self.shared_Qformer.resize_token_embeddings(len(self.tokenizer)) self.shared_Qformer.cls = None # Map shared Qformer by reference to all modalities. for modality in self.modalities: setattr(self, f"{modality}_Qformer", self.shared_Qformer) setattr(self, f"{modality}_query_tokens", self.shared_query_tokens) encoder_proj_type=locals()[f"encoder_projection_type_{modality}"] setattr(self, f"encoder_projection_type_{modality}", locals()[f"encoder_projection_type_{modality}"]) modality_encoder_features = getattr(self, f"{modality}_encoder").num_features setattr(self, f"{modality}_encoder_projection", self.init_encoder_projection(modality_encoder_features, shared_qformer_num_features, pretrained_shared_qformer, encoder_proj_type)) else: for modality in self.modalities: if getattr(self,f"cached_{modality}"): modality_num_features = locals()[f"num_features_{modality}"] else: modality_num_features = getattr(self, f"{modality}_encoder").num_features setattr(self, f"pretrained_{modality}_qformer", locals()[f"pretrained_{modality}_qformer"]) setattr(self, f"load_qformer_type_{modality}", locals()[f"load_qformer_type_{modality}"]) setattr(self, f"projection_only_{modality}", locals()[f"projection_only_{modality}"]) setattr(self, f"projection_path_{modality}", locals()[f"projection_path_{modality}"]) if self.projection_only or locals()[f"projection_only_{modality}"]: logging.info(f"Initializing {modality} projection") setattr(self, f"pretrained_{modality}_qformer", False) if modality == 'audio' and proj_dim == 1: modality_num_features *= 256 # hack to get full beats embedding. define better. modality_projection = self.init_vicuna_projection( modality_num_features, num_query_token*proj_dim, load_projection_path=getattr(self, f"projection_path_{modality}"), load_projection_type=modality, projection_key=f"{modality}_projection" ) setattr(self, f"{modality}_projection", modality_projection) else: logging.info(f"Initializing {modality} QFormer and query tokens of length {num_query_token}") modality_qformer, modality_query_tokens = self.init_Qformer( num_query_token, modality_num_features, pretrained_qformer=locals()[f"pretrained_{modality}_qformer"], load_attention=locals()[f"load_attention_{modality}_qformer"], load_qformer_type=locals()[f"load_qformer_type_{modality}"] ) if not qformer_text_input: modality_qformer.bert.embeddings.word_embeddings = None modality_qformer.bert.embeddings.position_embeddings = None for layer in modality_qformer.bert.encoder.layer: layer.output = None layer.intermediate = None else: modality_qformer.resize_token_embeddings(len(self.tokenizer)) modality_qformer.cls = None setattr(self, f"{modality}_Qformer", modality_qformer) setattr(self, f"{modality}_query_tokens", modality_query_tokens) ### Set up LLM ### logging.info(f"Setting up llm model {llm_model}") self.lora = lora print(f"Lora is set to {self.lora}") self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left") self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.llm_tokenizer.add_special_tokens({'bos_token': ''}) self.llm_tokenizer.add_special_tokens({'eos_token': ''}) self.llm_tokenizer.add_special_tokens({'unk_token': ''}) if self.lora: # https://github.com/lxe/llama-peft-tuner/blob/main/finetune_peft.py self.llm_model = LlamaForCausalLM.from_pretrained( llm_model, load_in_8bit=True, torch_dtype=torch.float16 ) self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) self.peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=['q_proj', 'v_proj'] ) self.llm_model.gradient_checkpointing_enable() self.llm_model.enable_input_require_grads() self.llm_model.lm_head = CastOutputToFloat(self.llm_model.lm_head) self.llm_model.config.use_cache = False # silence the warnings. Please re-enable for inference! self.llm_hidden_size = self.llm_model.config.hidden_size self.llm_model = get_peft_model(self.llm_model, self.peft_config) self.lora_model = lora_model else: self.llm_model = LlamaForCausalLM.from_pretrained( llm_model, torch_dtype=torch.float16 ) self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) self.llm_hidden_size = self.llm_model.config.hidden_size for name, param in self.llm_model.named_parameters(): param.requires_grad = False # Load LM projections if self.shared_qformer and load_projection_shared: qformer = getattr(self, f"shared_Qformer") load_projection_path = locals()[f"load_projection_shared"] if load_projection_path: load_projection_path = locals()[f"pretrained_shared_qformer"] load_projection_type = locals()[f"load_projection_type_shared"] setattr(self, f"load_projection_shared", load_projection_path) setattr(self, f"load_projection_type_shared", locals()[f"load_projection_type_shared"]) logging.info(f"Loading shared Qformer projection.") proj = self.init_vicuna_projection( qformer.config.hidden_size, self.llm_hidden_size, load_projection_path=load_projection_path ) # Map projection by reference to all modalities. for modality in self.modalities: setattr(self, f"{modality}_llm_proj", proj) else: for modality in self.modalities: load_projection_path = locals()[f"load_projection_{modality}"] if load_projection_path == True: load_projection_path = locals()[f"pretrained_{modality}_qformer"] load_projection_type = locals()[f"load_projection_type_{modality}"] setattr(self, f"load_projection_{modality}", load_projection_path) setattr(self, f"load_projection_type_{modality}", load_projection_type) if self.projection_only or getattr(self, f"projection_only_{modality}"): proj = self.init_vicuna_projection( self.num_query_token if proj_dim==1 else proj_dim, self.num_query_token*self.llm_hidden_size if proj_dim==1 else self.llm_hidden_size, load_projection_path=getattr(self, f"projection_path_{modality}"), load_projection_type=load_projection_type, ) else: qformer = getattr(self, f"{modality}_Qformer") proj = self.init_vicuna_projection( qformer.config.hidden_size, self.llm_hidden_size, load_projection_path=load_projection_path, load_projection_type=load_projection_type ) setattr(self, f"{modality}_llm_proj", proj) self.clean_tokenization = clean_tokenization logging.info(f"Clean tokenization is set to {self.clean_tokenization}") self.max_txt_len = max_txt_len self.max_output_txt_len = max_output_txt_len self.prompt = prompt self.prefix = prefix if self.prefix: self.tokenized_prefix = self.llm_tokenizer(self.prefix, return_tensors="pt") self.postfix = postfix if type(self.postfix) != str or not self.postfix: self.postfix = "" logging.info(f"Using prefix set to {self.prefix} and postfix set to {self.postfix}.") self.use_cues = use_cues logging.info(f"Using cues set to {self.use_cues}.") if self.use_cues: logging.info(f"Modality to cue {Blip2VicunaXInstruct.MODALITY_TO_CUE}") self.tokenized_cue = {} self.emb_cue = {} self.att_cue = {} for modality in self.modalities: if self.clean_tokenization: Blip2VicunaXInstruct.MODALITY_TO_CUE[modality] = Blip2VicunaXInstruct.MODALITY_TO_CUE[modality].lstrip() self.tokenized_cue[modality] = self.llm_tokenizer(Blip2VicunaXInstruct.MODALITY_TO_CUE[modality], return_tensors="pt") self.emb_cue[modality] = self.llm_model.get_input_embeddings()(self.tokenized_cue[modality].input_ids.to(self.device)) self.att_cue[modality] = self.tokenized_cue[modality].attention_mask.to(self.device) ## generation parameters self.use_caption=use_caption self.use_describe=use_describe self.predict_with_gen=predict_with_gen self.format_candidates_prompt=format_candidates_prompt self.special_qformer_input_prompt=special_qformer_input_prompt self.enumerate_inputs=enumerate_inputs self.add_space=add_space self.remove_start=remove_start if self.projection_only: self.qformer_text_input=False def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): input_part_targets_len = [] llm_tokens = {"input_ids": [], "attention_mask": []} for i in range(input_ids.size(0)): this_input_ones = input_atts[i].sum() input_part_targets_len.append(this_input_ones) llm_tokens['input_ids'].append( torch.cat([ input_ids[i][:this_input_ones], output_ids[i][1:], input_ids[i][this_input_ones:] ]) ) llm_tokens['attention_mask'].append( torch.cat([ input_atts[i][:this_input_ones], output_atts[i][1:], input_atts[i][this_input_ones:] ]) ) llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask']) return llm_tokens, input_part_targets_len def forward(self, samples): # print('-----------------') # print(samples["text_input"]) # print(samples["text_output"]) # print('-----------------') if samples == None or samples == {} or not any([modality in samples for modality in self.modalities]): return {"loss": torch.tensor(0.0)} random.shuffle(self.modalities) curr_modalities = [modality for modality in self.modalities if modality in samples] excess_modalities = [modality for modality in self.modalities if modality not in curr_modalities] # disable gradient in excess modalities dummy_loss = 0. for modality in excess_modalities: if self.shared_qformer: for name, param in getattr(self, f"{modality}_encoder_projection").named_parameters(): # param.requires_grad = False dummy_loss += param.sum()*0. for name, param in getattr(self,f"{modality}_ln").named_parameters(): # param.requires_grad = False dummy_loss += param.sum()*0. dummy_loss += getattr(self, f"{modality}_query_tokens").sum()*0. for name, param in getattr(self, f'{modality}_Qformer').named_parameters(): # param.requires_grad = False dummy_loss += param.sum()*0. for name, param in getattr(self, f'{modality}_llm_proj').named_parameters(): # param.requires_grad = False dummy_loss += param.sum()*0. embeds = {} query_tokens = {} data_atts = {} for modality in curr_modalities: data = samples[modality] ln = getattr(self, f"{modality}_ln") encoder = getattr(self, f"{modality}_encoder") if modality == "video" and self.video_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(2)): this_frame = data[:,:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][-1] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) # B, Token Size, LM EMB if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) elif modality == 'audio' and self.audio_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(1)): this_frame = data[:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) # B, Token Size, LM EMB if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) else: with self.maybe_autocast(): embeds[modality] = ln(encoder(data)) if len(embeds[modality].size()) == 2: # B, C, D embeds[modality] = embeds[modality].unsqueeze(1) # B, C if self.shared_qformer: embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size, LM EMB if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(embeds[modality].shape[0], -1, -1) query_outputs = {} if self.qformer_text_input: text_Qformer = self.tokenizer( samples["text_input"] if not self.special_qformer_input_prompt else self.special_qformer_input_prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) Qformer_atts = {} query_atts = {} for modality in curr_modalities: # B, Token Size query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size + Inp Size Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num)]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids.repeat(num, 1), attention_mask=Qformer_atts[modality].repeat(num, 1), query_embeds=query_tokens[modality].repeat(num, 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids, attention_mask=Qformer_atts[modality], query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) else: for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num)]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality].repeat(num, 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: bs = embeds[modality].shape[0] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), # pc data is floa16. encoder_attention_mask=data_atts[modality], return_dict=True, ) inputs_llm = {} atts_llm = {} for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: # num*bs, num query tokens, llm emb size if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num, self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]).reshape(bs*num, self.num_query_token, -1) inputs_llm[modality] = inputs_llm[modality].reshape(bs, num, self.num_query_token, -1).view(bs, num*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) # bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size inputs_llm[modality] = inputs_llm[modality].reshape(bs, num, self.num_query_token, -1).view(bs, num*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim == 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) self.llm_tokenizer.padding_side = "right" self.llm_tokenizer.truncation_side = 'left' if self.llm_text_input: text_input_tokens = self.llm_tokenizer( [f"{t}{self.postfix}" for t in samples['text_input']] if self.postfix else samples['text_input'], return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens= not self.clean_tokenization ).to(self.device) self.llm_tokenizer.truncation_side = 'right' text_output_tokens = self.llm_tokenizer( [t + self.llm_tokenizer.eos_token for t in samples['text_output']], return_tensors="pt", padding="longest", truncation=True, max_length=self.max_output_txt_len, ).to(self.device) if self.llm_text_input: llm_tokens, input_part_targets_len = self.concat_text_input_output( text_input_tokens.input_ids, text_input_tokens.attention_mask, text_output_tokens.input_ids, text_output_tokens.attention_mask, ) else: llm_tokens = text_output_tokens input_part_targets_len = [0 for _ in range(llm_tokens['input_ids'].shape[0])] # input length is 0 # do not apply loss to the padding targets = llm_tokens['input_ids'].masked_fill( llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100 ) # do not apply loss to the text input (i.e., instruction) for i, l in enumerate(input_part_targets_len): targets[i][:l] = -100 inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids']) bs = inputs_embeds.shape[0] att_list = [] inp_list = [] if self.prefix: att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] for modality in curr_modalities: if self.use_cues: if self.prefix and self.clean_tokenization: att_list.extend([self.att_cue[modality][:,1:].repeat(bs, 1).to(self.device), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality][:,1:].repeat(bs, 1, 1).to(self.device), inputs_llm[modality]]) att_list.extend([self.att_cue[modality].repeat(bs, 1).to(self.device), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality].repeat(bs, 1, 1).to(self.device), inputs_llm[modality]]) else: att_list.extend([atts_llm[modality]]) inp_list.extend([inputs_llm[modality]]) # do not apply loss to the query tokens empty_targets = ( torch.ones(torch.cat(att_list, dim=1).size(), dtype=torch.long).to(self.device).fill_(-100) ) # append llm prompt + output to queries att_list.append(llm_tokens['attention_mask']) inp_list.append(inputs_embeds) inputs_embeds = torch.cat(inp_list, dim=1) attention_mask = torch.cat(att_list, dim=1) targets = torch.cat([empty_targets, targets], dim=1) with self.maybe_autocast(): outputs = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) loss = dummy_loss+outputs.loss return {"loss": loss} def init_image_encoder(self, model_name, precision, **kwargs): load_ln_path = kwargs['load_ln_path'] del kwargs['load_ln_path'] load_ln_type=kwargs['load_ln_type'] del kwargs['load_ln_type'] encoder, _ = super().init_vision_encoder(model_name, kwargs['image_size'], kwargs['drop_path_rate'], kwargs['use_grad_checkpoint'], precision) ln = self.init_ln(encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) return encoder, ln def init_pc_encoder( self, model_name, precision, **kwargs ): assert model_name in [ "ulip1_pointbert", "ulip2_pointbert", "ulip_shapenet", "ulip_objaverse", "objaverse_shapenet_k_1", "ulip2_scaledup" "" ], "pc model must be in [ulip1_pointbert,ulip2_pointbert]" load_ln_path = kwargs['load_ln_path'] del kwargs['load_ln_path'] load_ln_type=kwargs['load_ln_type'] del kwargs['load_ln_type'] if model_name == "ulip2_pointbert": pc_encoder = ULIP_PointBERT(ulip_v=2) elif model_name == "ulip_shapenet": pc_encoder = ULIP_PointBERT(ulip_v="shapenet") elif model_name == "ulip_objaverse": pc_encoder = ULIP_PointBERT(ulip_v="objaverse_k_1") elif model_name == "objaverse_shapenet_k_1": pc_encoder = ULIP_PointBERT(ulip_v="objaverse_shapenet_k_1") elif model_name == "ulip2_scaledup": pc_encoder = ULIP_PointBERT(ulip_v="ulip2_scaledup") else: pc_encoder = ULIP_PointBERT(ulip_v=1) ln_pc = self.init_ln(pc_encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) self.pc_enc_name = model_name return pc_encoder, ln_pc def init_video_encoder( self, model_name, precision, **kwargs ): assert model_name in [ "eva_clip_g", "eva2_clip_L", "clip_L", ], "video_model must be in [eva_clip_g, eva2_clip_L, clip_L]" if model_name in ["eva_clip_g","eva2_clip_L","clip_L",]: video_encoder, ln_video = self.init_image_encoder( model_name, precision=precision, **kwargs ) self.video_enc_name = model_name return video_encoder, ln_video def init_audio_encoder( self, model_name, precision, **kwargs ): assert model_name in [ 'beats' ], "audio model must be in [beats]" load_ln_path = kwargs['load_ln_path'] del kwargs['load_ln_path'] load_ln_type=kwargs['load_ln_type'] del kwargs['load_ln_type'] if "beats" in model_name: from lavis.models.beats_encoder import BeatsEncoder if self.cached_audio: audio_encoder = lambda x: x ln_audio = self.init_ln(768, load_ln_path=load_ln_path, load_ln_type=load_ln_type) else: audio_encoder = BeatsEncoder(**kwargs) if not self.cached_audio: ln_audio = self.init_ln(audio_encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) self.audio_enc_name = model_name return audio_encoder, ln_audio @torch.no_grad() def get_query_outputs( self, samples ): if samples == None or samples == {}: return curr_modalities = [modality for modality in self.modalities if modality in samples] if len(curr_modalities) == 0: print("Model modalities do not match sample modalities.") return # get batch size bs = None for modality in curr_modalities: data = samples[modality] bs = data.size(0) break if "prompt" in samples.keys(): prompt = samples["prompt"] elif "text_input" in samples.keys(): prompt = samples["text_input"] else: prompt = self.prompt if isinstance(prompt, str): prompt = [prompt] * bs else: assert len(prompt) == bs, "The number of prompts must be equal to the batch size." embeds = {} query_tokens = {} data_atts = {} for modality in curr_modalities: data = samples[modality] ln = getattr(self, f"{modality}_ln") encoder = getattr(self, f"{modality}_encoder") if modality == "video" and self.video_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(2)): this_frame = data[:,:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][-1] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) # B, Token Size, LM EMB query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) elif modality == 'audio' and self.audio_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(1)): this_frame = data[:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) # B, Token Size, LM EMB if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) else: with self.maybe_autocast(): embeds[modality] = ln(encoder(data)) if len(embeds[modality].size()) == 2: # B, C, D embeds[modality] = embeds[modality].unsqueeze(1) # B, C if self.shared_qformer: embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size, LM EMB if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(embeds[modality].shape[0], -1, -1) query_outputs = {} if self.qformer_text_input: text_Qformer = self.tokenizer( prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) Qformer_atts = {} query_atts = {} num = {} for modality in curr_modalities: # B, Token Size if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size + Inp Size Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids.repeat(num[modality], 1), attention_mask=Qformer_atts[modality].repeat(num[modality], 1), query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids, attention_mask=Qformer_atts[modality], query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) else: num = {} for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), # pc data is floa16. encoder_attention_mask=data_atts[modality], return_dict=True, ) for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num, self.num_query_token, -1) else: query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]).reshape(bs*num, self.num_query_token, -1) query_outputs[f'llm_proj_{modality}'] = query_outputs[f'llm_proj_{modality}'].reshape(bs, num[modality], self.num_query_token, -1).contiguous().view(bs, num[modality]*self.num_query_token, -1) query_outputs[modality] = query_outputs[modality].view(bs, num[modality]*self.num_query_token, -1) else: query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]).contiguous().view(bs, num[modality]*self.num_query_token, -1) query_outputs[modality] = query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:].contiguous().view(bs, num[modality]*self.num_query_token, -1) else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim == 1: query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) else: query_outputs[f'llm_proj_{modality}']= getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) else: query_outputs[modality] = query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:] query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]) for modality in curr_modalities: query_outputs[f'embeds_{modality}'] = embeds[modality] return query_outputs @torch.no_grad() def generate( self, samples, use_nucleus_sampling=False, num_beams=5, max_length=256, min_length=1, top_p=0.9, repetition_penalty=1.5, length_penalty=1, num_captions=1, temperature=1, special_qformer_input_prompt=False ): self.llm_tokenizer.padding_side = "left" if samples == None or samples == {}: return if 'modalities' in samples: curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] elif self.joint_video_audio: curr_modalities = ["video", "audio"] else: curr_modalities = [modality for modality in self.modalities if modality in samples] if len(curr_modalities) == 0: print("Model modalities do not match sample modalities.") return # get batch size bs = None for modality in curr_modalities: data = samples[modality] if isinstance(data, torch.Tensor): bs = data.size(0) else: bs = len(data) break if "prompt" in samples.keys(): prompt = samples["prompt"] elif self.prompt and 'text_input' in samples and '{}' in self.prompt: prompt = [self.prompt.format(t) for t in samples["text_input"]] elif "text_input" in samples.keys(): prompt = samples["text_input"] else: prompt = self.prompt if isinstance(prompt, str): prompt = [prompt] * bs else: assert len(prompt) == bs, "The number of prompts must be equal to the batch size." # For TextCaps if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] if 'discrn' in samples and self.use_caption: ## discriminatory reasoning if self.postfix: prompt = [f'{t}{self.postfix}' for t in prompt] if self.enumerate_inputs: prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] else: prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] llm_tokens = self.llm_tokenizer( prompt, padding="longest", return_tensors="pt" ).to(self.device) inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids) with self.maybe_autocast(): outputs = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=llm_tokens.attention_mask, do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=num_beams, max_length=max_length, min_length=min_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, ) outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) output_text = [o.strip() for o in output_text] # print(output) return output_text query_tokens = {} for modality in curr_modalities: if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(bs, -1, -1) if self.qformer_text_input: if self.special_qformer_input_prompt or special_qformer_input_prompt: qformer_prompt = special_qformer_input_prompt if special_qformer_input_prompt else self.special_qformer_input_prompt qformer_prompt = [qformer_prompt] * len(prompt) if "text_input" in samples.keys(): if type(samples["text_input"][0]) == list: qformer_prompt = [qformer_prompt[i].format(*samples["text_input"][i]) for i in range(len(qformer_prompt))] else: qformer_prompt = [qformer_prompt[i].format(samples["text_input"][i]) for i in range(len(qformer_prompt))] text_Qformer = self.tokenizer( qformer_prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) elif self.use_describe: modality2prompt = { "video": "a short description of the video", "audio": "an audio that shows", "image": "a short image caption", "pc": "a 3d model of" } qformer_prompt = [modality2prompt[modality] for _ in samples['text_input']] text_Qformer = self.tokenizer( qformer_prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) else: text_Qformer = self.tokenizer( prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) Qformer_atts = {} query_atts = {} for modality in curr_modalities: if not getattr(self, f"projection_only_{modality}"): # B, Token Size query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size + Inp Size Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) embeds = {} data_atts = {} for modality in curr_modalities: data = samples[modality] ln = getattr(self, f"{modality}_ln") encoder = getattr(self, f"{modality}_encoder") if modality == "video" and "clip" in self.video_enc_name: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(2)): this_frame = data[:,:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) elif modality == 'audio' and 'beats' in self.audio_enc_name: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(1)): this_frame = data[:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) else: with self.maybe_autocast(): embeds[modality] = ln(encoder(data)) if len(embeds[modality].size()) == 2: embeds[modality] = embeds[modality].unsqueeze(1) if self.shared_qformer: with self.maybe_autocast(): embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) query_outputs = {} num = {} if self.qformer_text_input: for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids.repeat(num[modality], 1), attention_mask=Qformer_atts[modality].repeat(num[modality], 1), query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: bs = embeds[modality].shape[0] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids, attention_mask=Qformer_atts[modality], query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) else: for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: bs = embeds[modality].shape[0] if self.projection_only or getattr(self, f"projection_only_{modality}"): with self.maybe_autocast(): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) inputs_llm = {} atts_llm = {} enumeration = {} for i,modality in enumerate(curr_modalities): if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num[modality], self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs*num, self.num_query_token, -1)) inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue # num*bs, num query tokens, llm emb size inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) # bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim == 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) if self.enumerate_inputs: enumeration[modality] = self.llm_tokenizer( [f"{'' if i == 0 else ' '}({chr(97+i)}) " for _ in prompt], return_tensors="pt", add_special_tokens=False if (i!= 0 or self.prefix) else True ).to(self.device) ## remove trailing whitespace prompt = [p.strip() for p in prompt] if 'dialog' in samples: llm_tokens = self.llm_tokenizer( [f"{d} {p}" if d else p for d, p in zip(samples['dialog'], prompt)], padding="longest", return_tensors="pt", add_special_tokens= not self.clean_tokenization ).to(self.device) else: llm_tokens = self.llm_tokenizer( [f"{p}{self.postfix}" for p in prompt] if self.postfix else prompt, padding="longest", return_tensors="pt", add_special_tokens= not self.clean_tokenization ).to(self.device) bs = llm_tokens.input_ids.shape[0] att_list = [] inp_list = [] if self.prefix: att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] if self.joint_video_audio: for pos in range(num['video']): if self.enumerate_inputs: enumeration_pos = self.llm_tokenizer( [f"{'' if pos == 0 else ' '}({chr(97+pos)}) " for _ in prompt], return_tensors="pt", add_special_tokens=False if (pos!= 0 or self.prefix) else True ).to(self.device) enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration_pos.input_ids) enumeration_atts_llm = enumeration_pos.attention_mask.to(self.device) inp_list.extend([enumeration_inputs_llm]) att_list.extend([enumeration_atts_llm]) if self.use_cues: for modality in ['video', 'audio']: if self.clean_tokenization: if self.prefix or pos > 1 or self.enumerate_inputs or modality == 'audio': att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) continue att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) else: att_list.extend([atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) inp_list.extend([inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) else: for modality in curr_modalities: if self.enumerate_inputs: enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration[modality].input_ids.to(self.device)) enumeration_atts_llm = enumeration[modality].attention_mask.to(self.device) inp_list.extend([enumeration_inputs_llm]) att_list.extend([enumeration_atts_llm]) if self.use_cues: if self.clean_tokenization or self.remove_start: if (modality==curr_modalities[0] and not (self.prefix or self.enumerate_inputs)): att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([atts_llm[modality]]) inp_list.extend([inputs_llm[modality]]) if self.add_space: space_tok = self.llm_tokenizer( [f" " for _ in prompt], return_tensors="pt", add_special_tokens=False ) space_inputs_llm = self.llm_model.get_input_embeddings()(space_tok.input_ids.to(self.device)) space_atts_llm = space_tok.attention_mask.to(self.device) inp_list.extend([space_inputs_llm]) att_list.extend([space_atts_llm]) att_list.append(llm_tokens.attention_mask) inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids) inp_list.append(inputs_embeds) attention_mask = torch.cat(att_list, dim=1) inputs_embeds = torch.cat(inp_list, dim=1) with self.maybe_autocast(): outputs = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=num_beams, max_length=max_length, min_length=min_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, ) outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) output_text = [o.strip() for o in output_text] return output_text @torch.no_grad() def predict_answers( self, samples, num_beams=5, inference_method="generate", max_len=10, min_len=1, num_ans_candidates=128, answer_list=None, prompt="", length_penalty=-1, **kwargs ): if samples == None or samples == {}: return None # get batch size bs = None if 'modalities' in samples: curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] else: curr_modalities = [modality for modality in self.modalities if modality in samples] for modality in curr_modalities: data = samples[modality] if isinstance(data, torch.Tensor): bs = data.size(0) else: bs = len(data) break if "text_input" not in samples: samples["text_input"] = self.prompt if isinstance(samples["text_input"], str): samples["text_input"] = [samples["text_input"]] * bs text_input = samples['text_input'] if not prompt and self.prompt: prompt=self.prompt if prompt: if prompt.count("{}") == 2: if 'ocr_tokens' in samples: text_input = [ prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i]) for i in range(len(samples["text_input"]))] elif 'choices' in samples: text_input = [] for i in range(len(samples["text_input"])): this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])] this_choices = " ".join(this_choices) text_input.append(prompt.format(samples["text_input"][i], this_choices)) else: text_input = [prompt.format(question) for question in samples["text_input"]] samples["prompt"] = text_input if 'discrn' in samples and self.use_caption: ## discriminatory reasoning self.llm_tokenizer.padding_side = "left" text_input = samples['text_input'] if 'prompt' not in samples else samples['prompt'] if self.postfix: text_input = [f'{t}{self.postfix}' for t in text_input] if self.enumerate_inputs: prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {text_input[i]}' for i in range(bs)] else: prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {text_input[i]}' for i in range(bs)] llm_tokens = self.llm_tokenizer( prompt, padding="longest", return_tensors="pt" ).to(self.device) with self.maybe_autocast(): outputs = self.llm_model.generate( inputs_embeds=self.llm_model.get_input_embeddings()(llm_tokens.input_ids), attention_mask=llm_tokens.attention_mask, do_sample=False, num_beams=num_beams, max_length=max_len, min_length=min_len, repetition_penalty=1.5, # eos_token_id=self.eos_token_id, length_penalty=length_penalty, ) outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) return output_text output_text = self.generate( samples, num_beams=num_beams, max_length=max_len, min_length=min_len, length_penalty=length_penalty ) if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: output_text = self._lemmatize(output_text) #vizwiz output_text = [o if o != "" else "unanswerable" for o in output_text] return output_text def predict( self, samples, candidates=None, n_segments=1, max_length=10, min_length=1, length_penalty=-1., special_qformer_input_prompt=False ): self.llm_tokenizer.padding_side = "left" if candidates == None: candidates = self.candidates else: self.candidates = candidates # for the output targets. if self.predict_with_gen: output = self.generate(samples,max_length=max_length,min_length=min_length,length_penalty=length_penalty) result = [] for text in output: text = BlipCaptionProcessor().pre_caption(text) pred_label = "" # default to an empty string for cand in candidates: cand = BlipCaptionProcessor().pre_caption(cand) if cand in text.split(" "): pred_label = cand break # stop as soon as we find a match result.append(pred_label) return {"predictions":result, "target": samples["label"]} # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one if type(candidates[0]) == list: results = [] for i in range(samples["image"].size(0)): this_sample = { "image": samples["image"][i].unsqueeze(0), "prompt": samples["prompt"], } if "text_input" in samples.keys(): this_sample["text_input"] = [samples["text_input"][i]] if 'context' in samples.keys(): this_sample['context'] = [samples["context"][i]] if 'history' in samples.keys(): this_sample['history'] = [samples["history"][i]] if 'caption' in samples.keys(): this_sample['caption'] = [samples["caption"][i]] this_result = self._predict_class(this_sample, candidates[i], n_segments, special_qformer_input_prompt) results.append(this_result) try: results = torch.cat(results, dim=0) except: results = [res.tolist()[0] for res in results] return results return self._predict_class(samples, candidates, n_segments, special_qformer_input_prompt) def _predict_class( self, samples, candidates, n_segments=1, special_qformer_input_prompt=False, ): if list(samples.keys()) == []: return None if "prompt" in samples: prompt = samples["prompt"] else: prompt = self.prompt candidates = [self.format_candidates_prompt.format(c) for c in candidates] if 'modalities' in samples: curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] else: curr_modalities = [modality for modality in self.modalities if modality in samples] # get batch size for modality in curr_modalities: data = samples[modality] if isinstance(data, torch.Tensor): bs = data.size(0) else: bs = len(data) break if isinstance(prompt, str): prompt = [prompt] * bs else: assert len(prompt) == bs, "The number of prompts must be equal to the batch size." if "text_input" in samples.keys(): if type(samples["text_input"][0]) == list: prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))] else: prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))] # scienceqa if 'context' in samples.keys() and samples['context'] != '': prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))] # visual dialog if 'history' in samples.keys() and samples['history'][0] != '': prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))] if 'caption' in samples.keys() and samples['caption'][0] != '': prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))] if 'discrn' in samples and self.use_caption: ## discriminatory reasoning if self.postfix: prompt = [f'{p}{self.postfix}' for p in prompt] if self.enumerate_inputs: prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] else: prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] text_input_tokens = self.llm_tokenizer( prompt, padding="longest", return_tensors="pt" ).to(self.device) else: if not self.projection_only and not getattr(self, f"projection_only_{modality}"): query_tokens = {} for modality in self.modalities: if modality not in samples: continue query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(bs, -1, -1) if self.qformer_text_input: if self.special_qformer_input_prompt or special_qformer_input_prompt: qformer_prompt = special_qformer_input_prompt if special_qformer_input_prompt else self.special_qformer_input_prompt qformer_prompt = [qformer_prompt] * len(prompt) if "text_input" in samples.keys(): if type(samples["text_input"][0]) == list: qformer_prompt = [qformer_prompt[i].format(*samples["text_input"][i]) for i in range(len(qformer_prompt))] else: qformer_prompt = [qformer_prompt[i].format(samples["text_input"][i]) for i in range(len(qformer_prompt))] text_Qformer = self.tokenizer( qformer_prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) elif self.use_describe: modality2prompt = { "video": "a short description of the video", "audio": "an audio that shows", "image": "a short image caption", "pc": "a 3d model of" } qformer_prompt = [modality2prompt[modality] for _ in samples['text_input']] # qformer_prompt = [f'Describe the {Blip2VicunaXInstruct.MODALITY_TO_CUE[modality].replace(":", "").strip() if modality != "pc" else "3d model"}.' for _ in samples["text_input"]] text_Qformer = self.tokenizer( qformer_prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) else: text_Qformer = self.tokenizer( prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(self.device) Qformer_atts = {} query_atts = {} for modality in curr_modalities: # B, Token Size query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) # B, Token Size + Inp Size Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) embeds = {} data_atts = {} for modality in curr_modalities: data = samples[modality] ln = getattr(self, f"{modality}_ln") encoder = getattr(self, f"{modality}_encoder") if modality == "video" and "clip" in self.video_enc_name: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(2)): this_frame = data[:,:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) elif modality == 'audio' and 'beats' in self.audio_enc_name: embeds[modality] = [] data_atts[modality] = [] for j in range(data.size(1)): this_frame = data[:,j,:,:] with self.maybe_autocast(): embeds[modality].append(ln(encoder(this_frame))) if self.shared_qformer: embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) else: with self.maybe_autocast(): embeds[modality] = ln(encoder(data)) if len(embeds[modality].size()) == 2: # B, C, D embeds[modality] = embeds[modality].unsqueeze(1) # B, C if self.shared_qformer: embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) query_outputs = {} num = {} if self.qformer_text_input: for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids.repeat(num[modality], 1), attention_mask=Qformer_atts[modality].repeat(num[modality], 1), query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: bs = embeds[modality].shape[0] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( text_Qformer.input_ids, attention_mask=Qformer_atts[modality], query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) else: for modality in curr_modalities: if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: num[modality] = len(embeds[modality]) bs = embeds[modality][0].shape[0] indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] reordered_embeds = torch.cat(embeds[modality])[indices] reordered_atts = torch.cat(data_atts[modality])[indices] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) continue query_output = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), encoder_hidden_states=reordered_embeds, encoder_attention_mask=reordered_atts, return_dict=True, ) query_outputs[modality] = query_output else: bs = embeds[modality].shape[0] if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) else: query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) continue query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( query_embeds=query_tokens[modality], encoder_hidden_states=embeds[modality].to(torch.float32), encoder_attention_mask=data_atts[modality], return_dict=True, ) inputs_llm = {} atts_llm = {} enumeration = {} # from pdb import set_trace; set_trace() for i,modality in enumerate(curr_modalities): if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim != 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num[modality], self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs*num, self.num_query_token, -1)) inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue # num*bs, num query tokens, llm emb size inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) # bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) else: if self.projection_only or getattr(self, f"projection_only_{modality}"): if self.proj_dim == 1: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) else: inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) continue inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]) atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) if self.enumerate_inputs: enumeration[modality] = self.llm_tokenizer( [f"{'' if i == 0 else ' '}({chr(97+i)}) " for _ in prompt], return_tensors="pt", add_special_tokens=False if (i!= 0 or self.prefix) else True ).to(self.device) att_list = [] inp_list = [] if self.prefix: att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] for modality in curr_modalities: if self.enumerate_inputs: enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration[modality].input_ids.to(self.device)) enumeration_atts_llm = enumeration[modality].attention_mask.to(self.device) inp_list.extend([enumeration_inputs_llm]) att_list.extend([enumeration_atts_llm]) if self.use_cues: if self.clean_tokenization or self.remove_start: if (modality==curr_modalities[0] and not (self.prefix or self.enumerate_inputs)): att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) else: att_list.extend([atts_llm[modality]]) inp_list.extend([inputs_llm[modality]]) if self.add_space: space_tok = self.llm_tokenizer( [f" " for _ in prompt], return_tensors="pt", add_special_tokens=False ) space_inputs_llm = self.llm_model.get_input_embeddings()(space_tok.input_ids.to(self.device)) space_atts_llm = space_tok.attention_mask.to(self.device) inp_list.extend([space_inputs_llm]) att_list.extend([space_atts_llm]) atts_llm = torch.cat(att_list, dim=1) empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(self.device).fill_(-100) inputs_llm = torch.cat(inp_list, dim=1) self.llm_tokenizer.padding_side = "right" self.llm_tokenizer.truncation_side = 'left' text_input_tokens = self.llm_tokenizer( [f"{p}{self.postfix}" for p in prompt] if self.postfix else prompt, padding="longest", return_tensors="pt", add_special_tokens= not self.clean_tokenization ).to(self.device) self.llm_tokenizer.truncation_side = 'right' n_cands = len(candidates) with self.maybe_autocast(): all_losses = [] for n in range(n_segments): seg_len = n_cands // n_segments if n == (n_segments - 1): seg_len = n_cands - seg_len * (n_segments - 1) start_i = n * (n_cands // n_segments) end_i = start_i + seg_len this_output_tokens = self.llm_tokenizer( candidates[start_i:end_i], return_tensors="pt", padding="longest", # truncation=True, # max_length=self.max_output_txt_len, ).to(self.device) this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0) this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0) this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1) this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1) this_llm_tokens, this_input_targets_len = self.concat_text_input_output( this_input_tokens_ids, this_input_tokens_atts, this_output_tokens_ids, this_output_tokens_atts ) this_llm_input_ids = this_llm_tokens['input_ids'] this_llm_atts = this_llm_tokens['attention_mask'] inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids) if self.use_caption: inputs_embeds = torch.cat([inputs_embeds], dim=1) attention_mask = torch.cat([this_llm_atts], dim=1) else: inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1) attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), this_llm_atts], dim=1) this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100) for i, l in enumerate(this_input_targets_len): this_targets[i][:l] = -100 if self.use_caption: torch.cat([this_targets], dim=1) else: this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), this_targets], dim=1) outputs = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=this_targets, reduction="none", ) loss = outputs.loss loss = loss.reshape(bs, seg_len) all_losses.append(loss) all_losses = torch.cat(all_losses, dim=-1) all_losses = -all_losses output_class_ranks = torch.argsort(all_losses, dim=-1) return {"predictions": all_losses, "targets": torch.tensor([self.candidates.index(l) for l in samples["label"]])} def _lemmatize(self, answers): def apply(answer): doc = self.lemmatizer(answer) words = [] for token in doc: if token.pos_ in ["NOUN", "VERB"]: words.append(token.lemma_) else: words.append(token.text) answer = " ".join(words) return answer return [apply(answer) for answer in answers] @property def lemmatizer(self): if self._lemmatizer is None: try: import spacy self._lemmatizer = spacy.load("en_core_web_sm") except ImportError: logging.error( """ Please install spacy and en_core_web_sm model to apply lemmatization. python -m spacy download en_core_web_sm OR import spacy.cli spacy.cli.download("en_core_web_sm") """ ) exit(1) return self._lemmatizer def get_optimizer_params(self, weight_decay, lr_scale=1): return BaseModel.get_optimizer_params(self, weight_decay, lr_scale=lr_scale) @classmethod def from_config(cls, cfg): image_model = cfg.get("image_model","eva_clip_g") pc_model = cfg.get("pc_model","ulip2_pointbert") video_model = cfg.get("video_model","eva_clip_g") audio_model = cfg.get("audio_model","beats") pretrained_image_qformer = cfg.get("pretrained_image_qformer",None) pretrained_pc_qformer = cfg.get("pretrained_pc_qformer",None) pretrained_video_qformer = cfg.get("pretrained_video_qformer",None) pretrained_audio_qformer = cfg.get("pretrained_audio_qformer",None) load_attention_image_qformer = cfg.get("load_attention_image_qformer",False) load_attention_pc_qformer = cfg.get("load_attention_pc_qformer",False) load_attention_video_qformer = cfg.get("load_attention_video_qformer",False) load_attention_audio_qformer = cfg.get("load_attention_audio_qformer",False) load_qformer_type_image=cfg.get('load_qformer_type_image', "") load_qformer_type_pc=cfg.get('load_qformer_type_pc', "") load_qformer_type_video=cfg.get('load_qformer_type_video', "") load_qformer_type_audio=cfg.get('load_qformer_type_audio',"") load_projection_image=cfg.get('load_projection_image', True) load_projection_pc=cfg.get('load_projection_pc', True) load_projection_video=cfg.get('load_projection_video', True) load_projection_audio=cfg.get('load_projection_audio', True) load_projection_type_image=cfg.get('load_projection_type_image', "") load_projection_type_pc=cfg.get('load_projection_type_pc', "") load_projection_type_video=cfg.get('load_projection_type_video', "") load_projection_type_audio=cfg.get('load_projection_type_audio', "") load_ln_type_image=cfg.get('load_ln_type_image', "") load_ln_type_pc=cfg.get('load_ln_type_pc', "") load_ln_type_video=cfg.get('load_ln_type_video', "") load_ln_type_audio=cfg.get('load_ln_type_audio', "") image_encoder_kwargs = cfg.get("image_encoder_kwargs", {"image_size": 224, "drop_path_rate": 0, "use_grad_checkpoint": False}) pc_encoder_kwargs = cfg.get("pc_encoder_kwargs",{}) video_encoder_kwargs = cfg.get("video_encoder_kwargs",{}) audio_encoder_kwargs = cfg.get("audio_encoder_kwargs",{}) image_precision = cfg.get("image_precision","fp16") pc_precision = cfg.get("pc_precision","fp16") video_precision = cfg.get("video_precision","fp16") audio_precision = cfg.get("audio_precision","fp16") freeze_image = cfg.get("freeze_image",True) freeze_pc = cfg.get("freeze_pc",True) freeze_video = cfg.get("freeze_video",True) freeze_audio = cfg.get("freeze_audio",True) num_query_token = cfg.get("num_query_token") llm_model = cfg.get("llm_model") freeze_pc = cfg.get("freeze_pc", True) freeze_video = cfg.get("freeze_video", True) freeze_audio = cfg.get("freeze_audio", True) prompt = cfg.get("prompt", "") max_txt_len = cfg.get("max_txt_len", 128) max_output_txt_len = cfg.get("max_output_txt_len", 256) apply_lemmatizer = cfg.get("apply_lemmatizer", False) qformer_text_input = cfg.get("qformer_text_input", True) modalities = cfg.get("modalities", ["image"]) use_cues = cfg.get("use_cues", True) shared_qformer = cfg.get("shared_qformer",False) pretrained_shared_qformer = cfg.get("pretrained_shared_qformer", None) load_attention_shared_qformer = cfg.get("load_attention_shared_qformer", None) load_qformer_type_shared= cfg.get('load_qformer_type_shared',"") load_projection_shared= cfg.get('load_projection_shared',False) load_projection_type_shared= cfg.get('load_projection_type_shared',"") shared_qformer_num_features=cfg.get("shared_qformer_num_features", 512) encoder_projection_type_image=cfg.get("encoder_projection_type_image","") encoder_projection_type_video=cfg.get("encoder_projection_type_video","") encoder_projection_type_audio=cfg.get("encoder_projection_type_audio","") encoder_projection_type_pc=cfg.get("encoder_projection_type_pc","") llm_text_input = cfg.get("llm_text_input", True) lora = cfg.get("lora", False) prefix = cfg.get("prefix", "") postfix = cfg.get("postfix", "") cached_audio= cfg.get("cached_audio", False) cached_image= cfg.get("cached_image", False) cached_video= cfg.get("cached_video", False) cached_pc= cfg.get("cached_pc", False) num_features_audio=cfg.get('num_features_audio', 768) num_features_image=cfg.get('num_features_image', 1408) num_features_video=cfg.get('num_features_video', 14080) num_features_pc=cfg.get('num_features_depth', 512) joint_video_audio=cfg.get('joint_video_audio', False) use_caption=cfg.get('use_caption', False) use_describe=cfg.get('use_describe', False) predict_with_gen = cfg.get('predict_with_gen', False) format_candidates_prompt = cfg.get('format_candidates_prompt', "{}") special_qformer_input_prompt = cfg.get('special_qformer_input_prompt', False) enumerate_inputs = cfg.get('enumerate_inputs', False) add_space = cfg.get('add_space', True) projection_only = cfg.get('projection_only', False) lora_model = cfg.get('lora_model', '') projection_only_audio= cfg.get('projection_only_audio', False) projection_only_pc= cfg.get('projection_only_pc', False) projection_only_video= cfg.get('projection_only_video', False) projection_only_image= cfg.get('projection_only_image', False) projection_path_audio=cfg.get('projection_path_audio', False) projection_path_pc=cfg.get('projection_path_pc', False) projection_path_video=cfg.get('projection_path_video', False) projection_path_image=cfg.get('projection_path_image', False) remove_start=cfg.get('remove_start', False) proj_dim=cfg.get('proj_dim', 1) clean_tokenization=cfg.get('clean_tokenization', False) logging.info("Model Config Arguments:") logging.info(OmegaConf.to_yaml(cfg)) model = cls( image_model=image_model, pc_model=pc_model, video_model=video_model, audio_model=audio_model, pretrained_image_qformer=pretrained_image_qformer, pretrained_pc_qformer=pretrained_pc_qformer, pretrained_video_qformer=pretrained_video_qformer, pretrained_audio_qformer=pretrained_audio_qformer, load_attention_image_qformer=load_attention_image_qformer, load_attention_pc_qformer=load_attention_pc_qformer, load_attention_video_qformer=load_attention_video_qformer, load_attention_audio_qformer=load_attention_audio_qformer, load_qformer_type_image=load_qformer_type_image, load_qformer_type_pc=load_qformer_type_pc, load_qformer_type_video=load_qformer_type_video, load_qformer_type_audio=load_qformer_type_audio, load_projection_image=load_projection_image, load_projection_pc=load_projection_pc, load_projection_video=load_projection_video, load_projection_audio=load_projection_audio, load_projection_type_image=load_projection_type_image, load_projection_type_pc=load_projection_type_pc, load_projection_type_video=load_projection_type_video, load_projection_type_audio=load_projection_type_audio, load_ln_type_image=load_ln_type_image, load_ln_type_pc=load_ln_type_pc, load_ln_type_video=load_ln_type_video, load_ln_type_audio=load_ln_type_audio, image_encoder_kwargs = image_encoder_kwargs, pc_encoder_kwargs = pc_encoder_kwargs, video_encoder_kwargs = video_encoder_kwargs, audio_encoder_kwargs = audio_encoder_kwargs, image_precision=image_precision, pc_precision=pc_precision, video_precision=video_precision, audio_precision=audio_precision, freeze_image=freeze_image, freeze_pc=freeze_pc, freeze_video=freeze_video, freeze_audio=freeze_audio, num_query_token=num_query_token, llm_model=llm_model, lora_model=lora_model, lora = lora, prompt=prompt, max_txt_len=max_txt_len, max_output_txt_len=max_output_txt_len, apply_lemmatizer=apply_lemmatizer, qformer_text_input=qformer_text_input, modalities=modalities, use_cues=use_cues, llm_text_input=llm_text_input, shared_qformer=shared_qformer, pretrained_shared_qformer = pretrained_shared_qformer, load_attention_shared_qformer = load_attention_shared_qformer, shared_qformer_num_features=shared_qformer_num_features, load_qformer_type_shared= load_qformer_type_shared, load_projection_shared= load_projection_shared, encoder_projection_type_image=encoder_projection_type_image, encoder_projection_type_video=encoder_projection_type_video, encoder_projection_type_audio=encoder_projection_type_audio, encoder_projection_type_pc=encoder_projection_type_pc, projection_path_audio=projection_path_audio, projection_path_pc=projection_path_pc, projection_path_video=projection_path_video, projection_path_image=projection_path_image, load_projection_type_shared= load_projection_type_shared, prefix=prefix, postfix=postfix, cached_audio=cached_audio, cached_image=cached_image, cached_video=cached_video, cached_pc=cached_pc, num_features_audio=num_features_audio, num_features_image=num_features_image, num_features_video=num_features_video, num_features_pc=num_features_pc, joint_video_audio=joint_video_audio, use_caption=use_caption, use_describe=use_describe, predict_with_gen=predict_with_gen, format_candidates_prompt=format_candidates_prompt, special_qformer_input_prompt=special_qformer_input_prompt, enumerate_inputs=enumerate_inputs, add_space=add_space, projection_only=projection_only, projection_only_audio= projection_only_audio, projection_only_pc= projection_only_pc, projection_only_video= projection_only_video, projection_only_image= projection_only_image, remove_start= remove_start, proj_dim=proj_dim, clean_tokenization=clean_tokenization ) stage1_url_or_filename = cfg.get("stage1_url_or_filename","") if stage1_url_or_filename: model.load_from_pretrained(stage1_url_or_filename) model.load_checkpoint_from_config(cfg) return model @classmethod def init_ln(cls, num_features, load_ln_path=False, load_ln_type=""): ln = LayerNorm(num_features) if load_ln_path and load_ln_type: url_or_filename=load_ln_path logging.info(f"Loading pretrained layer norm weights from {url_or_filename} of type {load_ln_type}") if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if load_ln_type: load_ln_type = f"{load_ln_type}_ln" if "vision" not in load_ln_type else "ln_vision" loaded_state_dict = {} if 'model' in checkpoint: checkpoint = checkpoint['model'] for k in checkpoint.keys(): if load_ln_type in k: loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] ln.load_state_dict(loaded_state_dict, strict=False) return ln @classmethod def init_encoder_projection(cls, enc_num_features, shared_qformer_num_features, load_proj_path=False, load_proj_type=""): encoder_projection = nn.Linear(enc_num_features, shared_qformer_num_features) if load_proj_path and load_proj_type: url_or_filename=load_proj_path logging.info(f"Loading shared Qformer encoder projection weights from {url_or_filename} of type {load_proj_type}") if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if load_proj_type: load_proj_type = f"{load_proj_type}_" loaded_state_dict = {} if 'model' in checkpoint: checkpoint = checkpoint['model'] for k in checkpoint.keys(): if load_proj_type+'encoder_projection' in k: loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] encoder_projection.load_state_dict(loaded_state_dict, strict=False) return encoder_projection @classmethod def init_vicuna_projection(cls, input_size, output_size, load_projection_path=False, load_projection_type="", projection_key=None): proj = nn.Linear(input_size, output_size) if load_projection_path: url_or_filename=load_projection_path logging.info(f"Loading pretrained projection weights from {url_or_filename} of type {load_projection_type} with key {projection_key if projection_key else load_projection_type+'_llm_proj.'}") if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if load_projection_type: load_projection_type = f"{load_projection_type}_" loaded_state_dict = {} if 'model' in checkpoint: checkpoint = checkpoint['model'] for k in checkpoint.keys(): if projection_key: if projection_key in k: loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] else: if load_projection_type+'llm_proj.' in k: loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] proj.load_state_dict(loaded_state_dict, strict=False) return proj @classmethod def init_Qformer(cls, num_query_token, modality_width, cross_attention_freq=2, pretrained_qformer=None, load_attention=False, load_qformer_type=""): encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = modality_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token encoder_config.vocab_size += 1 # for special token [DEC] Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) if pretrained_qformer: url_or_filename=pretrained_qformer logging.info(f"Loading pretrained qformer weights and query tokens from {url_or_filename} of type {load_qformer_type}") if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if load_qformer_type: load_qformer_type = f"{load_qformer_type}_" loaded_state_dict = {} if 'model' in checkpoint: checkpoint = checkpoint['model'] for k in checkpoint.keys(): if load_qformer_type+'Qformer.' in k: if not load_attention and 'attention' in k: continue loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] Qformer.load_state_dict(loaded_state_dict, strict=False) query_tokens.data = checkpoint[load_qformer_type+'query_tokens'] return Qformer, query_tokens def get_state_dict(self, url_or_filename, **kwargs): if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if "model" in checkpoint.keys(): state_dict = checkpoint["model"] else: state_dict = checkpoint return state_dict def load_from_pretrained(self, url_or_filename, **kwargs): state_dict = self.get_state_dict(url_or_filename) self.load_state_dict(state_dict, strict=False) logging.info("load checkpoint from %s" % url_or_filename) def load_checkpoint(self, url_or_filename, **kwargs): """ Load from a finetuned checkpoint. This should expect no mismatch in the model keys and the checkpoint keys. """ state_dict = self.get_state_dict(url_or_filename) self.load_state_dict(state_dict, strict=True) logging.info("load checkpoint from %s" % url_or_filename) def load_state_dict(self, state_dict, strict=True): # from pdb import set_trace; set_trace() unexpected_keys = [] missing_keys = [] if self.shared_qformer and not self.projection_only: ## Load Q-Former if it is not loaded from config if not getattr(self, "pretrained_shared_qformer"): shared_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if "shared_Qformer" == k.split('.')[0]} msg = self.shared_Qformer.load_state_dict(shared_qformer_state_dict, strict=strict) missing_keys.extend(msg.missing_keys) ## Load query tokens if "shared_query_tokens" not in state_dict: missing_keys.append("shared_query_tokens") else: self.shared_query_tokens = state_dict["shared_query_tokens"] missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) for modality in self.modalities: # Map shared Qformer by reference to all modalities. setattr(self, f"{modality}_Qformer", self.shared_Qformer) getattr(self, f"{modality}_query_tokens").data = state_dict[f"shared_query_tokens"] # load encoder projections modality_encoder_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_encoder_projection" in k.split('.')[0]} msg = getattr(self, f"{modality}_encoder_projection").load_state_dict(modality_encoder_projection_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) # load modality layer norm if getattr(self,f"load_ln_type_{modality}") == "vision": modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"ln_vision" in k.split('.')[0]} else: modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_ln" in k.split('.')[0]} msg = getattr(self, f"{modality}_ln").load_state_dict(modality_ln_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) ## Load Shared LLM projection if not loaded by config if not getattr(self, "load_projection_shared"): shared_llm_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"shared_llm_proj" in k.split('.')[0]} msg = self.shared_llm_proj.load_state_dict(shared_llm_projection_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) for modality in self.modalities: ## Map to modality projections by reference msg = setattr(self, f"{modality}_llm_proj", self.shared_llm_proj) else: for modality in self.modalities: ## Load Q-Former if not loaded from config if not getattr(self, f"pretrained_{modality}_qformer") or ((self.projection_only or getattr(self, f"projection_only_{modality}")) and not getattr(self, f"projection_path_{modality}")): if self.projection_only or getattr(self, f"projection_only_{modality}") : if not getattr(self, f"projection_path_{modality}"): logging.info(f"Loaded {modality} projection") modality_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_projection" == k.split('.')[0]} msg = getattr(self, f"{modality}_projection").load_state_dict(modality_qformer_state_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) else: modality_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_Qformer" == k.split('.')[0]} msg = getattr(self, f"{modality}_Qformer").load_state_dict(modality_qformer_state_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) ## Load query tokens if not self.projection_only and not getattr(self, f"projection_only_{modality}"): if f"{modality}_query_tokens" not in state_dict: missing_keys.append(f"{modality}_query_tokens") else: logging.info(f"Loaded {modality} query tokens") getattr(self, f"{modality}_query_tokens").data = state_dict[f"{modality}_query_tokens"] # load modality layer norm if not loaded from config if getattr(self,f"load_ln_type_{modality}") == "vision": logging.info(f"Loaded {modality} vision ln") modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"ln_vision" in k.split('.')[0]} else: modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_ln" in k.split('.')[0]} msg = getattr(self, f"{modality}_ln").load_state_dict(modality_ln_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) ## Load LLM projections if not loaded from config if not getattr(self, f"load_projection_{modality}") or (getattr(self, f"projection_only_{modality}") or self.projection_only): if not getattr(self, f"projection_path_{modality}"): logging.info(f"Loaded {modality} llm projection") modality_llm_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_llm_proj" in k.split('.')[0]} msg = getattr(self, f"{modality}_llm_proj").load_state_dict(modality_llm_projection_dict, strict=strict) missing_keys.extend(msg.missing_keys) unexpected_keys.extend(msg.unexpected_keys) ## llm model is loaded from pretrained lora_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"llm_model" in k.split('.')[0]} if not self.lora or len(lora_state_dict) == 0: unexpected_keys = [k for k in unexpected_keys if k.split('.')[0] != 'llm_model'] else: msg = self.llm_model.load_state_dict({'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"llm_model" in k.split('.')[0]}, strict=False) missing_keys.extend(["llm_model."+k for k in msg.missing_keys]) missing_keys = [k for k in missing_keys if 'encoder' not in k.split('.')[0]] missing_keys = [k for k in missing_keys if k.split('.')[0] != 'llm_model'] return _IncompatibleKeys(missing_keys, unexpected_keys) def before_evaluation(self, dataset, task_type, **kwargs): if task_type == MultimodalClassificationTask: self.candidates = dataset.classnames print(self.candidates)