# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Facebook, Inc. All Rights Reserved import torch from torch import nn try: from transformers.modeling_bert import ( BertPreTrainedModel, BertModel, BertEncoder, BertPredictionHeadTransform, ) except ImportError: pass from ..modules import VideoTokenMLP, MMBertEmbeddings # --------------- fine-tuning models --------------- class MMBertForJoint(BertPreTrainedModel): """A BertModel with isolated attention mask to separate modality.""" def __init__(self, config): super().__init__(config) self.videomlp = VideoTokenMLP(config) self.bert = MMBertModel(config) self.init_weights() def forward( self, input_ids=None, input_video_embeds=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, next_sentence_label=None, output_attentions=None, output_hidden_states=None, return_dict=None, separate_forward_split=None, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) video_tokens = self.videomlp(input_video_embeds) outputs = self.bert( input_ids, video_tokens, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, separate_forward_split=separate_forward_split, ) return outputs class MMBertForTokenClassification(BertPreTrainedModel): """A BertModel similar to MMJointUni, with extra wrapper layer to be fine-tuned from other pretrained MMFusion model.""" def __init__(self, config): super().__init__(config) self.videomlp = VideoTokenMLP(config) self.bert = MMBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) # TODO(huxu): 779 is the number of classes for COIN: move to config? self.classifier = nn.Linear(config.hidden_size, 779) self.init_weights() def forward( self, input_ids=None, input_video_embeds=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, next_sentence_label=None, output_attentions=None, output_hidden_states=None, return_dict=None, separate_forward_split=None, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) video_tokens = self.videomlp(input_video_embeds) outputs = self.bert( input_ids, video_tokens, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, separate_forward_split=separate_forward_split, ) return (self.classifier(outputs[0]),) # ------------ pre-training models ---------------- class MMBertForEncoder(BertPreTrainedModel): """A BertModel for Contrastive Learning.""" def __init__(self, config): super().__init__(config) self.videomlp = VideoTokenMLP(config) self.bert = MMBertModel(config) self.init_weights() def forward( self, input_ids=None, input_video_embeds=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_video_embeds is not None: video_tokens = self.videomlp(input_video_embeds) else: video_tokens = None outputs = self.bert( input_ids, video_tokens, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs class MMBertForMFMMLM(BertPreTrainedModel): """A BertModel with shared prediction head on MFM-MLM.""" def __init__(self, config): super().__init__(config) self.videomlp = VideoTokenMLP(config) self.bert = MMBertModel(config) self.cls = MFMMLMHead(config) self.hidden_size = config.hidden_size self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def forward( self, input_ids=None, input_video_embeds=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, masked_frame_labels=None, target_video_hidden_states=None, non_masked_frame_mask=None, masked_lm_labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_video_embeds is not None: video_tokens = self.videomlp(input_video_embeds) else: video_tokens = None if target_video_hidden_states is not None: target_video_hidden_states = self.videomlp( target_video_hidden_states) non_masked_frame_hidden_states = video_tokens.masked_select( non_masked_frame_mask.unsqueeze(-1) ).view(-1, self.hidden_size) outputs = self.bert( input_ids, video_tokens, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] mfm_scores, prediction_scores = None, None if masked_frame_labels is not None and masked_lm_labels is not None: # split the sequence. text_offset = masked_frame_labels.size(1) + 1 # [CLS] video_sequence_output = sequence_output[ :, 1:text_offset ] # remove [SEP] as not in video_label. text_sequence_output = torch.cat( [sequence_output[:, :1], sequence_output[:, text_offset:]], dim=1 ) hidden_size = video_sequence_output.size(-1) selected_video_output = video_sequence_output.masked_select( masked_frame_labels.unsqueeze(-1) ).view(-1, hidden_size) # only compute select tokens to training to speed up. hidden_size = text_sequence_output.size(-1) # masked_lm_labels = masked_lm_labels.reshape(-1) labels_mask = masked_lm_labels != -100 selected_text_output = text_sequence_output.masked_select( labels_mask.unsqueeze(-1) ).view(-1, hidden_size) mfm_scores, prediction_scores = self.cls( selected_video_output, target_video_hidden_states, non_masked_frame_hidden_states, selected_text_output, ) output = ( mfm_scores, prediction_scores, ) + outputs return output class BertMFMMLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear( config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly # resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward( self, video_hidden_states=None, target_video_hidden_states=None, non_masked_frame_hidden_states=None, text_hidden_states=None, ): video_logits, text_logits = None, None if video_hidden_states is not None: video_hidden_states = self.transform(video_hidden_states) non_masked_frame_logits = torch.mm( video_hidden_states, non_masked_frame_hidden_states.transpose(1, 0) ) masked_frame_logits = torch.bmm( video_hidden_states.unsqueeze(1), target_video_hidden_states.unsqueeze(-1), ).squeeze(-1) video_logits = torch.cat( [masked_frame_logits, non_masked_frame_logits], dim=1 ) if text_hidden_states is not None: text_hidden_states = self.transform(text_hidden_states) text_logits = self.decoder(text_hidden_states) return video_logits, text_logits class MFMMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertMFMMLMPredictionHead(config) def forward( self, video_hidden_states=None, target_video_hidden_states=None, non_masked_frame_hidden_states=None, text_hidden_states=None, ): video_logits, text_logits = self.predictions( video_hidden_states, target_video_hidden_states, non_masked_frame_hidden_states, text_hidden_states, ) return video_logits, text_logits class MMBertForMTM(MMBertForMFMMLM): def __init__(self, config): BertPreTrainedModel.__init__(self, config) self.videomlp = VideoTokenMLP(config) self.bert = MMBertModel(config) self.cls = MTMHead(config) self.hidden_size = config.hidden_size self.init_weights() class BertMTMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) self.decoder = nn.Linear( config.hidden_size, config.vocab_size, bias=False) def forward( self, video_hidden_states=None, target_video_hidden_states=None, non_masked_frame_hidden_states=None, text_hidden_states=None, ): non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0) video_logits, text_logits = None, None if video_hidden_states is not None: video_hidden_states = self.transform(video_hidden_states) masked_frame_logits = torch.bmm( video_hidden_states.unsqueeze(1), target_video_hidden_states.unsqueeze(-1), ).squeeze(-1) non_masked_frame_logits = torch.mm( video_hidden_states, non_masked_frame_hidden_states ) video_on_vocab_logits = self.decoder(video_hidden_states) video_logits = torch.cat([ masked_frame_logits, non_masked_frame_logits, video_on_vocab_logits], dim=1) if text_hidden_states is not None: text_hidden_states = self.transform(text_hidden_states) # text first so label does not need to be shifted. text_on_vocab_logits = self.decoder(text_hidden_states) text_on_video_logits = torch.mm( text_hidden_states, non_masked_frame_hidden_states ) text_logits = torch.cat([ text_on_vocab_logits, text_on_video_logits ], dim=1) return video_logits, text_logits class MTMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertMTMPredictionHead(config) def forward( self, video_hidden_states=None, target_video_hidden_states=None, non_masked_frame_hidden_states=None, text_hidden_states=None, ): video_logits, text_logits = self.predictions( video_hidden_states, target_video_hidden_states, non_masked_frame_hidden_states, text_hidden_states, ) return video_logits, text_logits class MMBertModel(BertModel): """MMBertModel has MMBertEmbedding to support video tokens.""" def __init__(self, config, add_pooling_layer=True): super().__init__(config) # overwrite embedding self.embeddings = MMBertEmbeddings(config) self.encoder = MultiLayerAttentionMaskBertEncoder(config) self.init_weights() def forward( self, input_ids=None, input_video_embeds=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, separate_forward_split=None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids " "and inputs_embeds at the same time" ) elif input_ids is not None: if input_video_embeds is not None: input_shape = ( input_ids.size(0), input_ids.size(1) + input_video_embeds.size(1), ) else: input_shape = ( input_ids.size(0), input_ids.size(1), ) elif inputs_embeds is not None: if input_video_embeds is not None: input_shape = ( inputs_embeds.size(0), inputs_embeds.size(1) + input_video_embeds.size(1), ) else: input_shape = ( input_ids.size(0), input_ids.size(1), ) else: raise ValueError( "You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None \ else inputs_embeds.device if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: token_type_ids = torch.zeros( input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions # [batch_size, from_seq_length, to_seq_length] # ourselves in which case # we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = \ self.get_extended_attention_mask( attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to # [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = ( encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones( encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or # [num_hidden_layers x num_heads] # and head_mask is converted to shape # [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask( head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids, input_video_embeds, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, ) if separate_forward_split is not None: split_embedding_output = \ embedding_output[:, :separate_forward_split] split_extended_attention_mask = extended_attention_mask[ :, :, :, :separate_forward_split, :separate_forward_split ] split_encoder_outputs = self.encoder( split_embedding_output, attention_mask=split_extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) assert ( len(split_encoder_outputs) <= 2 ), "we do not support merge on attention for now." encoder_outputs = [] encoder_outputs.append([split_encoder_outputs[0]]) if len(split_encoder_outputs) == 2: encoder_outputs.append([]) for _all_hidden_states in split_encoder_outputs[1]: encoder_outputs[-1].append([_all_hidden_states]) split_embedding_output = \ embedding_output[:, separate_forward_split:] split_extended_attention_mask = extended_attention_mask[ :, :, :, separate_forward_split:, separate_forward_split: ] split_encoder_outputs = self.encoder( split_embedding_output, attention_mask=split_extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) assert ( len(split_encoder_outputs) <= 2 ), "we do not support merge on attention for now." encoder_outputs[0].append(split_encoder_outputs[0]) encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1) if len(split_encoder_outputs) == 2: for layer_idx, _all_hidden_states in enumerate( split_encoder_outputs[1] ): encoder_outputs[1][layer_idx].append(_all_hidden_states) encoder_outputs[1][layer_idx] = torch.cat( encoder_outputs[1][layer_idx], dim=1 ) encoder_outputs = tuple(encoder_outputs) else: encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] pooled_output = ( self.pooler(sequence_output) if self.pooler is not None else None ) return (sequence_output, pooled_output) + encoder_outputs[1:] def get_extended_attention_mask(self, attention_mask, input_shape, device): """This is borrowed from `modeling_utils.py` with the support of multi-layer attention masks. The second dim is expected to be number of layers. See `MMAttentionMaskProcessor`. Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, \ with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions # [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable # to all heads. if attention_mask.dim() == 4: extended_attention_mask = attention_mask[:, :, None, :, :] extended_attention_mask = extended_attention_mask.to( dtype=self.dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) \ * -10000.0 return extended_attention_mask else: return super().get_extended_attention_mask( attention_mask, input_shape, device ) class MultiLayerAttentionMaskBertEncoder(BertEncoder): """extend BertEncoder with the capability of multiple layers of attention mask.""" def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False, ): all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None layer_attention_mask = ( attention_mask[:, i, :, :, :] if attention_mask.dim() == 5 else attention_mask ) if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, layer_attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, layer_attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return tuple( v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None )