# Scene Text Recognition Model Hub # Copyright 2022 Darwin Bautista # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Optional, Sequence import torch import torch.nn as nn from torch import Tensor from timm.models.helpers import named_apply from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer from IndicPhotoOCR.utils.strhub.models.utils import init_weights from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding class PARSeq(nn.Module): def __init__( self, num_tokens: int, max_label_length: int, img_size: Sequence[int], patch_size: Sequence[int], embed_dim: int, enc_num_heads: int, enc_mlp_ratio: int, enc_depth: int, dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, decode_ar: bool, refine_iters: int, dropout: float, ) -> None: super().__init__() self.max_label_length = max_label_length self.decode_ar = decode_ar self.refine_iters = refine_iters self.encoder = Encoder( img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio ) decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim)) # We don't predict nor self.head = nn.Linear(embed_dim, num_tokens - 2) self.text_embed = TokenEmbedding(num_tokens, embed_dim) # +1 for self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim)) self.dropout = nn.Dropout(p=dropout) # Encoder has its own init. named_apply(partial(init_weights, exclude=['encoder']), self) nn.init.trunc_normal_(self.pos_queries, std=0.02) @property def _device(self) -> torch.device: return next(self.head.parameters(recurse=False)).device @torch.jit.ignore def no_weight_decay(self): param_names = {'text_embed.embedding.weight', 'pos_queries'} enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()} return param_names.union(enc_param_names) def encode(self, img: torch.Tensor): return self.encoder(img) def decode( self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, tgt_query_mask: Optional[Tensor] = None, ): N, L = tgt.shape # stands for the null context. We only supply position information for characters after . null_ctx = self.text_embed(tgt[:, :1]) tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:]) tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) if tgt_query is None: tgt_query = self.pos_queries[:, :L].expand(N, -1, -1) tgt_query = self.dropout(tgt_query) return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor: testing = max_length is None max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) bs = images.shape[0] # +1 for at end of sequence. num_steps = max_length + 1 memory = self.encode(images) # Query positions up to `num_steps` pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) # Special case for the forward permutation. Faster than using `generate_attn_masks()` tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1) if self.decode_ar: tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device) tgt_in[:, 0] = tokenizer.bos_id logits = [] for i in range(num_steps): j = i + 1 # next token index # Efficient decoding: # Input the context up to the ith token. We use only one query (at position = i) at a time. # This works because of the lookahead masking effect of the canonical (forward) AR context. # Past tokens have no access to future tokens, hence are fixed once computed. tgt_out = self.decode( tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j], ) # the next token probability is in the output's ith token position p_i = self.head(tgt_out) logits.append(p_i) if j < num_steps: # greedy decode. add the next token index to the target input tgt_in[:, j] = p_i.squeeze().argmax(-1) # Efficient batch decoding: If all output words have at least one EOS token, end decoding. if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all(): break logits = torch.cat(logits, dim=1) else: # No prior context, so input is just . We query all positions. tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device) tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) logits = self.head(tgt_out) if self.refine_iters: # For iterative refinement, we always use a 'cloze' mask. # We can derive it from the AR forward mask by unmasking the token context to the right. query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0 bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device) for i in range(self.refine_iters): # Prior context is the previous output. tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) # Mask tokens beyond the first EOS token. tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0 tgt_out = self.decode( tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]] ) logits = self.head(tgt_out) return logits