PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
29c9ba5 verified
raw
history blame
30.6 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
try:
from transformers import AutoConfig, AutoTokenizer
except ImportError:
pass
from . import transformermodel
class MMPTModel(nn.Module):
"""An e2e wrapper of inference model.
"""
@classmethod
def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
import os
from ..utils import recursive_config
from ..tasks import Task
config = recursive_config(config)
mmtask = Task.config_task(config)
checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
mmtask.build_model(checkpoint=checkpoint_path)
# TODO(huxu): make the video encoder configurable.
from ..processors.models.s3dg import S3D
video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
video_encoder.load_state_dict(
torch.load('pretrained_models/s3d_howto100m.pth'))
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name, use_fast=config.dataset.use_fast
)
from ..processors import Aligner
aligner = Aligner(config.dataset)
return (
MMPTModel(config, mmtask.model, video_encoder),
tokenizer,
aligner
)
def __init__(self, config, model, video_encoder, **kwargs):
super().__init__()
self.max_video_len = config.dataset.max_video_len
self.video_encoder = video_encoder
self.model = model
def forward(self, video_frames, caps, cmasks, return_score=False):
bsz = video_frames.size(0)
assert bsz == 1, "only bsz=1 is supported now."
seq_len = video_frames.size(1)
video_frames = video_frames.view(-1, *video_frames.size()[2:])
vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
vfeats = vfeats['video_embedding']
vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
padding = torch.zeros(
bsz, self.max_video_len - seq_len, vfeats.size(-1))
vfeats = torch.cat([vfeats, padding], dim=1)
vmasks = torch.cat([
torch.ones((bsz, seq_len), dtype=torch.bool),
torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
],
dim=1
)
output = self.model(caps, cmasks, vfeats, vmasks)
if return_score:
output = {"score": torch.bmm(
output["pooled_video"][:, None, :],
output["pooled_text"][:, :, None]
).squeeze(-1).squeeze(-1)}
return output
class MMFusion(nn.Module):
"""a MMPT wrapper class for MMBert style models.
TODO: move isolated mask to a subclass.
"""
def __init__(self, config, **kwargs):
super().__init__()
transformer_config = AutoConfig.from_pretrained(
config.dataset.bert_name)
self.hidden_size = transformer_config.hidden_size
self.is_train = False
if config.dataset.train_path is not None:
self.is_train = True
# 0 means no iso; 1-12 means iso up to that layer.
self.num_hidden_layers = transformer_config.num_hidden_layers
self.last_iso_layer = 0
if config.dataset.num_iso_layer is not None:
self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
if config.model.mm_encoder_cls is not None:
mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
# TODO: a general way to add parameter for a model.
model_config.use_seg_emb = config.model.use_seg_emb
self.mm_encoder = mm_encoder_cls.from_pretrained(
config.dataset.bert_name, config=model_config)
elif config.model.video_encoder_cls is not None\
and config.model.text_encoder_cls is not None:
video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
# TODO: make each model a set of config class.
if hasattr(model_config, "num_layers"):
model_config.num_layers = config.model.num_hidden_video_layers
else:
model_config.num_hidden_layers = config.model.num_hidden_video_layers
self.video_encoder = video_encoder_cls.from_pretrained(
config.dataset.bert_name, config=model_config)
# exact same NLP model from Huggingface.
text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
self.text_encoder = text_encoder_cls.from_pretrained(
config.dataset.bert_name)
else:
raise ValueError("the encoder must be either MM or two backbones.")
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
raise NotImplementedError(
"Please derive MMFusion module."
)
def _mm_on_the_fly(
self,
cmasks,
vmasks,
attention_mask
):
"""helper function for mask, seg_ids and token_type_ids."""
if attention_mask is None:
attention_mask = self._mm_attention_mask(cmasks, vmasks)
"""
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
"""
token_type_ids = torch.cat(
[
torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device,
),
torch.ones(
(cmasks.size(0), cmasks.size(1) - 2),
dtype=torch.long,
device=cmasks.device,
),
],
dim=1,
)
return attention_mask, token_type_ids
def _mm_attention_mask(self, cmasks, vmasks):
assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
str(cmasks.size()),
str(vmasks.size()),
str(cmasks.size(0)),
str(vmasks.size(0)),
)
mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
if self.last_iso_layer == 0:
# hard attention mask.
return mm_mask
else:
# a gpu iso mask; 0 : num_iso_layer is isolated;
# num_iso_layer: are MM-fused.
# make an iso layer
batch_size = cmasks.size(0)
iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
iso_mm_masks = []
# hard attention mask.
iso_mask = iso_mask[:, None, :, :].repeat(
1, self.last_iso_layer, 1, 1)
iso_mm_masks.append(iso_mask)
if self.last_iso_layer < self.num_hidden_layers:
mm_mask = mm_mask[:, None, :, :].repeat(
1, self.num_hidden_layers - self.last_iso_layer, 1, 1
)
iso_mm_masks.append(mm_mask)
iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
return iso_mm_masks
def _make_iso_mask(self, batch_size, cmasks, vmasks):
cls_self_mask = torch.cat(
[
torch.ones(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
torch.zeros(
(batch_size, cmasks.size(1) + vmasks.size(1) - 1),
dtype=torch.bool, device=cmasks.device)
], dim=1)
iso_video_mask = torch.cat(
[
# [CLS] is not used.
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device
),
vmasks,
# assume to be 1.
cmasks[:, 1:2],
# 2 means [CLS] + [SEP]
torch.zeros(
(batch_size, cmasks.size(1) - 2),
dtype=torch.bool,
device=cmasks.device,
),
],
dim=1,
)
iso_text_mask = torch.cat(
[
torch.zeros(
(batch_size, 2 + vmasks.size(1)),
dtype=torch.bool,
device=cmasks.device,
), # [CLS] is not used.
cmasks[:, 2:], # assume to be 1.
],
dim=1,
)
cls_self_mask = cls_self_mask[:, None, :]
iso_video_mask = iso_video_mask[:, None, :].repeat(
1, vmasks.size(1) + 1, 1)
iso_text_mask = iso_text_mask[:, None, :].repeat(
1, cmasks.size(1) - 2, 1)
return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
def _pooling_vt_layer(
self,
layered_sequence_output,
cmasks,
vmasks
):
layer_idx = self.last_iso_layer \
if self.last_iso_layer > 0 else self.num_hidden_layers
hidden_state = layered_sequence_output[layer_idx]
# also output pooled_video and pooled_text.
batch_size = cmasks.size(0)
# pool the modality.
text_offset = vmasks.size(1) + 2 # [CLS] + [SEP]
# video tokens + [SEP]
video_outputs = hidden_state[:, 1:text_offset]
video_attention_mask = torch.cat(
[
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
pooled_video = torch.sum(
video_outputs * video_attention_mask.unsqueeze(-1), dim=1
) / video_attention_mask.sum(1, keepdim=True)
# pooled_video = torch.mean(video_outputs[0], dim=1)
# text tokens + [SEP]
text_attention_mask = cmasks[:, 2:]
text_outputs = hidden_state[:, text_offset:]
assert text_outputs.size(1) == text_attention_mask.size(1)
pooled_text = torch.sum(
text_outputs * text_attention_mask.unsqueeze(-1), dim=1
) / text_attention_mask.sum(1, keepdim=True)
return pooled_video, pooled_text
class MMFusionMFMMLM(MMFusion):
"""forward function for MFM and MLM."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
**kwargs
):
output_hidden_states = False if self.is_train else True
target_vfeats, non_masked_frame_mask = None, None
if video_label is not None:
target_vfeats = vfeats.masked_select(
video_label.unsqueeze(-1)).view(
-1, vfeats.size(-1)
)
# mask video token.
vfeats[video_label] = 0.0
non_masked_frame_mask = vmasks.clone()
non_masked_frame_mask[video_label] = False
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
masked_frame_labels=video_label,
target_video_hidden_states=target_vfeats,
non_masked_frame_mask=non_masked_frame_mask,
masked_lm_labels=text_label,
output_hidden_states=output_hidden_states,
)
video_logits, text_logits = outputs[0], outputs[1]
if self.is_train: # return earlier for training.
return {
"video_logits": video_logits,
"text_logits": text_logits,
}
pooled_video, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, vmasks)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
class MMFusionMTM(MMFusionMFMMLM):
def __init__(self, config, **kwargs):
super().__init__(config)
"""
For reproducibility:
self.mm_encoder will be initialized then discarded.
"""
from .transformermodel import MMBertForMTM
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
model_config.use_seg_emb = config.model.use_seg_emb
self.mm_encoder = MMBertForMTM.from_pretrained(
config.dataset.bert_name, config=model_config)
class MMFusionShare(MMFusion):
"""A retrival wrapper using mm_encoder as both video/text backbone.
TODO: move formally.
"""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
output_hidden_states=False,
**kwargs
):
pooled_video = self.forward_video(
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states
)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states
)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
def forward_video(
self,
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = caps[:, :2]
attention_mask = torch.cat([
cmasks[:, :1],
vmasks,
cmasks[:, 1:2]
], dim=1)
token_type_ids = torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device)
outputs = self.mm_encoder(
input_ids=input_ids,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
video_outputs = outputs[0]
if output_hidden_states:
return video_outputs
batch_size = cmasks.size(0)
video_attention_mask = torch.cat(
[
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
/ video_attention_mask.sum(1, keepdim=True)
pooled_video = torch.bmm(
video_outputs.transpose(2, 1),
video_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_video # video_outputs
def forward_text(
self,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = torch.cat([
caps[:, :1], caps[:, 2:],
], dim=1)
attention_mask = torch.cat([
cmasks[:, :1],
cmasks[:, 2:]
], dim=1)
token_type_ids = torch.cat([
torch.zeros(
(cmasks.size(0), 1),
dtype=torch.long,
device=cmasks.device),
torch.ones(
(cmasks.size(0), cmasks.size(1) - 2),
dtype=torch.long,
device=cmasks.device)
], dim=1)
outputs = self.mm_encoder(
input_ids=input_ids,
input_video_embeds=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
text_outputs = outputs[0]
if output_hidden_states:
return text_outputs
batch_size = caps.size(0)
# text tokens + [SEP]
text_attention_mask = torch.cat([
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
cmasks[:, 2:]
], dim=1)
assert text_outputs.size(1) == text_attention_mask.size(1)
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
/ text_attention_mask.sum(1, keepdim=True)
pooled_text = torch.bmm(
text_outputs.transpose(2, 1),
text_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_text # text_outputs
class MMFusionSeparate(MMFusionShare):
def forward_video(
self,
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = caps[:, :2]
attention_mask = torch.cat([
cmasks[:, :1],
vmasks,
cmasks[:, 1:2]
], dim=1)
token_type_ids = torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device)
outputs = self.video_encoder(
input_ids=input_ids,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
video_outputs = outputs[0]
if output_hidden_states:
return video_outputs
batch_size = cmasks.size(0)
video_attention_mask = torch.cat(
[
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
/ video_attention_mask.sum(1, keepdim=True)
pooled_video = torch.bmm(
video_outputs.transpose(2, 1),
video_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_video # video_outputs
def forward_text(
self,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = torch.cat([
caps[:, :1], caps[:, 2:],
], dim=1)
attention_mask = torch.cat([
cmasks[:, :1],
cmasks[:, 2:]
], dim=1)
# different from sharing, we use all-0 type.
token_type_ids = torch.zeros(
(cmasks.size(0), cmasks.size(1) - 1),
dtype=torch.long,
device=cmasks.device)
outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
text_outputs = outputs[0]
if output_hidden_states:
return text_outputs
batch_size = caps.size(0)
# text tokens + [SEP]
text_attention_mask = torch.cat([
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
cmasks[:, 2:]
], dim=1)
assert text_outputs.size(1) == text_attention_mask.size(1)
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
/ text_attention_mask.sum(1, keepdim=True)
pooled_text = torch.bmm(
text_outputs.transpose(2, 1),
text_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_text # text_outputs
class MMFusionJoint(MMFusion):
"""fine-tuning wrapper for retrival task."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
**kwargs
):
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
separate_forward_split = (
None if self.is_train else vmasks.size(1) + 2
) # [CLS] + [SEP]
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
separate_forward_split=separate_forward_split,
)
pooled_video, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, vmasks)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
class MMFusionActionSegmentation(MMFusion):
"""Fine-tuning wrapper for action segmentation.
TODO: rename this for VLM.
"""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.view(-1, caps.size(-1))
cmasks = cmasks.view(-1, cmasks.size(-1))
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
vmasks = vmasks.view(-1, vmasks.size(-1))
# this may not cover all shapes of attention_mask.
attention_mask = attention_mask.view(
-1, attention_mask.size(2), attention_mask.size(3)) \
if attention_mask is not None else None
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
# video forwarding, text is dummy; never use attention_mask.
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
logits = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
class MMFusionActionLocalization(MMFusion):
"""fine-tuning model for retrival task."""
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
# a len1 dummy video token.
dummy_vfeats = torch.zeros(
(caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
dummy_vmasks = torch.ones(
(caps.size(0), 1), dtype=torch.bool,
device=vfeats.device)
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
# video forwarding, text is dummy; never use attention_mask.
attention_mask, token_type_ids = self._mm_on_the_fly(
dummy_cmasks, vmasks, None)
outputs = self.mm_encoder(
input_ids=dummy_caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
layer_idx = self.last_iso_layer \
if self.last_iso_layer > 0 else self.num_hidden_layers
video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
# text forwarding, video is dummy
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, dummy_vmasks, None)
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=dummy_vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
_, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, dummy_vmasks)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}
# --------------- MMFusionSeparate for end tasks ---------------
class MMFusionSeparateActionSegmentation(MMFusionSeparate):
"""Fine-tuning wrapper for action segmentation."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.view(-1, caps.size(-1))
cmasks = cmasks.view(-1, cmasks.size(-1))
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
vmasks = vmasks.view(-1, vmasks.size(-1))
logits = self.forward_video(
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=True
)
return {"logits": logits[:, 1:vmasks.size(1)+1]}
class MMFusionSeparateActionLocalization(MMFusionSeparate):
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
outputs = self.forward_video(
vfeats,
vmasks,
dummy_caps,
dummy_cmasks,
output_hidden_states=True
)
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states=False
)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}
class MMFusionShareActionLocalization(MMFusionShare):
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
outputs = self.forward_video(
vfeats,
vmasks,
dummy_caps,
dummy_cmasks,
output_hidden_states=True
)
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states=False
)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}