Spaces:
Runtime error
Runtime error
File size: 7,255 Bytes
01bb3bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, Sequence
import torch
import torch.nn as nn
from torch import Tensor
from timm.models.helpers import named_apply
from IndicPhotoOCR.utils.strhub.data.utils import Tokenizer
from IndicPhotoOCR.utils.strhub.models.utils import init_weights
from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
class PARSeq(nn.Module):
def __init__(
self,
num_tokens: int,
max_label_length: int,
img_size: Sequence[int],
patch_size: Sequence[int],
embed_dim: int,
enc_num_heads: int,
enc_mlp_ratio: int,
enc_depth: int,
dec_num_heads: int,
dec_mlp_ratio: int,
dec_depth: int,
decode_ar: bool,
refine_iters: int,
dropout: float,
) -> None:
super().__init__()
self.max_label_length = max_label_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
self.encoder = Encoder(
img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
)
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
# We don't predict <bos> nor <pad>
self.head = nn.Linear(embed_dim, num_tokens - 2)
self.text_embed = TokenEmbedding(num_tokens, embed_dim)
# +1 for <eos>
self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
self.dropout = nn.Dropout(p=dropout)
# Encoder has its own init.
named_apply(partial(init_weights, exclude=['encoder']), self)
nn.init.trunc_normal_(self.pos_queries, std=0.02)
@property
def _device(self) -> torch.device:
return next(self.head.parameters(recurse=False)).device
@torch.jit.ignore
def no_weight_decay(self):
param_names = {'text_embed.embedding.weight', 'pos_queries'}
enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img: torch.Tensor):
return self.encoder(img)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[Tensor] = None,
tgt_padding_mask: Optional[Tensor] = None,
tgt_query: Optional[Tensor] = None,
tgt_query_mask: Optional[Tensor] = None,
):
N, L = tgt.shape
# <bos> stands for the null context. We only supply position information for characters after <bos>.
null_ctx = self.text_embed(tgt[:, :1])
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
tgt_query = self.dropout(tgt_query)
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
testing = max_length is None
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
bs = images.shape[0]
# +1 for <eos> at end of sequence.
num_steps = max_length + 1
memory = self.encode(images)
# Query positions up to `num_steps`
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device), 1)
if self.decode_ar:
tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
tgt_in[:, 0] = tokenizer.bos_id
logits = []
for i in range(num_steps):
j = i + 1 # next token index
# Efficient decoding:
# Input the context up to the ith token. We use only one query (at position = i) at a time.
# This works because of the lookahead masking effect of the canonical (forward) AR context.
# Past tokens have no access to future tokens, hence are fixed once computed.
tgt_out = self.decode(
tgt_in[:, :j],
memory,
tgt_mask[:j, :j],
tgt_query=pos_queries[:, i:j],
tgt_query_mask=query_mask[i:j, :j],
)
# the next token probability is in the output's ith token position
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
else:
# No prior context, so input is just <bos>. We query all positions.
tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
# For iterative refinement, we always use a 'cloze' mask.
# We can derive it from the AR forward mask by unmasking the token context to the right.
query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
for i in range(self.refine_iters):
# Prior context is the previous output.
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
# Mask tokens beyond the first EOS token.
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
tgt_out = self.decode(
tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
)
logits = self.head(tgt_out)
return logits
|