|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from .backbone import build_backbone |
|
import pdb |
|
import numpy as np |
|
from typing import Optional |
|
|
|
class TokenOCR(nn.Module): |
|
def __init__(self, backbone): |
|
""" Initializes the model. |
|
Parameters: |
|
backbone: torch module of the backbone to be used. See backbone.py |
|
transformer: torch module of the transformer architecture. See transformer.py |
|
num_classes: number of object classes |
|
|
|
""" |
|
super().__init__() |
|
self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2) |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
|
|
self.backbone = backbone |
|
init_tau=np.log(10) |
|
init_b=-2.71 |
|
|
|
|
|
self.kb = True |
|
self.upsample = nn.Sequential( |
|
nn.ConvTranspose2d( |
|
in_channels=2048, |
|
out_channels=512, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False |
|
), |
|
nn.SyncBatchNorm(512), |
|
nn.ConvTranspose2d( |
|
in_channels=512, |
|
out_channels=512, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False |
|
), |
|
nn.SyncBatchNorm(512), |
|
) |
|
self.ocr_mlp = nn.Sequential( |
|
nn.Linear(512, 2048), |
|
nn.GELU(), |
|
nn.Linear(2048, 2048) |
|
) |
|
|
|
def forward(self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: torch.LongTensor = None, |
|
image_flags: Optional[torch.LongTensor] = None, |
|
mask_values: Optional[torch.LongTensor] = None, |
|
masks_flags: Optional[torch.LongTensor] = None, |
|
mask_nums: Optional[torch.LongTensor] = None, |
|
): |
|
image_flags = image_flags.squeeze(-1) |
|
try: |
|
input_embeds = self.language_embedding(input_ids).clone() |
|
except: |
|
print('error'*1000) |
|
import pdb; pdb.set_trace() |
|
|
|
vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) |
|
nb, nl, nd = vit_embeds.shape |
|
h, w = vit_embeds_shape |
|
vit_embeds = vit_embeds.reshape(nb, h, w, nd) |
|
vit_embeds = vit_embeds.split(list(image_flags)) |
|
vit_batch_size = pixel_values.shape[0] |
|
|
|
B, N, C = input_embeds.shape |
|
try: |
|
assert sum(image_flags) == mask_values.shape[0] |
|
except: |
|
print((mask_values.shape, image_flags, mask_nums)) |
|
|
|
mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) |
|
masks = mask_values.split(list(image_flags)) |
|
|
|
|
|
masks_flags = masks_flags.chunk(B) |
|
token_features = [] |
|
input_embedings = [] |
|
masked_input_ids = [] |
|
masked_zero_bools = [] |
|
for i, vit_embed in enumerate(vit_embeds): |
|
current_token = masks_flags[i].sum() |
|
mask = masks[i] |
|
limit_num = mask.shape[1] |
|
mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0 |
|
max_cluster_index = mask.sum(-1) |
|
zero_bool = max_cluster_index != 0 |
|
|
|
mask[~zero_bool] = 1 |
|
new_max_cluster_index = mask.sum(-1) |
|
mask = mask / new_max_cluster_index.unsqueeze(-1) |
|
token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1])) |
|
token_features.extend(token_feature) |
|
input_embedings.extend(input_embeds[i, :]) |
|
masked_input_ids.extend(input_ids[i, zero_bool]) |
|
masked_zero_bools.append(zero_bool) |
|
|
|
masked_zero_bools = torch.cat(masked_zero_bools) |
|
token_features = torch.stack(token_features) |
|
input_embedings= torch.stack(input_embedings) |
|
|
|
loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean() |
|
token_features = token_features / token_features.norm(dim=1, keepdim=True) |
|
input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True) |
|
|
|
similarity = F.cosine_similarity(token_features, input_embedings, dim=1) |
|
loss1 = (1 - similarity[masked_zero_bools]).mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss1, loss2 |
|
|
|
def forward_tokenocr(self, pixel_values): |
|
vit_embeds = self.backbone(pixel_values) |
|
vit_embeds = vit_embeds['0'] |
|
h, w = vit_embeds.shape[2], vit_embeds.shape[3] |
|
vit_embeds = self.upsample(vit_embeds) |
|
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1]) |
|
vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1)) |
|
return vit_embeds, (h*4, w*4) |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
def build(args): |
|
backbone = build_backbone(args) |
|
model = TokenOCR(backbone) |
|
return model |
|
|