Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from .backbone import build_backbone | |
import pdb | |
import numpy as np | |
from typing import Optional | |
class TokenOCR(nn.Module): | |
def __init__(self, backbone): | |
""" Initializes the model. | |
Parameters: | |
backbone: torch module of the backbone to be used. See backbone.py | |
transformer: torch module of the transformer architecture. See transformer.py | |
num_classes: number of object classes | |
""" | |
super().__init__() | |
self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2) | |
for p in self.parameters(): | |
p.requires_grad = False | |
self.backbone = backbone | |
init_tau=np.log(10) | |
init_b=-2.71 | |
# self.t_prime = nn.Parameter(torch.ones([]) * init_tau) | |
# self.b = nn.Parameter(torch.ones([]) * init_b) | |
self.kb = True | |
self.upsample = nn.Sequential( | |
nn.ConvTranspose2d( | |
in_channels=2048, | |
out_channels=512, | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=False | |
), | |
nn.SyncBatchNorm(512), | |
nn.ConvTranspose2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=False | |
), | |
nn.SyncBatchNorm(512), | |
) | |
self.ocr_mlp = nn.Sequential( | |
nn.Linear(512, 2048), | |
nn.GELU(), | |
nn.Linear(2048, 2048) | |
) | |
def forward(self, | |
pixel_values: torch.FloatTensor, | |
input_ids: torch.LongTensor = None, | |
image_flags: Optional[torch.LongTensor] = None, | |
mask_values: Optional[torch.LongTensor] = None, | |
masks_flags: Optional[torch.LongTensor] = None, | |
mask_nums: Optional[torch.LongTensor] = None, | |
): | |
image_flags = image_flags.squeeze(-1) | |
try: | |
input_embeds = self.language_embedding(input_ids).clone() | |
except: | |
print('error'*1000) | |
import pdb; pdb.set_trace() | |
# import pdb; pdb.set_trace() | |
vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048) | |
nb, nl, nd = vit_embeds.shape | |
h, w = vit_embeds_shape | |
vit_embeds = vit_embeds.reshape(nb, h, w, nd) | |
vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B | |
vit_batch_size = pixel_values.shape[0] | |
B, N, C = input_embeds.shape | |
try: | |
assert sum(image_flags) == mask_values.shape[0] | |
except: | |
print((mask_values.shape, image_flags, mask_nums)) | |
mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128) | |
masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B | |
masks_flags = masks_flags.chunk(B) | |
token_features = [] | |
input_embedings = [] | |
masked_input_ids = [] | |
masked_zero_bools = [] | |
for i, vit_embed in enumerate(vit_embeds): | |
current_token = masks_flags[i].sum() | |
mask = masks[i] | |
limit_num = mask.shape[1] | |
mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0 | |
max_cluster_index = mask.sum(-1) | |
zero_bool = max_cluster_index != 0 | |
# import pdb; pdb.set_trace() | |
mask[~zero_bool] = 1 #for addressing bflost16 bug | |
new_max_cluster_index = mask.sum(-1) | |
mask = mask / new_max_cluster_index.unsqueeze(-1) | |
token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1])) | |
token_features.extend(token_feature) | |
input_embedings.extend(input_embeds[i, :]) | |
masked_input_ids.extend(input_ids[i, zero_bool]) | |
masked_zero_bools.append(zero_bool) | |
masked_zero_bools = torch.cat(masked_zero_bools) | |
token_features = torch.stack(token_features) | |
input_embedings= torch.stack(input_embedings) | |
loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean() | |
token_features = token_features / token_features.norm(dim=1, keepdim=True) | |
input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True) | |
# cosine similarity as logits | |
similarity = F.cosine_similarity(token_features, input_embedings, dim=1) | |
loss1 = (1 - similarity[masked_zero_bools]).mean() | |
# loss_d = loss1 + loss2 | |
# if rank == 0: | |
# print(f'loss1:{loss_d}') | |
###siglip | |
# masked_input_ids = torch.stack(masked_input_ids) | |
# label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int() | |
# label_matrix = 2 * label_matrix - 1 | |
# if self.kb: | |
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device) | |
# else: | |
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375 | |
# loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0] | |
# if rank == 0: | |
# print(f'loss2:{loss_s}') | |
return loss1, loss2 | |
def forward_tokenocr(self, pixel_values): | |
vit_embeds = self.backbone(pixel_values) | |
vit_embeds = vit_embeds['0'] | |
h, w = vit_embeds.shape[2], vit_embeds.shape[3] | |
vit_embeds = self.upsample(vit_embeds) | |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1]) | |
vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1)) | |
return vit_embeds, (h*4, w*4) | |
class MLP(nn.Module): | |
""" Very simple multi-layer perceptron (also called FFN)""" | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
return x | |
def build(args): | |
backbone = build_backbone(args) | |
model = TokenOCR(backbone) | |
return model | |