Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,610 Bytes
841bef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn.functional as F
from torch import nn
from .backbone import build_backbone
import pdb
import numpy as np
from typing import Optional
class TokenOCR(nn.Module):
def __init__(self, backbone):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
"""
super().__init__()
self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2)
for p in self.parameters():
p.requires_grad = False
self.backbone = backbone
init_tau=np.log(10)
init_b=-2.71
# self.t_prime = nn.Parameter(torch.ones([]) * init_tau)
# self.b = nn.Parameter(torch.ones([]) * init_b)
self.kb = True
self.upsample = nn.Sequential(
nn.ConvTranspose2d(
in_channels=2048,
out_channels=512,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.SyncBatchNorm(512),
nn.ConvTranspose2d(
in_channels=512,
out_channels=512,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.SyncBatchNorm(512),
)
self.ocr_mlp = nn.Sequential(
nn.Linear(512, 2048),
nn.GELU(),
nn.Linear(2048, 2048)
)
def forward(self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
image_flags: Optional[torch.LongTensor] = None,
mask_values: Optional[torch.LongTensor] = None,
masks_flags: Optional[torch.LongTensor] = None,
mask_nums: Optional[torch.LongTensor] = None,
):
image_flags = image_flags.squeeze(-1)
try:
input_embeds = self.language_embedding(input_ids).clone()
except:
print('error'*1000)
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048)
nb, nl, nd = vit_embeds.shape
h, w = vit_embeds_shape
vit_embeds = vit_embeds.reshape(nb, h, w, nd)
vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B
vit_batch_size = pixel_values.shape[0]
B, N, C = input_embeds.shape
try:
assert sum(image_flags) == mask_values.shape[0]
except:
print((mask_values.shape, image_flags, mask_nums))
mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128)
masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B
masks_flags = masks_flags.chunk(B)
token_features = []
input_embedings = []
masked_input_ids = []
masked_zero_bools = []
for i, vit_embed in enumerate(vit_embeds):
current_token = masks_flags[i].sum()
mask = masks[i]
limit_num = mask.shape[1]
mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0
max_cluster_index = mask.sum(-1)
zero_bool = max_cluster_index != 0
# import pdb; pdb.set_trace()
mask[~zero_bool] = 1 #for addressing bflost16 bug
new_max_cluster_index = mask.sum(-1)
mask = mask / new_max_cluster_index.unsqueeze(-1)
token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1]))
token_features.extend(token_feature)
input_embedings.extend(input_embeds[i, :])
masked_input_ids.extend(input_ids[i, zero_bool])
masked_zero_bools.append(zero_bool)
masked_zero_bools = torch.cat(masked_zero_bools)
token_features = torch.stack(token_features)
input_embedings= torch.stack(input_embedings)
loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean()
token_features = token_features / token_features.norm(dim=1, keepdim=True)
input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True)
# cosine similarity as logits
similarity = F.cosine_similarity(token_features, input_embedings, dim=1)
loss1 = (1 - similarity[masked_zero_bools]).mean()
# loss_d = loss1 + loss2
# if rank == 0:
# print(f'loss1:{loss_d}')
###siglip
# masked_input_ids = torch.stack(masked_input_ids)
# label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int()
# label_matrix = 2 * label_matrix - 1
# if self.kb:
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device)
# else:
# logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375
# loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0]
# if rank == 0:
# print(f'loss2:{loss_s}')
return loss1, loss2
def forward_tokenocr(self, pixel_values):
vit_embeds = self.backbone(pixel_values)
vit_embeds = vit_embeds['0']
h, w = vit_embeds.shape[2], vit_embeds.shape[3]
vit_embeds = self.upsample(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1])
vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1))
return vit_embeds, (h*4, w*4)
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
def build(args):
backbone = build_backbone(args)
model = TokenOCR(backbone)
return model
|