Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
""" | |
Requires Transformer 4.28 and above, implementation may change according the Llama implementation | |
""" | |
import logging | |
import string | |
from packaging import version | |
import os | |
from omegaconf import OmegaConf | |
import torch | |
from torch.cuda.amp import autocast as autocast | |
import torch.nn as nn | |
from torch.nn.modules.module import _IncompatibleKeys | |
from peft import ( | |
get_peft_model, | |
LoraConfig, | |
TaskType, | |
) | |
import transformers | |
import random | |
from lavis.common.registry import registry | |
from lavis.models.base_model import BaseModel | |
from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train, LayerNorm | |
from lavis.models.ulip_models.ULIP_models import ULIP_PointBERT | |
from lavis.tasks.multimodal_classification import MultimodalClassificationTask | |
from lavis.common.utils import is_url | |
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel | |
from lavis.common.dist_utils import download_cached_file | |
from lavis.processors.blip_processors import BlipCaptionProcessor | |
class CastOutputToFloat(nn.Sequential): | |
def forward(self, x): return super().forward(x).to(torch.float32) | |
class Blip2VicunaXInstruct(Blip2Base): | |
""" | |
BLIP2 Vicuna model. | |
Supported model types: | |
- vicuna7b | |
- vicuna13b | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip2_vicuna_xinstruct", "vicuna7b") | |
""" | |
PRETRAINED_MODEL_CONFIG_DICT = { | |
"vicuna7b": "configs/models/blip2/blip2_xinstruct_vicuna7b.yaml", | |
"vicuna13b": "configs/models/blip2/blip2_xinstruct_vicuna13b.yaml", | |
} | |
SEQUENCIAL_ENCODERS = [ | |
"eva_clip_g", | |
"beats" | |
] | |
SEQUENCIAL_MODALITIES = [ | |
"video", | |
"audio" | |
] | |
MODALITY_TO_CUE = { | |
"image": " image: ", | |
"pc": " 3d: ", | |
"video": " video: ", | |
"audio": " audio: ", | |
} | |
def __init__( | |
self, | |
modalities = ["image", "pc", "audio", "video"], | |
use_cues=True, | |
num_query_token=32, | |
qformer_text_input=True, | |
llm_text_input=False, | |
apply_lemmatizer=False, | |
## encoders | |
image_model="eva_clip_g", | |
pc_model="ulip2_pointbert", | |
video_model="eva_clip_g", | |
audio_model="beats", | |
image_encoder_kwargs = {"image_size": 224, "drop_path_rate": 0, "use_grad_checkpoint": False}, | |
pc_encoder_kwargs = {}, | |
video_encoder_kwargs = {}, | |
audio_encoder_kwargs = {}, | |
image_precision="fp16", | |
pc_precision="fp16", | |
video_precision="fp16", | |
audio_precision="fp16", | |
freeze_image=True, | |
freeze_pc=True, | |
freeze_video=True, | |
freeze_audio=True, | |
## load pretrained parameters | |
pretrained_image_qformer=None, | |
pretrained_pc_qformer=None, | |
pretrained_video_qformer=None, | |
pretrained_audio_qformer=None, | |
load_attention_image_qformer=False, | |
load_attention_pc_qformer=False, | |
load_attention_video_qformer=False, | |
load_attention_audio_qformer=False, | |
load_qformer_type_image="", | |
load_qformer_type_pc="", | |
load_qformer_type_video="", | |
load_qformer_type_audio="", | |
load_ln_type_image="", | |
load_ln_type_pc="", | |
load_ln_type_video="", | |
load_ln_type_audio="", | |
load_projection_image=True, | |
load_projection_pc=True, | |
load_projection_video=True, | |
load_projection_audio=True, | |
load_projection_type_image="", | |
load_projection_type_pc="", | |
load_projection_type_video="", | |
load_projection_type_audio="", | |
## llm model parameters | |
llm_model="", | |
lora_model="", | |
lora=False, | |
## generation parameters | |
prompt="", | |
prefix="", | |
postfix="", | |
max_txt_len=128, | |
max_output_txt_len=256, | |
special_qformer_input_prompt=False, | |
enumerate_inputs=False, | |
add_space=False, | |
remove_start=False, | |
clean_tokenization=False, # if set to true removes whitespace from cue, and start token from prompt. | |
## shared Q-former setup | |
shared_qformer=False, | |
pretrained_shared_qformer=None, | |
load_attention_shared_qformer=False, | |
load_qformer_type_shared="", | |
load_projection_shared=False, | |
load_projection_type_shared="", | |
encoder_projection_type_image="", | |
encoder_projection_type_pc="", | |
encoder_projection_type_video="", | |
encoder_projection_type_audio="", | |
shared_qformer_num_features=512, | |
## use cached features | |
cached_audio=False, | |
cached_image=False, | |
cached_pc=False, | |
cached_video=False, | |
## num features for modality (only needed in cached cases.) | |
num_features_audio=768, | |
num_features_image=1408, | |
num_features_video=1408, | |
num_features_pc=512, | |
joint_video_audio=False, | |
## DisCRN | |
use_caption=False, | |
use_describe=False, | |
## classification setup | |
predict_with_gen=False, | |
format_candidates_prompt="{}", | |
## projection only parameters | |
projection_only=False, | |
projection_only_audio=False, | |
projection_only_pc=False, | |
projection_only_video=False, | |
projection_only_image=False, | |
projection_path_audio=False, | |
projection_path_pc=False, | |
projection_path_video=False, | |
projection_path_image=False, | |
proj_dim=1, | |
): | |
super().__init__() | |
transformers_version = version.parse(transformers.__version__) | |
assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28" | |
from transformers import LlamaTokenizer | |
from lavis.models.blip2_models.modeling_llama import LlamaForCausalLM | |
logging.info(f"Using modalities {modalities}") | |
self.modalities = modalities | |
logging.info(f"Shared Qformer is set to {shared_qformer}") | |
self.shared_qformer = shared_qformer | |
logging.info(f"Video-audio interleaving is set to {joint_video_audio}") | |
self.joint_video_audio = joint_video_audio | |
logging.info(f"Using Spacy en_core_wb_sm lemmatizer is set to {apply_lemmatizer}") | |
self._lemmatizer = None | |
self.apply_lemmatizer = apply_lemmatizer | |
logging.info(f"Qformer text input {qformer_text_input} and LLM Text Input {llm_text_input}") | |
self.qformer_text_input = qformer_text_input | |
self.llm_text_input = llm_text_input | |
self.projection_only = projection_only | |
self.proj_dim = proj_dim | |
logging.info(f"Projection only setup is set to {projection_only} with dimension {proj_dim}") | |
for modality in self.modalities: | |
setattr(self, f"cached_{modality}", locals()[f"cached_{modality}"]) | |
if locals()[f"cached_{modality}"]: | |
setattr(self, f"num_features_{modality}", locals()[f"num_features_{modality}"]) | |
logging.info(f"Using cached {modality} representation with {getattr(self, f'num_features_{modality}')} embedding dim.") | |
### Initialize modality enoders ### | |
for modality in self.modalities: | |
modality_model = locals()[f"{modality}_model"] | |
modality_precision = locals()[f"{modality}_precision"] | |
modality_kwargs = locals()[f"{modality}_encoder_kwargs"] | |
modality_kwargs['load_ln_path'] = locals()[f"pretrained_shared_qformer"] if shared_qformer else \ | |
locals()[f"pretrained_{modality}_qformer"] | |
setattr(self, f"projection_only_{modality}", locals()[f"projection_only_{modality}"]) | |
setattr(self, f"projection_path_{modality}", locals()[f"projection_path_{modality}"]) | |
modality_kwargs['load_ln_type'] = locals()[f"load_ln_type_{modality}"] | |
if self.projection_only or locals()[f"projection_only_{modality}"]: | |
modality_kwargs['load_ln_path'] = getattr(self, f"projection_path_{modality}") | |
modality_kwargs['load_ln_type'] = modality | |
setattr(self, f"load_ln_type_{modality}", locals()[f"load_ln_type_{modality}"]) | |
setattr(self, f"pretrained_{modality}_qformer", locals()[f"pretrained_{modality}_qformer"]) | |
modality_encoder, modality_ln = getattr(self, f"init_{modality}_encoder")( | |
modality_model, | |
precision=modality_precision, | |
**modality_kwargs | |
) | |
freeze_modality = locals()[f"freeze_{modality}"] | |
cached_modality = locals()[f"cached_{modality}"] | |
if cached_modality: | |
setattr(self, f"{modality}_encoder", modality_encoder) | |
setattr(self, f"{modality}_ln", modality_ln) | |
continue | |
if freeze_modality: | |
for name, param in modality_encoder.named_parameters(): | |
param.requires_grad = False | |
modality_encoder = modality_encoder.eval() | |
modality_encoder.train = disabled_train | |
logging.info(f"freeze {modality} encoder") | |
setattr(self, f"{modality}_encoder", modality_encoder) | |
setattr(self, f"{modality}_ln", modality_ln) | |
##### Init QFormers #### | |
self.tokenizer = self.init_tokenizer(truncation_side="left") # 30523 tokens. | |
self.num_query_token = num_query_token | |
if self.shared_qformer: | |
logging.info(f"Initializing shared QFormer with {shared_qformer_num_features} \ | |
number of features and query tokens of length {num_query_token}") | |
setattr(self, f"pretrained_shared_qformer", pretrained_shared_qformer) | |
setattr(self, f"load_qformer_type_shared", load_qformer_type_shared) | |
self.shared_Qformer, self.shared_query_tokens = self.init_Qformer( | |
num_query_token, | |
shared_qformer_num_features, | |
pretrained_qformer=pretrained_shared_qformer, | |
load_attention=load_attention_shared_qformer, | |
load_qformer_type=load_qformer_type_shared | |
) | |
if not qformer_text_input: | |
self.shared_Qformer.bert.embeddings.word_embeddings = None | |
self.shared_Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.shared_Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
else: | |
self.shared_Qformer.resize_token_embeddings(len(self.tokenizer)) | |
self.shared_Qformer.cls = None | |
# Map shared Qformer by reference to all modalities. | |
for modality in self.modalities: | |
setattr(self, f"{modality}_Qformer", self.shared_Qformer) | |
setattr(self, f"{modality}_query_tokens", self.shared_query_tokens) | |
encoder_proj_type=locals()[f"encoder_projection_type_{modality}"] | |
setattr(self, f"encoder_projection_type_{modality}", locals()[f"encoder_projection_type_{modality}"]) | |
modality_encoder_features = getattr(self, f"{modality}_encoder").num_features | |
setattr(self, f"{modality}_encoder_projection", self.init_encoder_projection(modality_encoder_features, shared_qformer_num_features, pretrained_shared_qformer, encoder_proj_type)) | |
else: | |
for modality in self.modalities: | |
if getattr(self,f"cached_{modality}"): | |
modality_num_features = locals()[f"num_features_{modality}"] | |
else: | |
modality_num_features = getattr(self, f"{modality}_encoder").num_features | |
setattr(self, f"pretrained_{modality}_qformer", locals()[f"pretrained_{modality}_qformer"]) | |
setattr(self, f"load_qformer_type_{modality}", locals()[f"load_qformer_type_{modality}"]) | |
setattr(self, f"projection_only_{modality}", locals()[f"projection_only_{modality}"]) | |
setattr(self, f"projection_path_{modality}", locals()[f"projection_path_{modality}"]) | |
if self.projection_only or locals()[f"projection_only_{modality}"]: | |
logging.info(f"Initializing {modality} projection") | |
setattr(self, f"pretrained_{modality}_qformer", False) | |
if modality == 'audio' and proj_dim == 1: | |
modality_num_features *= 256 # hack to get full beats embedding. define better. | |
modality_projection = self.init_vicuna_projection( | |
modality_num_features, | |
num_query_token*proj_dim, | |
load_projection_path=getattr(self, f"projection_path_{modality}"), | |
load_projection_type=modality, | |
projection_key=f"{modality}_projection" | |
) | |
setattr(self, f"{modality}_projection", modality_projection) | |
else: | |
logging.info(f"Initializing {modality} QFormer and query tokens of length {num_query_token}") | |
modality_qformer, modality_query_tokens = self.init_Qformer( | |
num_query_token, | |
modality_num_features, | |
pretrained_qformer=locals()[f"pretrained_{modality}_qformer"], | |
load_attention=locals()[f"load_attention_{modality}_qformer"], | |
load_qformer_type=locals()[f"load_qformer_type_{modality}"] | |
) | |
if not qformer_text_input: | |
modality_qformer.bert.embeddings.word_embeddings = None | |
modality_qformer.bert.embeddings.position_embeddings = None | |
for layer in modality_qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
else: | |
modality_qformer.resize_token_embeddings(len(self.tokenizer)) | |
modality_qformer.cls = None | |
setattr(self, f"{modality}_Qformer", modality_qformer) | |
setattr(self, f"{modality}_query_tokens", modality_query_tokens) | |
### Set up LLM ### | |
logging.info(f"Setting up llm model {llm_model}") | |
self.lora = lora | |
print(f"Lora is set to {self.lora}") | |
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left") | |
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'}) | |
self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'}) | |
self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'}) | |
if self.lora: | |
# https://github.com/lxe/llama-peft-tuner/blob/main/finetune_peft.py | |
self.llm_model = LlamaForCausalLM.from_pretrained( | |
llm_model, | |
load_in_8bit=True, | |
torch_dtype=torch.float16 | |
) | |
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) | |
self.peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
r=8, | |
lora_alpha=32, lora_dropout=0.1, | |
target_modules=['q_proj', 'v_proj'] | |
) | |
self.llm_model.gradient_checkpointing_enable() | |
self.llm_model.enable_input_require_grads() | |
self.llm_model.lm_head = CastOutputToFloat(self.llm_model.lm_head) | |
self.llm_model.config.use_cache = False # silence the warnings. Please re-enable for inference! | |
self.llm_hidden_size = self.llm_model.config.hidden_size | |
self.llm_model = get_peft_model(self.llm_model, self.peft_config) | |
self.lora_model = lora_model | |
else: | |
self.llm_model = LlamaForCausalLM.from_pretrained( | |
llm_model, torch_dtype=torch.float16 | |
) | |
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) | |
self.llm_hidden_size = self.llm_model.config.hidden_size | |
for name, param in self.llm_model.named_parameters(): | |
param.requires_grad = False | |
# Load LM projections | |
if self.shared_qformer and load_projection_shared: | |
qformer = getattr(self, f"shared_Qformer") | |
load_projection_path = locals()[f"load_projection_shared"] | |
if load_projection_path: | |
load_projection_path = locals()[f"pretrained_shared_qformer"] | |
load_projection_type = locals()[f"load_projection_type_shared"] | |
setattr(self, f"load_projection_shared", load_projection_path) | |
setattr(self, f"load_projection_type_shared", locals()[f"load_projection_type_shared"]) | |
logging.info(f"Loading shared Qformer projection.") | |
proj = self.init_vicuna_projection( | |
qformer.config.hidden_size, | |
self.llm_hidden_size, | |
load_projection_path=load_projection_path | |
) | |
# Map projection by reference to all modalities. | |
for modality in self.modalities: | |
setattr(self, f"{modality}_llm_proj", proj) | |
else: | |
for modality in self.modalities: | |
load_projection_path = locals()[f"load_projection_{modality}"] | |
if load_projection_path == True: | |
load_projection_path = locals()[f"pretrained_{modality}_qformer"] | |
load_projection_type = locals()[f"load_projection_type_{modality}"] | |
setattr(self, f"load_projection_{modality}", load_projection_path) | |
setattr(self, f"load_projection_type_{modality}", load_projection_type) | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
proj = self.init_vicuna_projection( | |
self.num_query_token if proj_dim==1 else proj_dim, | |
self.num_query_token*self.llm_hidden_size if proj_dim==1 else self.llm_hidden_size, | |
load_projection_path=getattr(self, f"projection_path_{modality}"), | |
load_projection_type=load_projection_type, | |
) | |
else: | |
qformer = getattr(self, f"{modality}_Qformer") | |
proj = self.init_vicuna_projection( | |
qformer.config.hidden_size, | |
self.llm_hidden_size, | |
load_projection_path=load_projection_path, | |
load_projection_type=load_projection_type | |
) | |
setattr(self, f"{modality}_llm_proj", proj) | |
self.clean_tokenization = clean_tokenization | |
logging.info(f"Clean tokenization is set to {self.clean_tokenization}") | |
self.max_txt_len = max_txt_len | |
self.max_output_txt_len = max_output_txt_len | |
self.prompt = prompt | |
self.prefix = prefix | |
if self.prefix: | |
self.tokenized_prefix = self.llm_tokenizer(self.prefix, return_tensors="pt") | |
self.postfix = postfix | |
if type(self.postfix) != str or not self.postfix: | |
self.postfix = "" | |
logging.info(f"Using prefix set to {self.prefix} and postfix set to {self.postfix}.") | |
self.use_cues = use_cues | |
logging.info(f"Using cues set to {self.use_cues}.") | |
if self.use_cues: | |
logging.info(f"Modality to cue {Blip2VicunaXInstruct.MODALITY_TO_CUE}") | |
self.tokenized_cue = {} | |
self.emb_cue = {} | |
self.att_cue = {} | |
for modality in self.modalities: | |
if self.clean_tokenization: | |
Blip2VicunaXInstruct.MODALITY_TO_CUE[modality] = Blip2VicunaXInstruct.MODALITY_TO_CUE[modality].lstrip() | |
self.tokenized_cue[modality] = self.llm_tokenizer(Blip2VicunaXInstruct.MODALITY_TO_CUE[modality], return_tensors="pt") | |
self.emb_cue[modality] = self.llm_model.get_input_embeddings()(self.tokenized_cue[modality].input_ids.to(self.device)) | |
self.att_cue[modality] = self.tokenized_cue[modality].attention_mask.to(self.device) | |
## generation parameters | |
self.use_caption=use_caption | |
self.use_describe=use_describe | |
self.predict_with_gen=predict_with_gen | |
self.format_candidates_prompt=format_candidates_prompt | |
self.special_qformer_input_prompt=special_qformer_input_prompt | |
self.enumerate_inputs=enumerate_inputs | |
self.add_space=add_space | |
self.remove_start=remove_start | |
if self.projection_only: | |
self.qformer_text_input=False | |
def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): | |
input_part_targets_len = [] | |
llm_tokens = {"input_ids": [], "attention_mask": []} | |
for i in range(input_ids.size(0)): | |
this_input_ones = input_atts[i].sum() | |
input_part_targets_len.append(this_input_ones) | |
llm_tokens['input_ids'].append( | |
torch.cat([ | |
input_ids[i][:this_input_ones], | |
output_ids[i][1:], | |
input_ids[i][this_input_ones:] | |
]) | |
) | |
llm_tokens['attention_mask'].append( | |
torch.cat([ | |
input_atts[i][:this_input_ones], | |
output_atts[i][1:], | |
input_atts[i][this_input_ones:] | |
]) | |
) | |
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) | |
llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask']) | |
return llm_tokens, input_part_targets_len | |
def forward(self, samples): | |
# print('-----------------') | |
# print(samples["text_input"]) | |
# print(samples["text_output"]) | |
# print('-----------------') | |
if samples == None or samples == {} or not any([modality in samples for modality in self.modalities]): | |
return {"loss": torch.tensor(0.0)} | |
random.shuffle(self.modalities) | |
curr_modalities = [modality for modality in self.modalities if modality in samples] | |
excess_modalities = [modality for modality in self.modalities if modality not in curr_modalities] | |
# disable gradient in excess modalities | |
dummy_loss = 0. | |
for modality in excess_modalities: | |
if self.shared_qformer: | |
for name, param in getattr(self, f"{modality}_encoder_projection").named_parameters(): | |
# param.requires_grad = False | |
dummy_loss += param.sum()*0. | |
for name, param in getattr(self,f"{modality}_ln").named_parameters(): | |
# param.requires_grad = False | |
dummy_loss += param.sum()*0. | |
dummy_loss += getattr(self, f"{modality}_query_tokens").sum()*0. | |
for name, param in getattr(self, f'{modality}_Qformer').named_parameters(): | |
# param.requires_grad = False | |
dummy_loss += param.sum()*0. | |
for name, param in getattr(self, f'{modality}_llm_proj').named_parameters(): | |
# param.requires_grad = False | |
dummy_loss += param.sum()*0. | |
embeds = {} | |
query_tokens = {} | |
data_atts = {} | |
for modality in curr_modalities: | |
data = samples[modality] | |
ln = getattr(self, f"{modality}_ln") | |
encoder = getattr(self, f"{modality}_encoder") | |
if modality == "video" and self.video_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(2)): | |
this_frame = data[:,:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][-1] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
# B, Token Size, LM EMB | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) | |
elif modality == 'audio' and self.audio_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(1)): | |
this_frame = data[:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
# B, Token Size, LM EMB | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) | |
else: | |
with self.maybe_autocast(): | |
embeds[modality] = ln(encoder(data)) | |
if len(embeds[modality].size()) == 2: | |
# B, C, D | |
embeds[modality] = embeds[modality].unsqueeze(1) | |
# B, C | |
if self.shared_qformer: | |
embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) | |
data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size, LM EMB | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(embeds[modality].shape[0], -1, -1) | |
query_outputs = {} | |
if self.qformer_text_input: | |
text_Qformer = self.tokenizer( | |
samples["text_input"] if not self.special_qformer_input_prompt else self.special_qformer_input_prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
Qformer_atts = {} | |
query_atts = {} | |
for modality in curr_modalities: | |
# B, Token Size | |
query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size + Inp Size | |
Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num)]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids.repeat(num, 1), | |
attention_mask=Qformer_atts[modality].repeat(num, 1), | |
query_embeds=query_tokens[modality].repeat(num, 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids, | |
attention_mask=Qformer_atts[modality], | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
else: | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num)]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality].repeat(num, 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
bs = embeds[modality].shape[0] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), # pc data is floa16. | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
inputs_llm = {} | |
atts_llm = {} | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
# num*bs, num query tokens, llm emb size | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num, self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]).reshape(bs*num, self.num_query_token, -1) | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num, self.num_query_token, -1).view(bs, num*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) | |
# bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num, self.num_query_token, -1).view(bs, num*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim == 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
self.llm_tokenizer.padding_side = "right" | |
self.llm_tokenizer.truncation_side = 'left' | |
if self.llm_text_input: | |
text_input_tokens = self.llm_tokenizer( | |
[f"{t}{self.postfix}" for t in samples['text_input']] if self.postfix else samples['text_input'], | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=self.max_txt_len, | |
add_special_tokens= not self.clean_tokenization | |
).to(self.device) | |
self.llm_tokenizer.truncation_side = 'right' | |
text_output_tokens = self.llm_tokenizer( | |
[t + self.llm_tokenizer.eos_token for t in samples['text_output']], | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=self.max_output_txt_len, | |
).to(self.device) | |
if self.llm_text_input: | |
llm_tokens, input_part_targets_len = self.concat_text_input_output( | |
text_input_tokens.input_ids, | |
text_input_tokens.attention_mask, | |
text_output_tokens.input_ids, | |
text_output_tokens.attention_mask, | |
) | |
else: | |
llm_tokens = text_output_tokens | |
input_part_targets_len = [0 for _ in range(llm_tokens['input_ids'].shape[0])] # input length is 0 | |
# do not apply loss to the padding | |
targets = llm_tokens['input_ids'].masked_fill( | |
llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100 | |
) | |
# do not apply loss to the text input (i.e., instruction) | |
for i, l in enumerate(input_part_targets_len): | |
targets[i][:l] = -100 | |
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids']) | |
bs = inputs_embeds.shape[0] | |
att_list = [] | |
inp_list = [] | |
if self.prefix: | |
att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] | |
inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] | |
for modality in curr_modalities: | |
if self.use_cues: | |
if self.prefix and self.clean_tokenization: | |
att_list.extend([self.att_cue[modality][:,1:].repeat(bs, 1).to(self.device), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality][:,1:].repeat(bs, 1, 1).to(self.device), inputs_llm[modality]]) | |
att_list.extend([self.att_cue[modality].repeat(bs, 1).to(self.device), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality].repeat(bs, 1, 1).to(self.device), inputs_llm[modality]]) | |
else: | |
att_list.extend([atts_llm[modality]]) | |
inp_list.extend([inputs_llm[modality]]) | |
# do not apply loss to the query tokens | |
empty_targets = ( | |
torch.ones(torch.cat(att_list, dim=1).size(), dtype=torch.long).to(self.device).fill_(-100) | |
) | |
# append llm prompt + output to queries | |
att_list.append(llm_tokens['attention_mask']) | |
inp_list.append(inputs_embeds) | |
inputs_embeds = torch.cat(inp_list, dim=1) | |
attention_mask = torch.cat(att_list, dim=1) | |
targets = torch.cat([empty_targets, targets], dim=1) | |
with self.maybe_autocast(): | |
outputs = self.llm_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
labels=targets, | |
) | |
loss = dummy_loss+outputs.loss | |
return {"loss": loss} | |
def init_image_encoder(self, | |
model_name, | |
precision, | |
**kwargs): | |
load_ln_path = kwargs['load_ln_path'] | |
del kwargs['load_ln_path'] | |
load_ln_type=kwargs['load_ln_type'] | |
del kwargs['load_ln_type'] | |
encoder, _ = super().init_vision_encoder(model_name, kwargs['image_size'], kwargs['drop_path_rate'], kwargs['use_grad_checkpoint'], precision) | |
ln = self.init_ln(encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) | |
return encoder, ln | |
def init_pc_encoder( | |
self, model_name, precision, **kwargs | |
): | |
assert model_name in [ | |
"ulip1_pointbert", | |
"ulip2_pointbert", | |
"ulip_shapenet", | |
"ulip_objaverse", | |
"objaverse_shapenet_k_1", | |
"ulip2_scaledup" | |
"" | |
], "pc model must be in [ulip1_pointbert,ulip2_pointbert]" | |
load_ln_path = kwargs['load_ln_path'] | |
del kwargs['load_ln_path'] | |
load_ln_type=kwargs['load_ln_type'] | |
del kwargs['load_ln_type'] | |
if model_name == "ulip2_pointbert": | |
pc_encoder = ULIP_PointBERT(ulip_v=2) | |
elif model_name == "ulip_shapenet": | |
pc_encoder = ULIP_PointBERT(ulip_v="shapenet") | |
elif model_name == "ulip_objaverse": | |
pc_encoder = ULIP_PointBERT(ulip_v="objaverse_k_1") | |
elif model_name == "objaverse_shapenet_k_1": | |
pc_encoder = ULIP_PointBERT(ulip_v="objaverse_shapenet_k_1") | |
elif model_name == "ulip2_scaledup": | |
pc_encoder = ULIP_PointBERT(ulip_v="ulip2_scaledup") | |
else: | |
pc_encoder = ULIP_PointBERT(ulip_v=1) | |
ln_pc = self.init_ln(pc_encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) | |
self.pc_enc_name = model_name | |
return pc_encoder, ln_pc | |
def init_video_encoder( | |
self, model_name, precision, **kwargs | |
): | |
assert model_name in [ | |
"eva_clip_g", | |
"eva2_clip_L", | |
"clip_L", | |
], "video_model must be in [eva_clip_g, eva2_clip_L, clip_L]" | |
if model_name in ["eva_clip_g","eva2_clip_L","clip_L",]: | |
video_encoder, ln_video = self.init_image_encoder( | |
model_name, precision=precision, **kwargs | |
) | |
self.video_enc_name = model_name | |
return video_encoder, ln_video | |
def init_audio_encoder( | |
self, model_name, precision, **kwargs | |
): | |
assert model_name in [ | |
'beats' | |
], "audio model must be in [beats]" | |
load_ln_path = kwargs['load_ln_path'] | |
del kwargs['load_ln_path'] | |
load_ln_type=kwargs['load_ln_type'] | |
del kwargs['load_ln_type'] | |
if "beats" in model_name: | |
from lavis.models.beats_encoder import BeatsEncoder | |
if self.cached_audio: | |
audio_encoder = lambda x: x | |
ln_audio = self.init_ln(768, load_ln_path=load_ln_path, load_ln_type=load_ln_type) | |
else: | |
audio_encoder = BeatsEncoder(**kwargs) | |
if not self.cached_audio: | |
ln_audio = self.init_ln(audio_encoder.num_features, load_ln_path=load_ln_path, load_ln_type=load_ln_type) | |
self.audio_enc_name = model_name | |
return audio_encoder, ln_audio | |
def get_query_outputs( | |
self, | |
samples | |
): | |
if samples == None or samples == {}: | |
return | |
curr_modalities = [modality for modality in self.modalities if modality in samples] | |
if len(curr_modalities) == 0: | |
print("Model modalities do not match sample modalities.") | |
return | |
# get batch size | |
bs = None | |
for modality in curr_modalities: | |
data = samples[modality] | |
bs = data.size(0) | |
break | |
if "prompt" in samples.keys(): | |
prompt = samples["prompt"] | |
elif "text_input" in samples.keys(): | |
prompt = samples["text_input"] | |
else: | |
prompt = self.prompt | |
if isinstance(prompt, str): | |
prompt = [prompt] * bs | |
else: | |
assert len(prompt) == bs, "The number of prompts must be equal to the batch size." | |
embeds = {} | |
query_tokens = {} | |
data_atts = {} | |
for modality in curr_modalities: | |
data = samples[modality] | |
ln = getattr(self, f"{modality}_ln") | |
encoder = getattr(self, f"{modality}_encoder") | |
if modality == "video" and self.video_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(2)): | |
this_frame = data[:,:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][-1] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
# B, Token Size, LM EMB | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) | |
elif modality == 'audio' and self.audio_enc_name in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(1)): | |
this_frame = data[:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
# B, Token Size, LM EMB | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(data.size(0), -1, -1) | |
else: | |
with self.maybe_autocast(): | |
embeds[modality] = ln(encoder(data)) | |
if len(embeds[modality].size()) == 2: | |
# B, C, D | |
embeds[modality] = embeds[modality].unsqueeze(1) | |
# B, C | |
if self.shared_qformer: | |
embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) | |
data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size, LM EMB | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(embeds[modality].shape[0], -1, -1) | |
query_outputs = {} | |
if self.qformer_text_input: | |
text_Qformer = self.tokenizer( | |
prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
Qformer_atts = {} | |
query_atts = {} | |
num = {} | |
for modality in curr_modalities: | |
# B, Token Size | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size + Inp Size | |
Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids.repeat(num[modality], 1), | |
attention_mask=Qformer_atts[modality].repeat(num[modality], 1), | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids, | |
attention_mask=Qformer_atts[modality], | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
else: | |
num = {} | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num, self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), # pc data is floa16. | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num, self.num_query_token, -1) | |
else: | |
query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]).reshape(bs*num, self.num_query_token, -1) | |
query_outputs[f'llm_proj_{modality}'] = query_outputs[f'llm_proj_{modality}'].reshape(bs, num[modality], self.num_query_token, -1).contiguous().view(bs, num[modality]*self.num_query_token, -1) | |
query_outputs[modality] = query_outputs[modality].view(bs, num[modality]*self.num_query_token, -1) | |
else: | |
query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]).contiguous().view(bs, num[modality]*self.num_query_token, -1) | |
query_outputs[modality] = query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:].contiguous().view(bs, num[modality]*self.num_query_token, -1) | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim == 1: | |
query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) | |
else: | |
query_outputs[f'llm_proj_{modality}']= getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) | |
else: | |
query_outputs[modality] = query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:] | |
query_outputs[f'llm_proj_{modality}'] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]) | |
for modality in curr_modalities: | |
query_outputs[f'embeds_{modality}'] = embeds[modality] | |
return query_outputs | |
def generate( | |
self, | |
samples, | |
use_nucleus_sampling=False, | |
num_beams=5, | |
max_length=256, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.5, | |
length_penalty=1, | |
num_captions=1, | |
temperature=1, | |
special_qformer_input_prompt=False | |
): | |
self.llm_tokenizer.padding_side = "left" | |
if samples == None or samples == {}: | |
return | |
if 'modalities' in samples: | |
curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] | |
elif self.joint_video_audio: | |
curr_modalities = ["video", "audio"] | |
else: | |
curr_modalities = [modality for modality in self.modalities if modality in samples] | |
if len(curr_modalities) == 0: | |
print("Model modalities do not match sample modalities.") | |
return | |
# get batch size | |
bs = None | |
for modality in curr_modalities: | |
data = samples[modality] | |
if isinstance(data, torch.Tensor): | |
bs = data.size(0) | |
else: | |
bs = len(data) | |
break | |
if "prompt" in samples.keys(): | |
prompt = samples["prompt"] | |
elif self.prompt and 'text_input' in samples and '{}' in self.prompt: | |
prompt = [self.prompt.format(t) for t in samples["text_input"]] | |
elif "text_input" in samples.keys(): | |
prompt = samples["text_input"] | |
else: | |
prompt = self.prompt | |
if isinstance(prompt, str): | |
prompt = [prompt] * bs | |
else: | |
assert len(prompt) == bs, "The number of prompts must be equal to the batch size." | |
# For TextCaps | |
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: | |
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] | |
if 'discrn' in samples and self.use_caption: ## discriminatory reasoning | |
if self.postfix: | |
prompt = [f'{t}{self.postfix}' for t in prompt] | |
if self.enumerate_inputs: | |
prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] | |
else: | |
prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] | |
llm_tokens = self.llm_tokenizer( | |
prompt, | |
padding="longest", | |
return_tensors="pt" | |
).to(self.device) | |
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids) | |
with self.maybe_autocast(): | |
outputs = self.llm_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=llm_tokens.attention_mask, | |
do_sample=use_nucleus_sampling, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
) | |
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) | |
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [o.strip() for o in output_text] | |
# print(output) | |
return output_text | |
query_tokens = {} | |
for modality in curr_modalities: | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(bs, -1, -1) | |
if self.qformer_text_input: | |
if self.special_qformer_input_prompt or special_qformer_input_prompt: | |
qformer_prompt = special_qformer_input_prompt if special_qformer_input_prompt else self.special_qformer_input_prompt | |
qformer_prompt = [qformer_prompt] * len(prompt) | |
if "text_input" in samples.keys(): | |
if type(samples["text_input"][0]) == list: | |
qformer_prompt = [qformer_prompt[i].format(*samples["text_input"][i]) for i in range(len(qformer_prompt))] | |
else: | |
qformer_prompt = [qformer_prompt[i].format(samples["text_input"][i]) for i in range(len(qformer_prompt))] | |
text_Qformer = self.tokenizer( | |
qformer_prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
elif self.use_describe: | |
modality2prompt = { | |
"video": "a short description of the video", | |
"audio": "an audio that shows", | |
"image": "a short image caption", | |
"pc": "a 3d model of" | |
} | |
qformer_prompt = [modality2prompt[modality] for _ in samples['text_input']] | |
text_Qformer = self.tokenizer( | |
qformer_prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
else: | |
text_Qformer = self.tokenizer( | |
prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
Qformer_atts = {} | |
query_atts = {} | |
for modality in curr_modalities: | |
if not getattr(self, f"projection_only_{modality}"): | |
# B, Token Size | |
query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size + Inp Size | |
Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) | |
embeds = {} | |
data_atts = {} | |
for modality in curr_modalities: | |
data = samples[modality] | |
ln = getattr(self, f"{modality}_ln") | |
encoder = getattr(self, f"{modality}_encoder") | |
if modality == "video" and "clip" in self.video_enc_name: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(2)): | |
this_frame = data[:,:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
elif modality == 'audio' and 'beats' in self.audio_enc_name: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(1)): | |
this_frame = data[:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
else: | |
with self.maybe_autocast(): | |
embeds[modality] = ln(encoder(data)) | |
if len(embeds[modality].size()) == 2: | |
embeds[modality] = embeds[modality].unsqueeze(1) | |
if self.shared_qformer: | |
with self.maybe_autocast(): | |
embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) | |
data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) | |
query_outputs = {} | |
num = {} | |
if self.qformer_text_input: | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids.repeat(num[modality], 1), | |
attention_mask=Qformer_atts[modality].repeat(num[modality], 1), | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
bs = embeds[modality].shape[0] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids, | |
attention_mask=Qformer_atts[modality], | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
else: | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
bs = embeds[modality].shape[0] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
with self.maybe_autocast(): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
inputs_llm = {} | |
atts_llm = {} | |
enumeration = {} | |
for i,modality in enumerate(curr_modalities): | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num[modality], self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs*num, self.num_query_token, -1)) | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
# num*bs, num query tokens, llm emb size | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) | |
# bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim == 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
if self.enumerate_inputs: | |
enumeration[modality] = self.llm_tokenizer( | |
[f"{'' if i == 0 else ' '}({chr(97+i)}) " for _ in prompt], | |
return_tensors="pt", | |
add_special_tokens=False if (i!= 0 or self.prefix) else True | |
).to(self.device) | |
## remove trailing whitespace | |
prompt = [p.strip() for p in prompt] | |
if 'dialog' in samples: | |
llm_tokens = self.llm_tokenizer( | |
[f"{d} {p}" if d else p for d, p in zip(samples['dialog'], prompt)], | |
padding="longest", | |
return_tensors="pt", | |
add_special_tokens= not self.clean_tokenization | |
).to(self.device) | |
else: | |
llm_tokens = self.llm_tokenizer( | |
[f"{p}{self.postfix}" for p in prompt] if self.postfix else prompt, | |
padding="longest", | |
return_tensors="pt", | |
add_special_tokens= not self.clean_tokenization | |
).to(self.device) | |
bs = llm_tokens.input_ids.shape[0] | |
att_list = [] | |
inp_list = [] | |
if self.prefix: | |
att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] | |
inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] | |
if self.joint_video_audio: | |
for pos in range(num['video']): | |
if self.enumerate_inputs: | |
enumeration_pos = self.llm_tokenizer( | |
[f"{'' if pos == 0 else ' '}({chr(97+pos)}) " for _ in prompt], | |
return_tensors="pt", | |
add_special_tokens=False if (pos!= 0 or self.prefix) else True | |
).to(self.device) | |
enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration_pos.input_ids) | |
enumeration_atts_llm = enumeration_pos.attention_mask.to(self.device) | |
inp_list.extend([enumeration_inputs_llm]) | |
att_list.extend([enumeration_atts_llm]) | |
if self.use_cues: | |
for modality in ['video', 'audio']: | |
if self.clean_tokenization: | |
if self.prefix or pos > 1 or self.enumerate_inputs or modality == 'audio': | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) | |
inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) | |
continue | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) | |
inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) | |
else: | |
att_list.extend([atts_llm[modality].view(bs, num[modality], self.num_query_token)[:, pos, :]]) | |
inp_list.extend([inputs_llm[modality].view(bs, num[modality], self.num_query_token, -1)[:, pos, :, :]]) | |
else: | |
for modality in curr_modalities: | |
if self.enumerate_inputs: | |
enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration[modality].input_ids.to(self.device)) | |
enumeration_atts_llm = enumeration[modality].attention_mask.to(self.device) | |
inp_list.extend([enumeration_inputs_llm]) | |
att_list.extend([enumeration_atts_llm]) | |
if self.use_cues: | |
if self.clean_tokenization or self.remove_start: | |
if (modality==curr_modalities[0] and not (self.prefix or self.enumerate_inputs)): | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([atts_llm[modality]]) | |
inp_list.extend([inputs_llm[modality]]) | |
if self.add_space: | |
space_tok = self.llm_tokenizer( | |
[f" " for _ in prompt], | |
return_tensors="pt", | |
add_special_tokens=False | |
) | |
space_inputs_llm = self.llm_model.get_input_embeddings()(space_tok.input_ids.to(self.device)) | |
space_atts_llm = space_tok.attention_mask.to(self.device) | |
inp_list.extend([space_inputs_llm]) | |
att_list.extend([space_atts_llm]) | |
att_list.append(llm_tokens.attention_mask) | |
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids) | |
inp_list.append(inputs_embeds) | |
attention_mask = torch.cat(att_list, dim=1) | |
inputs_embeds = torch.cat(inp_list, dim=1) | |
with self.maybe_autocast(): | |
outputs = self.llm_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
do_sample=use_nucleus_sampling, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
) | |
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) | |
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [o.strip() for o in output_text] | |
return output_text | |
def predict_answers( | |
self, | |
samples, | |
num_beams=5, | |
inference_method="generate", | |
max_len=10, | |
min_len=1, | |
num_ans_candidates=128, | |
answer_list=None, | |
prompt="", | |
length_penalty=-1, | |
**kwargs | |
): | |
if samples == None or samples == {}: | |
return None | |
# get batch size | |
bs = None | |
if 'modalities' in samples: | |
curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] | |
else: | |
curr_modalities = [modality for modality in self.modalities if modality in samples] | |
for modality in curr_modalities: | |
data = samples[modality] | |
if isinstance(data, torch.Tensor): | |
bs = data.size(0) | |
else: | |
bs = len(data) | |
break | |
if "text_input" not in samples: | |
samples["text_input"] = self.prompt | |
if isinstance(samples["text_input"], str): | |
samples["text_input"] = [samples["text_input"]] * bs | |
text_input = samples['text_input'] | |
if not prompt and self.prompt: | |
prompt=self.prompt | |
if prompt: | |
if prompt.count("{}") == 2: | |
if 'ocr_tokens' in samples: | |
text_input = [ | |
prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i]) | |
for i in range(len(samples["text_input"]))] | |
elif 'choices' in samples: | |
text_input = [] | |
for i in range(len(samples["text_input"])): | |
this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])] | |
this_choices = " ".join(this_choices) | |
text_input.append(prompt.format(samples["text_input"][i], this_choices)) | |
else: | |
text_input = [prompt.format(question) for question in samples["text_input"]] | |
samples["prompt"] = text_input | |
if 'discrn' in samples and self.use_caption: ## discriminatory reasoning | |
self.llm_tokenizer.padding_side = "left" | |
text_input = samples['text_input'] if 'prompt' not in samples else samples['prompt'] | |
if self.postfix: | |
text_input = [f'{t}{self.postfix}' for t in text_input] | |
if self.enumerate_inputs: | |
prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {text_input[i]}' for i in range(bs)] | |
else: | |
prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {text_input[i]}' for i in range(bs)] | |
llm_tokens = self.llm_tokenizer( | |
prompt, | |
padding="longest", | |
return_tensors="pt" | |
).to(self.device) | |
with self.maybe_autocast(): | |
outputs = self.llm_model.generate( | |
inputs_embeds=self.llm_model.get_input_embeddings()(llm_tokens.input_ids), | |
attention_mask=llm_tokens.attention_mask, | |
do_sample=False, | |
num_beams=num_beams, | |
max_length=max_len, | |
min_length=min_len, | |
repetition_penalty=1.5, | |
# eos_token_id=self.eos_token_id, | |
length_penalty=length_penalty, | |
) | |
outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) | |
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
return output_text | |
output_text = self.generate( | |
samples, | |
num_beams=num_beams, | |
max_length=max_len, | |
min_length=min_len, | |
length_penalty=length_penalty | |
) | |
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]: | |
output_text = self._lemmatize(output_text) | |
#vizwiz | |
output_text = [o if o != "" else "unanswerable" for o in output_text] | |
return output_text | |
def predict( | |
self, | |
samples, | |
candidates=None, | |
n_segments=1, | |
max_length=10, | |
min_length=1, | |
length_penalty=-1., | |
special_qformer_input_prompt=False | |
): | |
self.llm_tokenizer.padding_side = "left" | |
if candidates == None: | |
candidates = self.candidates | |
else: | |
self.candidates = candidates # for the output targets. | |
if self.predict_with_gen: | |
output = self.generate(samples,max_length=max_length,min_length=min_length,length_penalty=length_penalty) | |
result = [] | |
for text in output: | |
text = BlipCaptionProcessor().pre_caption(text) | |
pred_label = "" # default to an empty string | |
for cand in candidates: | |
cand = BlipCaptionProcessor().pre_caption(cand) | |
if cand in text.split(" "): | |
pred_label = cand | |
break # stop as soon as we find a match | |
result.append(pred_label) | |
return {"predictions":result, "target": samples["label"]} | |
# If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one | |
if type(candidates[0]) == list: | |
results = [] | |
for i in range(samples["image"].size(0)): | |
this_sample = { | |
"image": samples["image"][i].unsqueeze(0), | |
"prompt": samples["prompt"], | |
} | |
if "text_input" in samples.keys(): | |
this_sample["text_input"] = [samples["text_input"][i]] | |
if 'context' in samples.keys(): | |
this_sample['context'] = [samples["context"][i]] | |
if 'history' in samples.keys(): | |
this_sample['history'] = [samples["history"][i]] | |
if 'caption' in samples.keys(): | |
this_sample['caption'] = [samples["caption"][i]] | |
this_result = self._predict_class(this_sample, candidates[i], n_segments, special_qformer_input_prompt) | |
results.append(this_result) | |
try: | |
results = torch.cat(results, dim=0) | |
except: | |
results = [res.tolist()[0] for res in results] | |
return results | |
return self._predict_class(samples, candidates, n_segments, special_qformer_input_prompt) | |
def _predict_class( | |
self, | |
samples, | |
candidates, | |
n_segments=1, | |
special_qformer_input_prompt=False, | |
): | |
if list(samples.keys()) == []: | |
return None | |
if "prompt" in samples: | |
prompt = samples["prompt"] | |
else: | |
prompt = self.prompt | |
candidates = [self.format_candidates_prompt.format(c) for c in candidates] | |
if 'modalities' in samples: | |
curr_modalities = samples['modalities'][0] if isinstance(samples['modalities'][0], list) else samples['modalities'] | |
else: | |
curr_modalities = [modality for modality in self.modalities if modality in samples] | |
# get batch size | |
for modality in curr_modalities: | |
data = samples[modality] | |
if isinstance(data, torch.Tensor): | |
bs = data.size(0) | |
else: | |
bs = len(data) | |
break | |
if isinstance(prompt, str): | |
prompt = [prompt] * bs | |
else: | |
assert len(prompt) == bs, "The number of prompts must be equal to the batch size." | |
if "text_input" in samples.keys(): | |
if type(samples["text_input"][0]) == list: | |
prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))] | |
else: | |
prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))] | |
# scienceqa | |
if 'context' in samples.keys() and samples['context'] != '': | |
prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))] | |
# visual dialog | |
if 'history' in samples.keys() and samples['history'][0] != '': | |
prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))] | |
if 'caption' in samples.keys() and samples['caption'][0] != '': | |
prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))] | |
if 'discrn' in samples and self.use_caption: ## discriminatory reasoning | |
if self.postfix: | |
prompt = [f'{p}{self.postfix}' for p in prompt] | |
if self.enumerate_inputs: | |
prompt = [f'{self.prefix}(a){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]] if self.use_cues else " "}{samples["baseline_captions"][i][0]} (b){Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] | |
else: | |
prompt = [f'{self.prefix}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][0]]}{samples["baseline_captions"][i][0] if self.use_cues else " "}{Blip2VicunaXInstruct.MODALITY_TO_CUE[samples["modalities"][i][1]] if self.use_cues else " "}{samples["baseline_captions"][i][1]} {prompt[i]}' for i in range(bs)] | |
text_input_tokens = self.llm_tokenizer( | |
prompt, | |
padding="longest", | |
return_tensors="pt" | |
).to(self.device) | |
else: | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
query_tokens = {} | |
for modality in self.modalities: | |
if modality not in samples: | |
continue | |
query_tokens[modality] = getattr(self, f"{modality}_query_tokens").expand(bs, -1, -1) | |
if self.qformer_text_input: | |
if self.special_qformer_input_prompt or special_qformer_input_prompt: | |
qformer_prompt = special_qformer_input_prompt if special_qformer_input_prompt else self.special_qformer_input_prompt | |
qformer_prompt = [qformer_prompt] * len(prompt) | |
if "text_input" in samples.keys(): | |
if type(samples["text_input"][0]) == list: | |
qformer_prompt = [qformer_prompt[i].format(*samples["text_input"][i]) for i in range(len(qformer_prompt))] | |
else: | |
qformer_prompt = [qformer_prompt[i].format(samples["text_input"][i]) for i in range(len(qformer_prompt))] | |
text_Qformer = self.tokenizer( | |
qformer_prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
elif self.use_describe: | |
modality2prompt = { | |
"video": "a short description of the video", | |
"audio": "an audio that shows", | |
"image": "a short image caption", | |
"pc": "a 3d model of" | |
} | |
qformer_prompt = [modality2prompt[modality] for _ in samples['text_input']] | |
# qformer_prompt = [f'Describe the {Blip2VicunaXInstruct.MODALITY_TO_CUE[modality].replace(":", "").strip() if modality != "pc" else "3d model"}.' for _ in samples["text_input"]] | |
text_Qformer = self.tokenizer( | |
qformer_prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
else: | |
text_Qformer = self.tokenizer( | |
prompt, | |
padding='longest', | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
Qformer_atts = {} | |
query_atts = {} | |
for modality in curr_modalities: | |
# B, Token Size | |
query_atts[modality] = torch.ones(query_tokens[modality].size()[:-1], dtype=torch.long).to(self.device) | |
# B, Token Size + Inp Size | |
Qformer_atts[modality] = torch.cat([query_atts[modality],text_Qformer.attention_mask],dim=1) | |
embeds = {} | |
data_atts = {} | |
for modality in curr_modalities: | |
data = samples[modality] | |
ln = getattr(self, f"{modality}_ln") | |
encoder = getattr(self, f"{modality}_encoder") | |
if modality == "video" and "clip" in self.video_enc_name: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(2)): | |
this_frame = data[:,:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
elif modality == 'audio' and 'beats' in self.audio_enc_name: | |
embeds[modality] = [] | |
data_atts[modality] = [] | |
for j in range(data.size(1)): | |
this_frame = data[:,j,:,:] | |
with self.maybe_autocast(): | |
embeds[modality].append(ln(encoder(this_frame))) | |
if self.shared_qformer: | |
embeds[modality][j] = getattr(self, f"{modality}_encoder_projection")(embeds[modality][j]) | |
data_atts[modality].append(torch.ones(embeds[modality][j].size()[:-1], dtype=torch.long).to(self.device)) | |
else: | |
with self.maybe_autocast(): | |
embeds[modality] = ln(encoder(data)) | |
if len(embeds[modality].size()) == 2: | |
# B, C, D | |
embeds[modality] = embeds[modality].unsqueeze(1) | |
# B, C | |
if self.shared_qformer: | |
embeds[modality] = getattr(self, f"{modality}_encoder_projection")(embeds[modality]) | |
data_atts[modality] = torch.ones(embeds[modality].size()[:-1], dtype=torch.long).to(self.device) | |
query_outputs = {} | |
num = {} | |
if self.qformer_text_input: | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids.repeat(num[modality], 1), | |
attention_mask=Qformer_atts[modality].repeat(num[modality], 1), | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
bs = embeds[modality].shape[0] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
text_Qformer.input_ids, | |
attention_mask=Qformer_atts[modality], | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
else: | |
for modality in curr_modalities: | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
num[modality] = len(embeds[modality]) | |
bs = embeds[modality][0].shape[0] | |
indices = [j_+r for r,j in enumerate([[i*bs for i in range(num[modality])]]*bs) for j_ in j] | |
reordered_embeds = torch.cat(embeds[modality])[indices] | |
reordered_atts = torch.cat(data_atts[modality])[indices] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.mean(1,keepdim=True)).view(bs*num[modality], self.num_query_token, -1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(reordered_embeds.view(reordered_embeds.shape[0],-1)) | |
continue | |
query_output = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality].repeat(num[modality], 1, 1), | |
encoder_hidden_states=reordered_embeds, | |
encoder_attention_mask=reordered_atts, | |
return_dict=True, | |
) | |
query_outputs[modality] = query_output | |
else: | |
bs = embeds[modality].shape[0] | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality].mean(1, keepdim=True)).reshape(bs, self.num_query_token,-1) | |
else: | |
query_outputs[modality] = getattr(self, f"{modality}_projection")(embeds[modality]).reshape(bs, self.num_query_token,-1) | |
continue | |
query_outputs[modality] = getattr(self, f"{modality}_Qformer").bert( | |
query_embeds=query_tokens[modality], | |
encoder_hidden_states=embeds[modality].to(torch.float32), | |
encoder_attention_mask=data_atts[modality], | |
return_dict=True, | |
) | |
inputs_llm = {} | |
atts_llm = {} | |
enumeration = {} | |
# from pdb import set_trace; set_trace() | |
for i,modality in enumerate(curr_modalities): | |
if modality in Blip2VicunaXInstruct.SEQUENCIAL_MODALITIES and getattr(self, f'{modality}_enc_name') in Blip2VicunaXInstruct.SEQUENCIAL_ENCODERS: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim != 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].unsqueeze(1)).reshape(bs*num[modality], self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs*num, self.num_query_token, -1)) | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
# num*bs, num query tokens, llm emb size | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].last_hidden_state[:,:query_tokens[modality].size(1),:]) | |
# bs, num, num query tokens, llm emb size -> bs, num*num query tokens, llm emb size | |
inputs_llm[modality] = inputs_llm[modality].reshape(bs, num[modality], self.num_query_token, -1).view(bs, num[modality]*self.num_query_token, -1) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
else: | |
if self.projection_only or getattr(self, f"projection_only_{modality}"): | |
if self.proj_dim == 1: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].mean(-1)).reshape(bs, self.num_query_token, -1) | |
else: | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality].reshape(bs, self.num_query_token, -1)) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
continue | |
inputs_llm[modality] = getattr(self, f"{modality}_llm_proj")(query_outputs[modality]['last_hidden_state'][:,:query_tokens[modality].size(1),:]) | |
atts_llm[modality] = torch.ones(inputs_llm[modality].size()[:-1], dtype=torch.long).to(self.device) | |
if self.enumerate_inputs: | |
enumeration[modality] = self.llm_tokenizer( | |
[f"{'' if i == 0 else ' '}({chr(97+i)}) " for _ in prompt], | |
return_tensors="pt", | |
add_special_tokens=False if (i!= 0 or self.prefix) else True | |
).to(self.device) | |
att_list = [] | |
inp_list = [] | |
if self.prefix: | |
att_list = [self.tokenized_prefix.attention_mask.repeat(bs, 1).to(self.device)] | |
inp_list = [self.llm_model.get_input_embeddings()(self.tokenized_prefix.input_ids.to(self.device)).repeat(bs, 1, 1)] | |
for modality in curr_modalities: | |
if self.enumerate_inputs: | |
enumeration_inputs_llm = self.llm_model.get_input_embeddings()(enumeration[modality].input_ids.to(self.device)) | |
enumeration_atts_llm = enumeration[modality].attention_mask.to(self.device) | |
inp_list.extend([enumeration_inputs_llm]) | |
att_list.extend([enumeration_atts_llm]) | |
if self.use_cues: | |
if self.clean_tokenization or self.remove_start: | |
if (modality==curr_modalities[0] and not (self.prefix or self.enumerate_inputs)): | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask[:,1:]).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality][:,1:].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([torch.tensor(self.tokenized_cue[modality].attention_mask).to(self.device).repeat(atts_llm[modality].shape[0], 1), atts_llm[modality]]) | |
inp_list.extend([self.emb_cue[modality].to(self.device).repeat(inputs_llm[modality].shape[0], 1, 1), inputs_llm[modality]]) | |
else: | |
att_list.extend([atts_llm[modality]]) | |
inp_list.extend([inputs_llm[modality]]) | |
if self.add_space: | |
space_tok = self.llm_tokenizer( | |
[f" " for _ in prompt], | |
return_tensors="pt", | |
add_special_tokens=False | |
) | |
space_inputs_llm = self.llm_model.get_input_embeddings()(space_tok.input_ids.to(self.device)) | |
space_atts_llm = space_tok.attention_mask.to(self.device) | |
inp_list.extend([space_inputs_llm]) | |
att_list.extend([space_atts_llm]) | |
atts_llm = torch.cat(att_list, dim=1) | |
empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(self.device).fill_(-100) | |
inputs_llm = torch.cat(inp_list, dim=1) | |
self.llm_tokenizer.padding_side = "right" | |
self.llm_tokenizer.truncation_side = 'left' | |
text_input_tokens = self.llm_tokenizer( | |
[f"{p}{self.postfix}" for p in prompt] if self.postfix else prompt, | |
padding="longest", | |
return_tensors="pt", | |
add_special_tokens= not self.clean_tokenization | |
).to(self.device) | |
self.llm_tokenizer.truncation_side = 'right' | |
n_cands = len(candidates) | |
with self.maybe_autocast(): | |
all_losses = [] | |
for n in range(n_segments): | |
seg_len = n_cands // n_segments | |
if n == (n_segments - 1): | |
seg_len = n_cands - seg_len * (n_segments - 1) | |
start_i = n * (n_cands // n_segments) | |
end_i = start_i + seg_len | |
this_output_tokens = self.llm_tokenizer( | |
candidates[start_i:end_i], | |
return_tensors="pt", | |
padding="longest", | |
# truncation=True, | |
# max_length=self.max_output_txt_len, | |
).to(self.device) | |
this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0) | |
this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0) | |
this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1) | |
this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1) | |
this_llm_tokens, this_input_targets_len = self.concat_text_input_output( | |
this_input_tokens_ids, | |
this_input_tokens_atts, | |
this_output_tokens_ids, | |
this_output_tokens_atts | |
) | |
this_llm_input_ids = this_llm_tokens['input_ids'] | |
this_llm_atts = this_llm_tokens['attention_mask'] | |
inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids) | |
if self.use_caption: | |
inputs_embeds = torch.cat([inputs_embeds], dim=1) | |
attention_mask = torch.cat([this_llm_atts], dim=1) | |
else: | |
inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1) | |
attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), this_llm_atts], dim=1) | |
this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100) | |
for i, l in enumerate(this_input_targets_len): | |
this_targets[i][:l] = -100 | |
if self.use_caption: | |
torch.cat([this_targets], dim=1) | |
else: | |
this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), this_targets], dim=1) | |
outputs = self.llm_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
labels=this_targets, | |
reduction="none", | |
) | |
loss = outputs.loss | |
loss = loss.reshape(bs, seg_len) | |
all_losses.append(loss) | |
all_losses = torch.cat(all_losses, dim=-1) | |
all_losses = -all_losses | |
output_class_ranks = torch.argsort(all_losses, dim=-1) | |
return {"predictions": all_losses, "targets": torch.tensor([self.candidates.index(l) for l in samples["label"]])} | |
def _lemmatize(self, answers): | |
def apply(answer): | |
doc = self.lemmatizer(answer) | |
words = [] | |
for token in doc: | |
if token.pos_ in ["NOUN", "VERB"]: | |
words.append(token.lemma_) | |
else: | |
words.append(token.text) | |
answer = " ".join(words) | |
return answer | |
return [apply(answer) for answer in answers] | |
def lemmatizer(self): | |
if self._lemmatizer is None: | |
try: | |
import spacy | |
self._lemmatizer = spacy.load("en_core_web_sm") | |
except ImportError: | |
logging.error( | |
""" | |
Please install spacy and en_core_web_sm model to apply lemmatization. | |
python -m spacy download en_core_web_sm | |
OR | |
import spacy.cli | |
spacy.cli.download("en_core_web_sm") | |
""" | |
) | |
exit(1) | |
return self._lemmatizer | |
def get_optimizer_params(self, weight_decay, lr_scale=1): | |
return BaseModel.get_optimizer_params(self, weight_decay, lr_scale=lr_scale) | |
def from_config(cls, cfg): | |
image_model = cfg.get("image_model","eva_clip_g") | |
pc_model = cfg.get("pc_model","ulip2_pointbert") | |
video_model = cfg.get("video_model","eva_clip_g") | |
audio_model = cfg.get("audio_model","beats") | |
pretrained_image_qformer = cfg.get("pretrained_image_qformer",None) | |
pretrained_pc_qformer = cfg.get("pretrained_pc_qformer",None) | |
pretrained_video_qformer = cfg.get("pretrained_video_qformer",None) | |
pretrained_audio_qformer = cfg.get("pretrained_audio_qformer",None) | |
load_attention_image_qformer = cfg.get("load_attention_image_qformer",False) | |
load_attention_pc_qformer = cfg.get("load_attention_pc_qformer",False) | |
load_attention_video_qformer = cfg.get("load_attention_video_qformer",False) | |
load_attention_audio_qformer = cfg.get("load_attention_audio_qformer",False) | |
load_qformer_type_image=cfg.get('load_qformer_type_image', "") | |
load_qformer_type_pc=cfg.get('load_qformer_type_pc', "") | |
load_qformer_type_video=cfg.get('load_qformer_type_video', "") | |
load_qformer_type_audio=cfg.get('load_qformer_type_audio',"") | |
load_projection_image=cfg.get('load_projection_image', True) | |
load_projection_pc=cfg.get('load_projection_pc', True) | |
load_projection_video=cfg.get('load_projection_video', True) | |
load_projection_audio=cfg.get('load_projection_audio', True) | |
load_projection_type_image=cfg.get('load_projection_type_image', "") | |
load_projection_type_pc=cfg.get('load_projection_type_pc', "") | |
load_projection_type_video=cfg.get('load_projection_type_video', "") | |
load_projection_type_audio=cfg.get('load_projection_type_audio', "") | |
load_ln_type_image=cfg.get('load_ln_type_image', "") | |
load_ln_type_pc=cfg.get('load_ln_type_pc', "") | |
load_ln_type_video=cfg.get('load_ln_type_video', "") | |
load_ln_type_audio=cfg.get('load_ln_type_audio', "") | |
image_encoder_kwargs = cfg.get("image_encoder_kwargs", {"image_size": 224, "drop_path_rate": 0, "use_grad_checkpoint": False}) | |
pc_encoder_kwargs = cfg.get("pc_encoder_kwargs",{}) | |
video_encoder_kwargs = cfg.get("video_encoder_kwargs",{}) | |
audio_encoder_kwargs = cfg.get("audio_encoder_kwargs",{}) | |
image_precision = cfg.get("image_precision","fp16") | |
pc_precision = cfg.get("pc_precision","fp16") | |
video_precision = cfg.get("video_precision","fp16") | |
audio_precision = cfg.get("audio_precision","fp16") | |
freeze_image = cfg.get("freeze_image",True) | |
freeze_pc = cfg.get("freeze_pc",True) | |
freeze_video = cfg.get("freeze_video",True) | |
freeze_audio = cfg.get("freeze_audio",True) | |
num_query_token = cfg.get("num_query_token") | |
llm_model = cfg.get("llm_model") | |
freeze_pc = cfg.get("freeze_pc", True) | |
freeze_video = cfg.get("freeze_video", True) | |
freeze_audio = cfg.get("freeze_audio", True) | |
prompt = cfg.get("prompt", "") | |
max_txt_len = cfg.get("max_txt_len", 128) | |
max_output_txt_len = cfg.get("max_output_txt_len", 256) | |
apply_lemmatizer = cfg.get("apply_lemmatizer", False) | |
qformer_text_input = cfg.get("qformer_text_input", True) | |
modalities = cfg.get("modalities", ["image"]) | |
use_cues = cfg.get("use_cues", True) | |
shared_qformer = cfg.get("shared_qformer",False) | |
pretrained_shared_qformer = cfg.get("pretrained_shared_qformer", None) | |
load_attention_shared_qformer = cfg.get("load_attention_shared_qformer", None) | |
load_qformer_type_shared= cfg.get('load_qformer_type_shared',"") | |
load_projection_shared= cfg.get('load_projection_shared',False) | |
load_projection_type_shared= cfg.get('load_projection_type_shared',"") | |
shared_qformer_num_features=cfg.get("shared_qformer_num_features", 512) | |
encoder_projection_type_image=cfg.get("encoder_projection_type_image","") | |
encoder_projection_type_video=cfg.get("encoder_projection_type_video","") | |
encoder_projection_type_audio=cfg.get("encoder_projection_type_audio","") | |
encoder_projection_type_pc=cfg.get("encoder_projection_type_pc","") | |
llm_text_input = cfg.get("llm_text_input", True) | |
lora = cfg.get("lora", False) | |
prefix = cfg.get("prefix", "") | |
postfix = cfg.get("postfix", "") | |
cached_audio= cfg.get("cached_audio", False) | |
cached_image= cfg.get("cached_image", False) | |
cached_video= cfg.get("cached_video", False) | |
cached_pc= cfg.get("cached_pc", False) | |
num_features_audio=cfg.get('num_features_audio', 768) | |
num_features_image=cfg.get('num_features_image', 1408) | |
num_features_video=cfg.get('num_features_video', 14080) | |
num_features_pc=cfg.get('num_features_depth', 512) | |
joint_video_audio=cfg.get('joint_video_audio', False) | |
use_caption=cfg.get('use_caption', False) | |
use_describe=cfg.get('use_describe', False) | |
predict_with_gen = cfg.get('predict_with_gen', False) | |
format_candidates_prompt = cfg.get('format_candidates_prompt', "{}") | |
special_qformer_input_prompt = cfg.get('special_qformer_input_prompt', False) | |
enumerate_inputs = cfg.get('enumerate_inputs', False) | |
add_space = cfg.get('add_space', True) | |
projection_only = cfg.get('projection_only', False) | |
lora_model = cfg.get('lora_model', '') | |
projection_only_audio= cfg.get('projection_only_audio', False) | |
projection_only_pc= cfg.get('projection_only_pc', False) | |
projection_only_video= cfg.get('projection_only_video', False) | |
projection_only_image= cfg.get('projection_only_image', False) | |
projection_path_audio=cfg.get('projection_path_audio', False) | |
projection_path_pc=cfg.get('projection_path_pc', False) | |
projection_path_video=cfg.get('projection_path_video', False) | |
projection_path_image=cfg.get('projection_path_image', False) | |
remove_start=cfg.get('remove_start', False) | |
proj_dim=cfg.get('proj_dim', 1) | |
clean_tokenization=cfg.get('clean_tokenization', False) | |
logging.info("Model Config Arguments:") | |
logging.info(OmegaConf.to_yaml(cfg)) | |
model = cls( | |
image_model=image_model, | |
pc_model=pc_model, | |
video_model=video_model, | |
audio_model=audio_model, | |
pretrained_image_qformer=pretrained_image_qformer, | |
pretrained_pc_qformer=pretrained_pc_qformer, | |
pretrained_video_qformer=pretrained_video_qformer, | |
pretrained_audio_qformer=pretrained_audio_qformer, | |
load_attention_image_qformer=load_attention_image_qformer, | |
load_attention_pc_qformer=load_attention_pc_qformer, | |
load_attention_video_qformer=load_attention_video_qformer, | |
load_attention_audio_qformer=load_attention_audio_qformer, | |
load_qformer_type_image=load_qformer_type_image, | |
load_qformer_type_pc=load_qformer_type_pc, | |
load_qformer_type_video=load_qformer_type_video, | |
load_qformer_type_audio=load_qformer_type_audio, | |
load_projection_image=load_projection_image, | |
load_projection_pc=load_projection_pc, | |
load_projection_video=load_projection_video, | |
load_projection_audio=load_projection_audio, | |
load_projection_type_image=load_projection_type_image, | |
load_projection_type_pc=load_projection_type_pc, | |
load_projection_type_video=load_projection_type_video, | |
load_projection_type_audio=load_projection_type_audio, | |
load_ln_type_image=load_ln_type_image, | |
load_ln_type_pc=load_ln_type_pc, | |
load_ln_type_video=load_ln_type_video, | |
load_ln_type_audio=load_ln_type_audio, | |
image_encoder_kwargs = image_encoder_kwargs, | |
pc_encoder_kwargs = pc_encoder_kwargs, | |
video_encoder_kwargs = video_encoder_kwargs, | |
audio_encoder_kwargs = audio_encoder_kwargs, | |
image_precision=image_precision, | |
pc_precision=pc_precision, | |
video_precision=video_precision, | |
audio_precision=audio_precision, | |
freeze_image=freeze_image, | |
freeze_pc=freeze_pc, | |
freeze_video=freeze_video, | |
freeze_audio=freeze_audio, | |
num_query_token=num_query_token, | |
llm_model=llm_model, | |
lora_model=lora_model, | |
lora = lora, | |
prompt=prompt, | |
max_txt_len=max_txt_len, | |
max_output_txt_len=max_output_txt_len, | |
apply_lemmatizer=apply_lemmatizer, | |
qformer_text_input=qformer_text_input, | |
modalities=modalities, | |
use_cues=use_cues, | |
llm_text_input=llm_text_input, | |
shared_qformer=shared_qformer, | |
pretrained_shared_qformer = pretrained_shared_qformer, | |
load_attention_shared_qformer = load_attention_shared_qformer, | |
shared_qformer_num_features=shared_qformer_num_features, | |
load_qformer_type_shared= load_qformer_type_shared, | |
load_projection_shared= load_projection_shared, | |
encoder_projection_type_image=encoder_projection_type_image, | |
encoder_projection_type_video=encoder_projection_type_video, | |
encoder_projection_type_audio=encoder_projection_type_audio, | |
encoder_projection_type_pc=encoder_projection_type_pc, | |
projection_path_audio=projection_path_audio, | |
projection_path_pc=projection_path_pc, | |
projection_path_video=projection_path_video, | |
projection_path_image=projection_path_image, | |
load_projection_type_shared= load_projection_type_shared, | |
prefix=prefix, | |
postfix=postfix, | |
cached_audio=cached_audio, | |
cached_image=cached_image, | |
cached_video=cached_video, | |
cached_pc=cached_pc, | |
num_features_audio=num_features_audio, | |
num_features_image=num_features_image, | |
num_features_video=num_features_video, | |
num_features_pc=num_features_pc, | |
joint_video_audio=joint_video_audio, | |
use_caption=use_caption, | |
use_describe=use_describe, | |
predict_with_gen=predict_with_gen, | |
format_candidates_prompt=format_candidates_prompt, | |
special_qformer_input_prompt=special_qformer_input_prompt, | |
enumerate_inputs=enumerate_inputs, | |
add_space=add_space, | |
projection_only=projection_only, | |
projection_only_audio= projection_only_audio, | |
projection_only_pc= projection_only_pc, | |
projection_only_video= projection_only_video, | |
projection_only_image= projection_only_image, | |
remove_start= remove_start, | |
proj_dim=proj_dim, | |
clean_tokenization=clean_tokenization | |
) | |
stage1_url_or_filename = cfg.get("stage1_url_or_filename","") | |
if stage1_url_or_filename: | |
model.load_from_pretrained(stage1_url_or_filename) | |
model.load_checkpoint_from_config(cfg) | |
return model | |
def init_ln(cls, num_features, load_ln_path=False, load_ln_type=""): | |
ln = LayerNorm(num_features) | |
if load_ln_path and load_ln_type: | |
url_or_filename=load_ln_path | |
logging.info(f"Loading pretrained layer norm weights from {url_or_filename} of type {load_ln_type}") | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if load_ln_type: | |
load_ln_type = f"{load_ln_type}_ln" if "vision" not in load_ln_type else "ln_vision" | |
loaded_state_dict = {} | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
for k in checkpoint.keys(): | |
if load_ln_type in k: | |
loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] | |
ln.load_state_dict(loaded_state_dict, strict=False) | |
return ln | |
def init_encoder_projection(cls, enc_num_features, shared_qformer_num_features, load_proj_path=False, load_proj_type=""): | |
encoder_projection = nn.Linear(enc_num_features, shared_qformer_num_features) | |
if load_proj_path and load_proj_type: | |
url_or_filename=load_proj_path | |
logging.info(f"Loading shared Qformer encoder projection weights from {url_or_filename} of type {load_proj_type}") | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if load_proj_type: | |
load_proj_type = f"{load_proj_type}_" | |
loaded_state_dict = {} | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
for k in checkpoint.keys(): | |
if load_proj_type+'encoder_projection' in k: | |
loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] | |
encoder_projection.load_state_dict(loaded_state_dict, strict=False) | |
return encoder_projection | |
def init_vicuna_projection(cls, input_size, output_size, load_projection_path=False, load_projection_type="", projection_key=None): | |
proj = nn.Linear(input_size, output_size) | |
if load_projection_path: | |
url_or_filename=load_projection_path | |
logging.info(f"Loading pretrained projection weights from {url_or_filename} of type {load_projection_type} with key {projection_key if projection_key else load_projection_type+'_llm_proj.'}") | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if load_projection_type: | |
load_projection_type = f"{load_projection_type}_" | |
loaded_state_dict = {} | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
for k in checkpoint.keys(): | |
if projection_key: | |
if projection_key in k: | |
loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] | |
else: | |
if load_projection_type+'llm_proj.' in k: | |
loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] | |
proj.load_state_dict(loaded_state_dict, strict=False) | |
return proj | |
def init_Qformer(cls, num_query_token, modality_width, cross_attention_freq=2, pretrained_qformer=None, load_attention=False, load_qformer_type=""): | |
encoder_config = BertConfig.from_pretrained("bert-base-uncased") | |
encoder_config.encoder_width = modality_width | |
# insert cross-attention layer every other block | |
encoder_config.add_cross_attention = True | |
encoder_config.cross_attention_freq = cross_attention_freq | |
encoder_config.query_length = num_query_token | |
encoder_config.vocab_size += 1 # for special token [DEC] | |
Qformer = BertLMHeadModel(config=encoder_config) | |
query_tokens = nn.Parameter( | |
torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
) | |
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
if pretrained_qformer: | |
url_or_filename=pretrained_qformer | |
logging.info(f"Loading pretrained qformer weights and query tokens from {url_or_filename} of type {load_qformer_type}") | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if load_qformer_type: | |
load_qformer_type = f"{load_qformer_type}_" | |
loaded_state_dict = {} | |
if 'model' in checkpoint: | |
checkpoint = checkpoint['model'] | |
for k in checkpoint.keys(): | |
if load_qformer_type+'Qformer.' in k: | |
if not load_attention and 'attention' in k: | |
continue | |
loaded_state_dict['.'.join(k.split('.')[1:])] = checkpoint[k] | |
Qformer.load_state_dict(loaded_state_dict, strict=False) | |
query_tokens.data = checkpoint[load_qformer_type+'query_tokens'] | |
return Qformer, query_tokens | |
def get_state_dict(self, url_or_filename, **kwargs): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if "model" in checkpoint.keys(): | |
state_dict = checkpoint["model"] | |
else: | |
state_dict = checkpoint | |
return state_dict | |
def load_from_pretrained(self, url_or_filename, **kwargs): | |
state_dict = self.get_state_dict(url_or_filename) | |
self.load_state_dict(state_dict, strict=False) | |
logging.info("load checkpoint from %s" % url_or_filename) | |
def load_checkpoint(self, url_or_filename, **kwargs): | |
""" | |
Load from a finetuned checkpoint. | |
This should expect no mismatch in the model keys and the checkpoint keys. | |
""" | |
state_dict = self.get_state_dict(url_or_filename) | |
self.load_state_dict(state_dict, strict=True) | |
logging.info("load checkpoint from %s" % url_or_filename) | |
def load_state_dict(self, state_dict, strict=True): | |
# from pdb import set_trace; set_trace() | |
unexpected_keys = [] | |
missing_keys = [] | |
if self.shared_qformer and not self.projection_only: | |
## Load Q-Former if it is not loaded from config | |
if not getattr(self, "pretrained_shared_qformer"): | |
shared_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if "shared_Qformer" == k.split('.')[0]} | |
msg = self.shared_Qformer.load_state_dict(shared_qformer_state_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
## Load query tokens | |
if "shared_query_tokens" not in state_dict: | |
missing_keys.append("shared_query_tokens") | |
else: | |
self.shared_query_tokens = state_dict["shared_query_tokens"] | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
for modality in self.modalities: | |
# Map shared Qformer by reference to all modalities. | |
setattr(self, f"{modality}_Qformer", self.shared_Qformer) | |
getattr(self, f"{modality}_query_tokens").data = state_dict[f"shared_query_tokens"] | |
# load encoder projections | |
modality_encoder_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_encoder_projection" in k.split('.')[0]} | |
msg = getattr(self, f"{modality}_encoder_projection").load_state_dict(modality_encoder_projection_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
# load modality layer norm | |
if getattr(self,f"load_ln_type_{modality}") == "vision": | |
modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"ln_vision" in k.split('.')[0]} | |
else: | |
modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_ln" in k.split('.')[0]} | |
msg = getattr(self, f"{modality}_ln").load_state_dict(modality_ln_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
## Load Shared LLM projection if not loaded by config | |
if not getattr(self, "load_projection_shared"): | |
shared_llm_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"shared_llm_proj" in k.split('.')[0]} | |
msg = self.shared_llm_proj.load_state_dict(shared_llm_projection_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
for modality in self.modalities: | |
## Map to modality projections by reference | |
msg = setattr(self, f"{modality}_llm_proj", self.shared_llm_proj) | |
else: | |
for modality in self.modalities: | |
## Load Q-Former if not loaded from config | |
if not getattr(self, f"pretrained_{modality}_qformer") or ((self.projection_only or getattr(self, f"projection_only_{modality}")) and not getattr(self, f"projection_path_{modality}")): | |
if self.projection_only or getattr(self, f"projection_only_{modality}") : | |
if not getattr(self, f"projection_path_{modality}"): | |
logging.info(f"Loaded {modality} projection") | |
modality_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_projection" == k.split('.')[0]} | |
msg = getattr(self, f"{modality}_projection").load_state_dict(modality_qformer_state_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
else: | |
modality_qformer_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_Qformer" == k.split('.')[0]} | |
msg = getattr(self, f"{modality}_Qformer").load_state_dict(modality_qformer_state_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
## Load query tokens | |
if not self.projection_only and not getattr(self, f"projection_only_{modality}"): | |
if f"{modality}_query_tokens" not in state_dict: | |
missing_keys.append(f"{modality}_query_tokens") | |
else: | |
logging.info(f"Loaded {modality} query tokens") | |
getattr(self, f"{modality}_query_tokens").data = state_dict[f"{modality}_query_tokens"] | |
# load modality layer norm if not loaded from config | |
if getattr(self,f"load_ln_type_{modality}") == "vision": | |
logging.info(f"Loaded {modality} vision ln") | |
modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"ln_vision" in k.split('.')[0]} | |
else: | |
modality_ln_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_ln" in k.split('.')[0]} | |
msg = getattr(self, f"{modality}_ln").load_state_dict(modality_ln_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
## Load LLM projections if not loaded from config | |
if not getattr(self, f"load_projection_{modality}") or (getattr(self, f"projection_only_{modality}") or self.projection_only): | |
if not getattr(self, f"projection_path_{modality}"): | |
logging.info(f"Loaded {modality} llm projection") | |
modality_llm_projection_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"{modality}_llm_proj" in k.split('.')[0]} | |
msg = getattr(self, f"{modality}_llm_proj").load_state_dict(modality_llm_projection_dict, strict=strict) | |
missing_keys.extend(msg.missing_keys) | |
unexpected_keys.extend(msg.unexpected_keys) | |
## llm model is loaded from pretrained | |
lora_state_dict = {'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"llm_model" in k.split('.')[0]} | |
if not self.lora or len(lora_state_dict) == 0: | |
unexpected_keys = [k for k in unexpected_keys if k.split('.')[0] != 'llm_model'] | |
else: | |
msg = self.llm_model.load_state_dict({'.'.join(k.split('.')[1:]):v for k,v in state_dict.items() if f"llm_model" in k.split('.')[0]}, strict=False) | |
missing_keys.extend(["llm_model."+k for k in msg.missing_keys]) | |
missing_keys = [k for k in missing_keys if 'encoder' not in k.split('.')[0]] | |
missing_keys = [k for k in missing_keys if k.split('.')[0] != 'llm_model'] | |
return _IncompatibleKeys(missing_keys, unexpected_keys) | |
def before_evaluation(self, dataset, task_type, **kwargs): | |
if task_type == MultimodalClassificationTask: | |
self.candidates = dataset.classnames | |
print(self.candidates) |