bridgetower-video-search / bridgetower_custom.py
shaoyent's picture
First update
a1ebdce
raw
history blame
7.22 kB
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BridgeTowerPreTrainedModel, BridgeTowerModel
from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class BridgeTowerImageFeatureExtractor(nn.Module):
def __init__(
self,
patch_size=14,
width=1024,
resolution_after=294,
ckpt_path=None,
):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
if ckpt_path is not None:
sd = torch.load(ckpt_path)
if 'state_dict' in sd:
sd = sd["state_dict"]
print(f'Loading feature extractor checkpoint from {ckpt_path}')
self.load_state_dict(sd)
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
return x
class BridgeTowerITCHead(nn.Module):
def __init__(self, hidden_size, embed_size):
super().__init__()
self.fc = nn.Linear(hidden_size, embed_size)
def forward(self, x):
x = self.fc(x)
return x
class _BridgeTowerTextModelWrapper(nn.Module):
def __init__(self, config):
super().__init__()
self.text_model = BridgeTowerTextModel(config)
def forward(self, **kwargs):
return self.text_model(**kwargs)
class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config)
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
):
outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask)
final_hidden_cls = outputs.last_hidden_state[:,0,:]
final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
return final_hidden_cls
class BridgeTowerForITC(BridgeTowerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bridgetower = BridgeTowerModel(config)
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size)
self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC'
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bridgetower(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
image_embeds=image_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooler_output = outputs.pooler_output if return_dict else outputs[2]
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states
final_hidden_txt = hidden_states_txt[-1]
final_hidden_img = hidden_states_img[-1]
image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img)
image_token_type_embeddings = self.bridgetower.token_type_embeddings(
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
).expand_as(image_embeds_with_ln)
final_hidden_img = (
self.bridgetower.cross_modal_image_transform(image_embeds_with_ln)
+ image_token_type_embeddings
)
final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2)
final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2)
final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2)
logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2)
if not return_dict:
return tuple(logits)
return SequenceClassifierOutput(
loss=None,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)