Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from typing import List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from transformers import BridgeTowerPreTrainedModel, BridgeTowerModel | |
from transformers.models.bridgetower.modeling_bridgetower import BridgeTowerTextModel | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class BridgeTowerImageFeatureExtractor(nn.Module): | |
def __init__( | |
self, | |
patch_size=14, | |
width=1024, | |
resolution_after=294, | |
ckpt_path=None, | |
): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) | |
scale = width ** -0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width)) | |
self.ln_pre = LayerNorm(width) | |
if ckpt_path is not None: | |
sd = torch.load(ckpt_path) | |
if 'state_dict' in sd: | |
sd = sd["state_dict"] | |
print(f'Loading feature extractor checkpoint from {ckpt_path}') | |
self.load_state_dict(sd) | |
def forward(self, x: torch.Tensor): | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) | |
x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.positional_embedding.to(x.dtype) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
return x | |
class BridgeTowerITCHead(nn.Module): | |
def __init__(self, hidden_size, embed_size): | |
super().__init__() | |
self.fc = nn.Linear(hidden_size, embed_size) | |
def forward(self, x): | |
x = self.fc(x) | |
return x | |
class _BridgeTowerTextModelWrapper(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.text_model = BridgeTowerTextModel(config) | |
def forward(self, **kwargs): | |
return self.text_model(**kwargs) | |
class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.bridgetower = _BridgeTowerTextModelWrapper(config.text_config) | |
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
): | |
outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask) | |
final_hidden_cls = outputs.last_hidden_state[:,0,:] | |
final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2) | |
return final_hidden_cls | |
class BridgeTowerForITC(BridgeTowerPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.bridgetower = BridgeTowerModel(config) | |
self.itc_text_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
self.itc_image_head = BridgeTowerITCHead(config.hidden_size, config.contrastive_hidden_size) | |
self.itc_cross_modal_head = BridgeTowerITCHead(config.hidden_size * 2, config.contrastive_hidden_size) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
pixel_mask: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
image_embeds: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: | |
assert output_hidden_states, 'output_hidden_states should be set to True for BridgeTowerForITC' | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.bridgetower( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
pixel_values=pixel_values, | |
pixel_mask=pixel_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
image_embeds=image_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooler_output = outputs.pooler_output if return_dict else outputs[2] | |
hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states | |
final_hidden_txt = hidden_states_txt[-1] | |
final_hidden_img = hidden_states_img[-1] | |
image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(final_hidden_img) | |
image_token_type_embeddings = self.bridgetower.token_type_embeddings( | |
torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device) | |
).expand_as(image_embeds_with_ln) | |
final_hidden_img = ( | |
self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) | |
+ image_token_type_embeddings | |
) | |
final_hidden_txt = F.normalize(self.itc_text_head(final_hidden_txt[:,0,:]), dim=-1, p=2) | |
final_hidden_img = F.normalize(self.itc_image_head(final_hidden_img[:,0,:]), dim=-1, p=2) | |
final_hidden_cross = F.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) | |
logits = torch.stack([final_hidden_txt, final_hidden_img, final_hidden_cross], dim=-2) | |
if not return_dict: | |
return tuple(logits) | |
return SequenceClassifierOutput( | |
loss=None, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |