# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Facebook, Inc. All Rights Reserved import torch from torch import nn try: from transformers import AutoConfig, AutoTokenizer except ImportError: pass from . import transformermodel class MMPTModel(nn.Module): """An e2e wrapper of inference model. """ @classmethod def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"): import os from ..utils import recursive_config from ..tasks import Task config = recursive_config(config) mmtask = Task.config_task(config) checkpoint_path = os.path.join(config.eval.save_path, checkpoint) mmtask.build_model(checkpoint=checkpoint_path) # TODO(huxu): make the video encoder configurable. from ..processors.models.s3dg import S3D video_encoder = S3D('pretrained_models/s3d_dict.npy', 512) video_encoder.load_state_dict( torch.load('pretrained_models/s3d_howto100m.pth')) from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( config.dataset.bert_name, use_fast=config.dataset.use_fast ) from ..processors import Aligner aligner = Aligner(config.dataset) return ( MMPTModel(config, mmtask.model, video_encoder), tokenizer, aligner ) def __init__(self, config, model, video_encoder, **kwargs): super().__init__() self.max_video_len = config.dataset.max_video_len self.video_encoder = video_encoder self.model = model def forward(self, video_frames, caps, cmasks, return_score=False): bsz = video_frames.size(0) assert bsz == 1, "only bsz=1 is supported now." seq_len = video_frames.size(1) video_frames = video_frames.view(-1, *video_frames.size()[2:]) vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3)) vfeats = vfeats['video_embedding'] vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1)) padding = torch.zeros( bsz, self.max_video_len - seq_len, vfeats.size(-1)) vfeats = torch.cat([vfeats, padding], dim=1) vmasks = torch.cat([ torch.ones((bsz, seq_len), dtype=torch.bool), torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool) ], dim=1 ) output = self.model(caps, cmasks, vfeats, vmasks) if return_score: output = {"score": torch.bmm( output["pooled_video"][:, None, :], output["pooled_text"][:, :, None] ).squeeze(-1).squeeze(-1)} return output class MMFusion(nn.Module): """a MMPT wrapper class for MMBert style models. TODO: move isolated mask to a subclass. """ def __init__(self, config, **kwargs): super().__init__() transformer_config = AutoConfig.from_pretrained( config.dataset.bert_name) self.hidden_size = transformer_config.hidden_size self.is_train = False if config.dataset.train_path is not None: self.is_train = True # 0 means no iso; 1-12 means iso up to that layer. self.num_hidden_layers = transformer_config.num_hidden_layers self.last_iso_layer = 0 if config.dataset.num_iso_layer is not None: self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1 if config.model.mm_encoder_cls is not None: mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls) model_config = AutoConfig.from_pretrained(config.dataset.bert_name) model_config.max_video_len = config.dataset.max_video_len # TODO: a general way to add parameter for a model. model_config.use_seg_emb = config.model.use_seg_emb self.mm_encoder = mm_encoder_cls.from_pretrained( config.dataset.bert_name, config=model_config) elif config.model.video_encoder_cls is not None\ and config.model.text_encoder_cls is not None: video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls) model_config = AutoConfig.from_pretrained(config.dataset.bert_name) model_config.max_video_len = config.dataset.max_video_len # TODO: make each model a set of config class. if hasattr(model_config, "num_layers"): model_config.num_layers = config.model.num_hidden_video_layers else: model_config.num_hidden_layers = config.model.num_hidden_video_layers self.video_encoder = video_encoder_cls.from_pretrained( config.dataset.bert_name, config=model_config) # exact same NLP model from Huggingface. text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls) self.text_encoder = text_encoder_cls.from_pretrained( config.dataset.bert_name) else: raise ValueError("the encoder must be either MM or two backbones.") def forward( self, caps, cmasks, vfeats, vmasks, **kwargs ): raise NotImplementedError( "Please derive MMFusion module." ) def _mm_on_the_fly( self, cmasks, vmasks, attention_mask ): """helper function for mask, seg_ids and token_type_ids.""" if attention_mask is None: attention_mask = self._mm_attention_mask(cmasks, vmasks) """ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | """ token_type_ids = torch.cat( [ torch.zeros( (vmasks.size(0), vmasks.size(1) + 2), dtype=torch.long, device=vmasks.device, ), torch.ones( (cmasks.size(0), cmasks.size(1) - 2), dtype=torch.long, device=cmasks.device, ), ], dim=1, ) return attention_mask, token_type_ids def _mm_attention_mask(self, cmasks, vmasks): assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format( str(cmasks.size()), str(vmasks.size()), str(cmasks.size(0)), str(vmasks.size(0)), ) mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1) if self.last_iso_layer == 0: # hard attention mask. return mm_mask else: # a gpu iso mask; 0 : num_iso_layer is isolated; # num_iso_layer: are MM-fused. # make an iso layer batch_size = cmasks.size(0) iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks) mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1) iso_mm_masks = [] # hard attention mask. iso_mask = iso_mask[:, None, :, :].repeat( 1, self.last_iso_layer, 1, 1) iso_mm_masks.append(iso_mask) if self.last_iso_layer < self.num_hidden_layers: mm_mask = mm_mask[:, None, :, :].repeat( 1, self.num_hidden_layers - self.last_iso_layer, 1, 1 ) iso_mm_masks.append(mm_mask) iso_mm_masks = torch.cat(iso_mm_masks, dim=1) return iso_mm_masks def _make_iso_mask(self, batch_size, cmasks, vmasks): cls_self_mask = torch.cat( [ torch.ones( (batch_size, 1), dtype=torch.bool, device=cmasks.device), torch.zeros( (batch_size, cmasks.size(1) + vmasks.size(1) - 1), dtype=torch.bool, device=cmasks.device) ], dim=1) iso_video_mask = torch.cat( [ # [CLS] is not used. torch.zeros( (batch_size, 1), dtype=torch.bool, device=cmasks.device ), vmasks, # assume to be 1. cmasks[:, 1:2], # 2 means [CLS] + [SEP] torch.zeros( (batch_size, cmasks.size(1) - 2), dtype=torch.bool, device=cmasks.device, ), ], dim=1, ) iso_text_mask = torch.cat( [ torch.zeros( (batch_size, 2 + vmasks.size(1)), dtype=torch.bool, device=cmasks.device, ), # [CLS] is not used. cmasks[:, 2:], # assume to be 1. ], dim=1, ) cls_self_mask = cls_self_mask[:, None, :] iso_video_mask = iso_video_mask[:, None, :].repeat( 1, vmasks.size(1) + 1, 1) iso_text_mask = iso_text_mask[:, None, :].repeat( 1, cmasks.size(1) - 2, 1) return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1) def _pooling_vt_layer( self, layered_sequence_output, cmasks, vmasks ): layer_idx = self.last_iso_layer \ if self.last_iso_layer > 0 else self.num_hidden_layers hidden_state = layered_sequence_output[layer_idx] # also output pooled_video and pooled_text. batch_size = cmasks.size(0) # pool the modality. text_offset = vmasks.size(1) + 2 # [CLS] + [SEP] # video tokens + [SEP] video_outputs = hidden_state[:, 1:text_offset] video_attention_mask = torch.cat( [ vmasks, torch.ones( (batch_size, 1), dtype=torch.bool, device=vmasks.device), ], dim=1, ) assert video_outputs.size(1) == video_attention_mask.size(1) pooled_video = torch.sum( video_outputs * video_attention_mask.unsqueeze(-1), dim=1 ) / video_attention_mask.sum(1, keepdim=True) # pooled_video = torch.mean(video_outputs[0], dim=1) # text tokens + [SEP] text_attention_mask = cmasks[:, 2:] text_outputs = hidden_state[:, text_offset:] assert text_outputs.size(1) == text_attention_mask.size(1) pooled_text = torch.sum( text_outputs * text_attention_mask.unsqueeze(-1), dim=1 ) / text_attention_mask.sum(1, keepdim=True) return pooled_video, pooled_text class MMFusionMFMMLM(MMFusion): """forward function for MFM and MLM.""" def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, video_label=None, text_label=None, **kwargs ): output_hidden_states = False if self.is_train else True target_vfeats, non_masked_frame_mask = None, None if video_label is not None: target_vfeats = vfeats.masked_select( video_label.unsqueeze(-1)).view( -1, vfeats.size(-1) ) # mask video token. vfeats[video_label] = 0.0 non_masked_frame_mask = vmasks.clone() non_masked_frame_mask[video_label] = False attention_mask, token_type_ids = self._mm_on_the_fly( cmasks, vmasks, attention_mask) outputs = self.mm_encoder( input_ids=caps, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, masked_frame_labels=video_label, target_video_hidden_states=target_vfeats, non_masked_frame_mask=non_masked_frame_mask, masked_lm_labels=text_label, output_hidden_states=output_hidden_states, ) video_logits, text_logits = outputs[0], outputs[1] if self.is_train: # return earlier for training. return { "video_logits": video_logits, "text_logits": text_logits, } pooled_video, pooled_text = self._pooling_vt_layer( outputs[2], cmasks, vmasks) return {"pooled_video": pooled_video, "pooled_text": pooled_text} class MMFusionMTM(MMFusionMFMMLM): def __init__(self, config, **kwargs): super().__init__(config) """ For reproducibility: self.mm_encoder will be initialized then discarded. """ from .transformermodel import MMBertForMTM model_config = AutoConfig.from_pretrained(config.dataset.bert_name) model_config.max_video_len = config.dataset.max_video_len model_config.use_seg_emb = config.model.use_seg_emb self.mm_encoder = MMBertForMTM.from_pretrained( config.dataset.bert_name, config=model_config) class MMFusionShare(MMFusion): """A retrival wrapper using mm_encoder as both video/text backbone. TODO: move formally. """ def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, video_label=None, text_label=None, output_hidden_states=False, **kwargs ): pooled_video = self.forward_video( vfeats, vmasks, caps, cmasks, output_hidden_states ) pooled_text = self.forward_text( caps, cmasks, output_hidden_states ) return {"pooled_video": pooled_video, "pooled_text": pooled_text} def forward_video( self, vfeats, vmasks, caps, cmasks, output_hidden_states=False, **kwargs ): input_ids = caps[:, :2] attention_mask = torch.cat([ cmasks[:, :1], vmasks, cmasks[:, 1:2] ], dim=1) token_type_ids = torch.zeros( (vmasks.size(0), vmasks.size(1) + 2), dtype=torch.long, device=vmasks.device) outputs = self.mm_encoder( input_ids=input_ids, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True ) video_outputs = outputs[0] if output_hidden_states: return video_outputs batch_size = cmasks.size(0) video_attention_mask = torch.cat( [ torch.zeros( (batch_size, 1), dtype=torch.bool, device=vmasks.device), vmasks, torch.ones( (batch_size, 1), dtype=torch.bool, device=vmasks.device), ], dim=1, ) assert video_outputs.size(1) == video_attention_mask.size(1) video_attention_mask = video_attention_mask.type(video_outputs.dtype) \ / video_attention_mask.sum(1, keepdim=True) pooled_video = torch.bmm( video_outputs.transpose(2, 1), video_attention_mask.unsqueeze(2) ).squeeze(-1) return pooled_video # video_outputs def forward_text( self, caps, cmasks, output_hidden_states=False, **kwargs ): input_ids = torch.cat([ caps[:, :1], caps[:, 2:], ], dim=1) attention_mask = torch.cat([ cmasks[:, :1], cmasks[:, 2:] ], dim=1) token_type_ids = torch.cat([ torch.zeros( (cmasks.size(0), 1), dtype=torch.long, device=cmasks.device), torch.ones( (cmasks.size(0), cmasks.size(1) - 2), dtype=torch.long, device=cmasks.device) ], dim=1) outputs = self.mm_encoder( input_ids=input_ids, input_video_embeds=None, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True ) text_outputs = outputs[0] if output_hidden_states: return text_outputs batch_size = caps.size(0) # text tokens + [SEP] text_attention_mask = torch.cat([ torch.zeros( (batch_size, 1), dtype=torch.bool, device=cmasks.device), cmasks[:, 2:] ], dim=1) assert text_outputs.size(1) == text_attention_mask.size(1) text_attention_mask = text_attention_mask.type(text_outputs.dtype) \ / text_attention_mask.sum(1, keepdim=True) pooled_text = torch.bmm( text_outputs.transpose(2, 1), text_attention_mask.unsqueeze(2) ).squeeze(-1) return pooled_text # text_outputs class MMFusionSeparate(MMFusionShare): def forward_video( self, vfeats, vmasks, caps, cmasks, output_hidden_states=False, **kwargs ): input_ids = caps[:, :2] attention_mask = torch.cat([ cmasks[:, :1], vmasks, cmasks[:, 1:2] ], dim=1) token_type_ids = torch.zeros( (vmasks.size(0), vmasks.size(1) + 2), dtype=torch.long, device=vmasks.device) outputs = self.video_encoder( input_ids=input_ids, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True ) video_outputs = outputs[0] if output_hidden_states: return video_outputs batch_size = cmasks.size(0) video_attention_mask = torch.cat( [ torch.zeros( (batch_size, 1), dtype=torch.bool, device=vmasks.device), vmasks, torch.ones( (batch_size, 1), dtype=torch.bool, device=vmasks.device), ], dim=1, ) assert video_outputs.size(1) == video_attention_mask.size(1) video_attention_mask = video_attention_mask.type(video_outputs.dtype) \ / video_attention_mask.sum(1, keepdim=True) pooled_video = torch.bmm( video_outputs.transpose(2, 1), video_attention_mask.unsqueeze(2) ).squeeze(-1) return pooled_video # video_outputs def forward_text( self, caps, cmasks, output_hidden_states=False, **kwargs ): input_ids = torch.cat([ caps[:, :1], caps[:, 2:], ], dim=1) attention_mask = torch.cat([ cmasks[:, :1], cmasks[:, 2:] ], dim=1) # different from sharing, we use all-0 type. token_type_ids = torch.zeros( (cmasks.size(0), cmasks.size(1) - 1), dtype=torch.long, device=cmasks.device) outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True ) text_outputs = outputs[0] if output_hidden_states: return text_outputs batch_size = caps.size(0) # text tokens + [SEP] text_attention_mask = torch.cat([ torch.zeros( (batch_size, 1), dtype=torch.bool, device=cmasks.device), cmasks[:, 2:] ], dim=1) assert text_outputs.size(1) == text_attention_mask.size(1) text_attention_mask = text_attention_mask.type(text_outputs.dtype) \ / text_attention_mask.sum(1, keepdim=True) pooled_text = torch.bmm( text_outputs.transpose(2, 1), text_attention_mask.unsqueeze(2) ).squeeze(-1) return pooled_text # text_outputs class MMFusionJoint(MMFusion): """fine-tuning wrapper for retrival task.""" def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, video_label=None, text_label=None, **kwargs ): # TODO (huxu): other ways to do negative examples; move the following # into your criterion forward. output_hidden_states = True attention_mask, token_type_ids = self._mm_on_the_fly( cmasks, vmasks, attention_mask) separate_forward_split = ( None if self.is_train else vmasks.size(1) + 2 ) # [CLS] + [SEP] outputs = self.mm_encoder( input_ids=caps, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=output_hidden_states, separate_forward_split=separate_forward_split, ) pooled_video, pooled_text = self._pooling_vt_layer( outputs[2], cmasks, vmasks) return {"pooled_video": pooled_video, "pooled_text": pooled_text} class MMFusionActionSegmentation(MMFusion): """Fine-tuning wrapper for action segmentation. TODO: rename this for VLM. """ def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, **kwargs ): # ActionLocalization assume of batch_size=1, squeeze it. caps = caps.view(-1, caps.size(-1)) cmasks = cmasks.view(-1, cmasks.size(-1)) vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3)) vmasks = vmasks.view(-1, vmasks.size(-1)) # this may not cover all shapes of attention_mask. attention_mask = attention_mask.view( -1, attention_mask.size(2), attention_mask.size(3)) \ if attention_mask is not None else None # TODO (huxu): other ways to do negative examples; move the following # into your criterion forward. output_hidden_states = True # video forwarding, text is dummy; never use attention_mask. attention_mask, token_type_ids = self._mm_on_the_fly( cmasks, vmasks, attention_mask) logits = self.mm_encoder( input_ids=caps, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=output_hidden_states, ) return {"logits": logits[0][:, 1:vmasks.size(1)+1]} class MMFusionActionLocalization(MMFusion): """fine-tuning model for retrival task.""" def __init__(self, config, **kwargs): super().__init__(config) tokenizer = AutoTokenizer.from_pretrained( config.dataset.bert_name) self.cls_token_id = tokenizer.cls_token_id self.sep_token_id = tokenizer.sep_token_id self.pad_token_id = tokenizer.pad_token_id def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, **kwargs ): # ActionLocalization assume of batch_size=1, squeeze it. caps = caps.squeeze(0) cmasks = cmasks.squeeze(0) vfeats = vfeats.squeeze(0) vmasks = vmasks.squeeze(0) attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None # TODO (huxu): other ways to do negative examples; move the following # into your criterion forward. output_hidden_states = True # a len1 dummy video token. dummy_vfeats = torch.zeros( (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype) dummy_vmasks = torch.ones( (caps.size(0), 1), dtype=torch.bool, device=vfeats.device) dummy_caps = torch.LongTensor( [[self.cls_token_id, self.sep_token_id, self.pad_token_id, self.sep_token_id]], ).to(caps.device).repeat(vfeats.size(0), 1) dummy_cmasks = torch.BoolTensor( [[0, 1, 0, 1]] # pad are valid for attention. ).to(caps.device).repeat(vfeats.size(0), 1) # video forwarding, text is dummy; never use attention_mask. attention_mask, token_type_ids = self._mm_on_the_fly( dummy_cmasks, vmasks, None) outputs = self.mm_encoder( input_ids=dummy_caps, input_video_embeds=vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=output_hidden_states, ) layer_idx = self.last_iso_layer \ if self.last_iso_layer > 0 else self.num_hidden_layers video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select( vmasks.unsqueeze(-1) ).view(-1, self.hidden_size) # text forwarding, video is dummy attention_mask, token_type_ids = self._mm_on_the_fly( cmasks, dummy_vmasks, None) outputs = self.mm_encoder( input_ids=caps, input_video_embeds=dummy_vfeats, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=output_hidden_states, ) _, pooled_text = self._pooling_vt_layer( outputs[2], cmasks, dummy_vmasks) # this line is not right. logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) return {"logits": logits} # --------------- MMFusionSeparate for end tasks --------------- class MMFusionSeparateActionSegmentation(MMFusionSeparate): """Fine-tuning wrapper for action segmentation.""" def forward( self, caps, cmasks, vfeats, vmasks, attention_mask=None, **kwargs ): # ActionLocalization assume of batch_size=1, squeeze it. caps = caps.view(-1, caps.size(-1)) cmasks = cmasks.view(-1, cmasks.size(-1)) vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3)) vmasks = vmasks.view(-1, vmasks.size(-1)) logits = self.forward_video( vfeats, vmasks, caps, cmasks, output_hidden_states=True ) return {"logits": logits[:, 1:vmasks.size(1)+1]} class MMFusionSeparateActionLocalization(MMFusionSeparate): def __init__(self, config, **kwargs): super().__init__(config) tokenizer = AutoTokenizer.from_pretrained( config.dataset.bert_name) self.cls_token_id = tokenizer.cls_token_id self.sep_token_id = tokenizer.sep_token_id self.pad_token_id = tokenizer.pad_token_id def forward( self, caps, cmasks, vfeats, vmasks, **kwargs ): # ActionLocalization assume of batch_size=1, squeeze it. caps = caps.squeeze(0) cmasks = cmasks.squeeze(0) vfeats = vfeats.squeeze(0) vmasks = vmasks.squeeze(0) # TODO (huxu): other ways to do negative examples; move the following # into your criterion forward. dummy_caps = torch.LongTensor( [[self.cls_token_id, self.sep_token_id, self.pad_token_id, self.sep_token_id]], ).to(caps.device).repeat(vfeats.size(0), 1) dummy_cmasks = torch.BoolTensor( [[0, 1, 0, 1]] # pad are valid for attention. ).to(caps.device).repeat(vfeats.size(0), 1) outputs = self.forward_video( vfeats, vmasks, dummy_caps, dummy_cmasks, output_hidden_states=True ) video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select( vmasks.unsqueeze(-1) ).view(-1, self.hidden_size) pooled_text = self.forward_text( caps, cmasks, output_hidden_states=False ) # this line is not right. logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) return {"logits": logits} class MMFusionShareActionLocalization(MMFusionShare): def __init__(self, config, **kwargs): super().__init__(config) tokenizer = AutoTokenizer.from_pretrained( config.dataset.bert_name) self.cls_token_id = tokenizer.cls_token_id self.sep_token_id = tokenizer.sep_token_id self.pad_token_id = tokenizer.pad_token_id def forward( self, caps, cmasks, vfeats, vmasks, **kwargs ): # ActionLocalization assume of batch_size=1, squeeze it. caps = caps.squeeze(0) cmasks = cmasks.squeeze(0) vfeats = vfeats.squeeze(0) vmasks = vmasks.squeeze(0) # TODO (huxu): other ways to do negative examples; move the following # into your criterion forward. dummy_caps = torch.LongTensor( [[self.cls_token_id, self.sep_token_id, self.pad_token_id, self.sep_token_id]], ).to(caps.device).repeat(vfeats.size(0), 1) dummy_cmasks = torch.BoolTensor( [[0, 1, 0, 1]] # pad are valid for attention. ).to(caps.device).repeat(vfeats.size(0), 1) outputs = self.forward_video( vfeats, vmasks, dummy_caps, dummy_cmasks, output_hidden_states=True ) video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select( vmasks.unsqueeze(-1) ).view(-1, self.hidden_size) pooled_text = self.forward_text( caps, cmasks, output_hidden_states=False ) # this line is not right. logits = torch.mm(video_seq, pooled_text.transpose(1, 0)) return {"logits": logits}