shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, Sequence
import torch
import torch.nn as nn
from torch import Tensor
from timm.models.helpers import named_apply
from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer
from IndicPhotoOCR.utils.strhub.models.utils import init_weights
from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
class PARSeq(nn.Module):
def __init__(
self,
num_tokens: int,
max_label_length: int,
img_size: Sequence[int],
patch_size: Sequence[int],
embed_dim: int,
enc_num_heads: int,
enc_mlp_ratio: int,
enc_depth: int,
dec_num_heads: int,
dec_mlp_ratio: int,
dec_depth: int,
decode_ar: bool,
refine_iters: int,
dropout: float,
) -> None:
super().__init__()
self.max_label_length = max_label_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
self.encoder = Encoder(
img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
)
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
# We don't predict <bos> nor <pad>
self.head = nn.Linear(embed_dim, num_tokens - 2)
self.text_embed = TokenEmbedding(num_tokens, embed_dim)
# +1 for <eos>
self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
self.dropout = nn.Dropout(p=dropout)
# Encoder has its own init.
named_apply(partial(init_weights, exclude=['encoder']), self)
nn.init.trunc_normal_(self.pos_queries, std=0.02)
@property
def _device(self) -> torch.device:
return next(self.head.parameters(recurse=False)).device
@torch.jit.ignore
def no_weight_decay(self):
param_names = {'text_embed.embedding.weight', 'pos_queries'}
enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img: torch.Tensor):
return self.encoder(img)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[Tensor] = None,
tgt_padding_mask: Optional[Tensor] = None,
tgt_query: Optional[Tensor] = None,
tgt_query_mask: Optional[Tensor] = None,
):
N, L = tgt.shape
# <bos> stands for the null context. We only supply position information for characters after <bos>.
null_ctx = self.text_embed(tgt[:, :1])
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
tgt_query = self.dropout(tgt_query)
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
testing = max_length is None
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
bs = images.shape[0]
# +1 for <eos> at end of sequence.
num_steps = max_length + 1
memory = self.encode(images)
# Query positions up to `num_steps`
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1)
if self.decode_ar:
tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
tgt_in[:, 0] = tokenizer.bos_id
logits = []
for i in range(num_steps):
j = i + 1 # next token index
# Efficient decoding:
# Input the context up to the ith token. We use only one query (at position = i) at a time.
# This works because of the lookahead masking effect of the canonical (forward) AR context.
# Past tokens have no access to future tokens, hence are fixed once computed.
tgt_out = self.decode(
tgt_in[:, :j],
memory,
tgt_mask[:j, :j],
tgt_query=pos_queries[:, i:j],
tgt_query_mask=query_mask[i:j, :j],
)
# the next token probability is in the output's ith token position
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
else:
# No prior context, so input is just <bos>. We query all positions.
tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
# For iterative refinement, we always use a 'cloze' mask.
# We can derive it from the AR forward mask by unmasking the token context to the right.
query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
for i in range(self.refine_iters):
# Prior context is the previous output.
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
# Mask tokens beyond the first EOS token.
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
tgt_out = self.decode(
tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
)
logits = self.head(tgt_out)
return logits