TongkunGuan's picture
Upload 94 files
841bef5 verified
import torch
import torch.nn.functional as F
from torch import nn
from .backbone import build_backbone
import pdb
import numpy as np
from typing import Optional
class TokenOCR(nn.Module):
def __init__(self, backbone):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
"""
super().__init__()
self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2)
for p in self.parameters():
p.requires_grad = False
self.backbone = backbone
init_tau=np.log(10)
init_b=-2.71
# self.t_prime = nn.Parameter(torch.ones([]) * init_tau)
# self.b = nn.Parameter(torch.ones([]) * init_b)
self.kb = True
self.upsample = nn.Sequential(
nn.ConvTranspose2d(
in_channels=2048,
out_channels=512,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.SyncBatchNorm(512),
nn.ConvTranspose2d(
in_channels=512,
out_channels=512,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.SyncBatchNorm(512),
)
self.ocr_mlp = nn.Sequential(
nn.Linear(512, 2048),
nn.GELU(),
nn.Linear(2048, 2048)
)
def forward(self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
image_flags: Optional[torch.LongTensor] = None,
mask_values: Optional[torch.LongTensor] = None,
masks_flags: Optional[torch.LongTensor] = None,
mask_nums: Optional[torch.LongTensor] = None,
):
image_flags = image_flags.squeeze(-1)
try:
input_embeds = self.language_embedding(input_ids).clone()
except:
print('error'*1000)
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048)
nb, nl, nd = vit_embeds.shape
h, w = vit_embeds_shape
vit_embeds = vit_embeds.reshape(nb, h, w, nd)
vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B
vit_batch_size = pixel_values.shape[0]
B, N, C = input_embeds.shape
try:
assert sum(image_flags) == mask_values.shape[0]
except:
print((mask_values.shape, image_flags, mask_nums))
mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128)
masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B
masks_flags = masks_flags.chunk(B)
token_features = []
input_embedings = []
masked_input_ids = []
masked_zero_bools = []
for i, vit_embed in enumerate(vit_embeds):
current_token = masks_flags[i].sum()
mask = masks[i]
limit_num = mask.shape[1]
mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0
max_cluster_index = mask.sum(-1)
zero_bool = max_cluster_index != 0
# import pdb; pdb.set_trace()
mask[~zero_bool] = 1 #for addressing bflost16 bug
new_max_cluster_index = mask.sum(-1)
mask = mask / new_max_cluster_index.unsqueeze(-1)
token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1]))
token_features.extend(token_feature)
input_embedings.extend(input_embeds[i, :])
masked_input_ids.extend(input_ids[i, zero_bool])
masked_zero_bools.append(zero_bool)
masked_zero_bools = torch.cat(masked_zero_bools)
token_features = torch.stack(token_features)
input_embedings= torch.stack(input_embedings)
loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean()
token_features = token_features / token_features.norm(dim=1, keepdim=True)
input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True)
# cosine similarity as logits
similarity = F.cosine_similarity(token_features, input_embedings, dim=1)
loss1 = (1 - similarity[masked_zero_bools]).mean()
# loss_d = loss1 + loss2
# if rank == 0:
# print(f'loss1:{loss_d}')
###siglip
# masked_input_ids = torch.stack(masked_input_ids)
# label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int()
# label_matrix = 2 * label_matrix - 1
# if self.kb:
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device)
# else:
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375
# loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0]
# if rank == 0:
# print(f'loss2:{loss_s}')
return loss1, loss2
def forward_tokenocr(self, pixel_values):
vit_embeds = self.backbone(pixel_values)
vit_embeds = vit_embeds['0']
h, w = vit_embeds.shape[2], vit_embeds.shape[3]
vit_embeds = self.upsample(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1])
vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1))
return vit_embeds, (h*4, w*4)
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def build(args):
backbone = build_backbone(args)
model = TokenOCR(backbone)
return model