import math from typing import Callable, List, Optional, Tuple, Union from collections import defaultdict import os import shutil import cv2 import numpy as np import einops import torch import torch.nn as nn import torch.nn.functional as F from .xpos_relative_position import XPOS # Roformer with Xpos and Local Attention ViT from .common import OfflineOCR from ..utils import TextBlock, Quadrilateral, chunks from ..utils.generic import AvgMeter from ..utils.bubble import is_ignore # Roformer with Xpos class Model48pxOCR(OfflineOCR): _MODEL_MAPPING = { 'model': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr_ar_48px.ckpt', 'hash': '29daa46d080818bb4ab239a518a88338cbccff8f901bef8c9db191a7cb97671d', }, 'dict': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/alphabet-all-v7.txt', 'hash': 'f5722368146aa0fbcc9f4726866e4efc3203318ebb66c811d8cbbe915576538a', }, } def __init__(self, *args, **kwargs): os.makedirs(self.model_dir, exist_ok=True) if os.path.exists('ocr_ar_48px.ckpt'): shutil.move('ocr_ar_48px.ckpt', self._get_file_path('ocr_ar_48px.ckpt')) if os.path.exists('alphabet-all-v7.txt'): shutil.move('alphabet-all-v7.txt', self._get_file_path('alphabet-all-v7.txt')) super().__init__(*args, **kwargs) async def _load(self, device: str): with open(self._get_file_path('alphabet-all-v7.txt'), 'r', encoding = 'utf-8') as fp: dictionary = [s[:-1] for s in fp.readlines()] self.model = OCR(dictionary, 768) sd = torch.load(self._get_file_path('ocr_ar_48px.ckpt')) self.model.load_state_dict(sd) self.model.eval() self.device = device if (device == 'cuda' or device == 'mps'): self.use_gpu = True else: self.use_gpu = False if self.use_gpu: self.model = self.model.to(device) async def _unload(self): del self.model async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False, ignore_bubble: int = 0) -> List[TextBlock]: text_height = 48 max_chunk_size = 16 quadrilaterals = list(self._generate_text_direction(textlines)) region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals] out_regions = [] perm = range(len(region_imgs)) is_quadrilaterals = False if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral): perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1]) is_quadrilaterals = True ix = 0 for indices in chunks(perm, max_chunk_size): N = len(indices) widths = [region_imgs[i].shape[1] for i in indices] max_width = 4 * (max(widths) + 7) // 4 region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8) for i, idx in enumerate(indices): W = region_imgs[idx].shape[1] tmp = region_imgs[idx] region[i, :, : W, :]=tmp if verbose: os.makedirs('result/ocrs/', exist_ok=True) if quadrilaterals[idx][1] == 'v': cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE)) else: cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR)) ix += 1 image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5 image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W') if self.use_gpu: image_tensor = image_tensor.to(self.device) with torch.no_grad(): ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255) for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret): if prob < 0.2: continue has_fg = (fg_ind_pred[:, 1] > fg_ind_pred[:, 0]) has_bg = (bg_ind_pred[:, 1] > bg_ind_pred[:, 0]) seq = [] fr = AvgMeter() fg = AvgMeter() fb = AvgMeter() br = AvgMeter() bg = AvgMeter() bb = AvgMeter() for chid, c_fg, c_bg, h_fg, h_bg in zip(pred_chars_index, fg_pred, bg_pred, has_fg, has_bg) : ch = self.model.dictionary[chid] if ch == '': continue if ch == '': break if ch == '': ch = ' ' seq.append(ch) if h_fg.item() : fr(int(c_fg[0] * 255)) fg(int(c_fg[1] * 255)) fb(int(c_fg[2] * 255)) if h_bg.item() : br(int(c_bg[0] * 255)) bg(int(c_bg[1] * 255)) bb(int(c_bg[2] * 255)) else : br(int(c_fg[0] * 255)) bg(int(c_fg[1] * 255)) bb(int(c_fg[2] * 255)) txt = ''.join(seq) fr = min(max(int(fr()), 0), 255) fg = min(max(int(fg()), 0), 255) fb = min(max(int(fb()), 0), 255) br = min(max(int(br()), 0), 255) bg = min(max(int(bg()), 0), 255) bb = min(max(int(bb()), 0), 255) self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})') cur_region = quadrilaterals[indices[i]][0] if isinstance(cur_region, Quadrilateral): cur_region.text = txt cur_region.prob = prob cur_region.fg_r = fr cur_region.fg_g = fg cur_region.fg_b = fb cur_region.bg_r = br cur_region.bg_g = bg cur_region.bg_b = bb else: cur_region.text.append(txt) cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb])) out_regions.append(cur_region) if is_quadrilaterals: return out_regions return textlines class ConvNeXtBlock(nn.Module): r""" ConvNeXt Block. There are two equivalent implementations: (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back We use (2) as we find it slightly faster in PyTorch Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__(self, dim, layer_scale_init_value=1e-6, ks = 7, padding = 3): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=padding, groups=dim) # depthwise conv self.norm = nn.BatchNorm2d(dim, eps=1e-6) self.pwconv1 = nn.Conv2d(dim, 4 * dim, 1, 1, 0) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Conv2d(4 * dim, dim, 1, 1, 0) self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(1, dim, 1, 1), requires_grad=True) if layer_scale_init_value > 0 else None def forward(self, x): input = x x = self.dwconv(x) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = input + x return x class ConvNext_FeatureExtractor(nn.Module) : def __init__(self, img_height = 48, in_dim = 3, dim = 512, n_layers = 12) -> None: super().__init__() base = dim // 8 self.stem = nn.Sequential( nn.Conv2d(in_dim, base, kernel_size = 7, stride = 1, padding = 3), nn.BatchNorm2d(base), nn.ReLU(), nn.Conv2d(base, base * 2, kernel_size = 2, stride = 2, padding = 0), nn.BatchNorm2d(base * 2), nn.ReLU(), nn.Conv2d(base * 2, base * 2, kernel_size = 3, stride = 1, padding = 1), nn.BatchNorm2d(base * 2), nn.ReLU(), ) self.block1 = self.make_layers(base * 2, 4) self.down1 = nn.Sequential( nn.Conv2d(base * 2, base * 4, kernel_size = 2, stride = 2, padding = 0), nn.BatchNorm2d(base * 4), nn.ReLU(), ) self.block2 = self.make_layers(base * 4, 12) self.down2 = nn.Sequential( nn.Conv2d(base * 4, base * 8, kernel_size = (2, 1), stride = (2, 1), padding = (0, 0)), nn.BatchNorm2d(base * 8), nn.ReLU(), ) self.block3 = self.make_layers(base * 8, 10, ks = 5, padding = 2) self.down3 = nn.Sequential( nn.Conv2d(base * 8, base * 8, kernel_size = (2, 1), stride = (2, 1), padding = (0, 0)), nn.BatchNorm2d(base * 8), nn.ReLU(), ) self.block4 = self.make_layers(base * 8, 8, ks = 3, padding = 1) self.down4 = nn.Sequential( nn.Conv2d(base * 8, base * 8, kernel_size = (3, 1), stride = (1, 1), padding = (0, 0)), nn.BatchNorm2d(base * 8), nn.ReLU(), ) def make_layers(self, dim, n, ks = 7, padding = 3) : layers = [] for i in range(n) : layers.append(ConvNeXtBlock(dim, ks = ks, padding = padding)) return nn.Sequential(*layers) def forward(self, x) : x = self.stem(x) # h//2, w//2 x = self.block1(x) x = self.down1(x) # h//4, w//4 x = self.block2(x) x = self.down2(x) # h//8, w//4 x = self.block3(x) x = self.down3(x) # h//16, w//4 x = self.block4(x) x = self.down4(x) return x def transformer_encoder_forward( self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False) -> torch.Tensor: x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x class XposMultiheadAttention(nn.Module): def __init__( self, embed_dim, num_heads, self_attention=False, encoder_decoder_attention=False, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert self.self_attention ^ self.encoder_decoder_attention self.k_proj = nn.Linear(embed_dim, embed_dim, bias = True) self.v_proj = nn.Linear(embed_dim, embed_dim, bias = True) self.q_proj = nn.Linear(embed_dim, embed_dim, bias = True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias = True) self.xpos = XPOS(self.head_dim, embed_dim) self.batch_first = True self._qkv_same_embed_dim = True def reset_parameters(self): nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.out_proj.weight) nn.init.constant_(self.out_proj.bias, 0.0) def forward( self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights = False, is_causal = False, k_offset = 0, q_offset = 0 ): assert not is_causal bsz, tgt_len, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" key_bsz, src_len, _ = key.size() assert key_bsz == bsz, f"{query.size(), key.size()}" assert value is not None assert bsz, src_len == value.shape[:2] q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) if self.xpos is not None: k = self.xpos(k, offset=k_offset, downscale=True) # TODO: read paper q = self.xpos(q, offset=q_offset, downscale=False) attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: attn_weights = torch.nan_to_num(attn_weights) attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask if key_padding_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( attn_weights ) attn = torch.bmm(attn_weights, v) attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) attn = self.out_proj(attn) attn_weights = attn_weights.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if need_weights: return attn, attn_weights else : return attn, None def generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask class Beam: def __init__(self, char_seq = [], logprobs = []): # L if isinstance(char_seq, list): self.chars = torch.tensor(char_seq, dtype=torch.long) self.logprobs = torch.tensor(logprobs, dtype=torch.float32) else: self.chars = char_seq.clone() self.logprobs = logprobs.clone() def avg_logprob(self): return self.logprobs.mean().item() def sort_key(self): return -self.avg_logprob() def seq_end(self, end_tok): return self.chars.view(-1)[-1] == end_tok def extend(self, idx, logprob): return Beam( torch.cat([self.chars, idx.unsqueeze(0)], dim = -1), torch.cat([self.logprobs, logprob.unsqueeze(0)], dim = -1), ) DECODE_BLOCK_LENGTH = 8 class Hypothesis: def __init__(self, device, start_tok: int, end_tok: int, padding_tok: int, memory_idx: int, num_layers: int, embd_dim: int): self.device = device self.start_tok = start_tok self.end_tok = end_tok self.padding_tok = padding_tok self.memory_idx = memory_idx self.embd_size = embd_dim self.num_layers = num_layers # 1, L, E self.cached_activations = [torch.zeros(1, 0, self.embd_size).to(self.device)] * (num_layers + 1) self.out_idx = torch.LongTensor([start_tok]).to(self.device) self.out_logprobs = torch.FloatTensor([0]).to(self.device) self.length = 0 def seq_end(self): return self.out_idx.view(-1)[-1] == self.end_tok def logprob(self): return self.out_logprobs.mean().item() def sort_key(self): return -self.logprob() def prob(self): return self.out_logprobs.mean().exp().item() def __len__(self): return self.length def extend(self, idx, logprob): ret = Hypothesis(self.device, self.start_tok, self.end_tok, self.padding_tok, self.memory_idx, self.num_layers, self.embd_size) ret.cached_activations = [item.clone() for item in self.cached_activations] ret.length = self.length + 1 ret.out_idx = torch.cat([self.out_idx, torch.LongTensor([idx]).to(self.device)], dim = 0) ret.out_logprobs = torch.cat([self.out_logprobs, torch.FloatTensor([logprob]).to(self.device)], dim = 0) return ret def output(self): return self.cached_activations[-1] def next_token_batch( hyps: List[Hypothesis], memory: torch.Tensor, # N, H, W, C memory_mask: torch.BoolTensor, decoders: nn.ModuleList, embd: nn.Embedding ): layer: nn.TransformerDecoderLayer N = len(hyps) offset = len(hyps[0]) # N last_toks = torch.stack([item.out_idx[-1] for item in hyps]) # N, 1, E tgt: torch.FloatTensor = embd(last_toks).unsqueeze_(1) # N, L, E memory = torch.stack([memory[idx, :, :] for idx in [item.memory_idx for item in hyps]], dim = 0) for l, layer in enumerate(decoders): # TODO: keys and values are recomputed every time # N, L - 1, E combined_activations = torch.cat([item.cached_activations[l] for item in hyps], dim = 0) # N, L, E combined_activations = torch.cat([combined_activations, tgt], dim = 1) for i in range(N): hyps[i].cached_activations[l] = combined_activations[i: i + 1, :, :] # N, 1, E tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset = offset)[0] tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask = memory_mask, q_offset = offset)[0] tgt = tgt + layer._ff_block(layer.norm3(tgt)) #print(tgt[0, 0, 0]) for i in range(N): hyps[i].cached_activations[len(decoders)] = torch.cat([hyps[i].cached_activations[len(decoders)], tgt[i: i + 1, :, :]], dim = 1) # N, E return tgt.squeeze_(1) class OCR(nn.Module): def __init__(self, dictionary, max_len): super(OCR, self).__init__() self.max_len = max_len self.dictionary = dictionary self.dict_size = len(dictionary) n_decoders = 4 embd_dim = 320 nhead = 4 #self.backbone = LocalViT_FeatureExtractor(48, 3, dim = embd_dim, ff_dim = embd_dim * 4, n_layers = n_encoders) self.backbone = ConvNext_FeatureExtractor(48, 3, embd_dim) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for i in range(4) : encoder = nn.TransformerEncoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True) encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True) encoder.forward = transformer_encoder_forward self.encoders.append(encoder) for i in range(5) : decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True) decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True) decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True) self.decoders.append(decoder) self.embd = nn.Embedding(self.dict_size, embd_dim) self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15)) self.pred = nn.Linear(embd_dim, self.dict_size) self.pred.weight = self.embd.weight self.color_pred1 = nn.Sequential(nn.Linear(embd_dim, 64), nn.ReLU()) self.color_pred_fg = nn.Linear(64, 3) self.color_pred_bg = nn.Linear(64, 3) self.color_pred_fg_ind = nn.Linear(64, 2) self.color_pred_bg_ind = nn.Linear(64, 2) def forward(self, img: torch.FloatTensor, char_idx: torch.LongTensor, decoder_mask: torch.BoolTensor, encoder_mask: torch.BoolTensor ): memory = self.backbone(img) memory = einops.rearrange(memory, 'N C 1 W -> N W C') for layer in self.encoders : memory = layer(memory, src_key_padding_mask = encoder_mask) N, L = char_idx.shape char_embd = self.embd(char_idx) # N, L, D casual_mask = generate_square_subsequent_mask(L).to(img.device) decoded = char_embd for layer in self.decoders : decoded = layer(decoded, memory, tgt_mask = casual_mask, tgt_key_padding_mask = decoder_mask, memory_key_padding_mask = encoder_mask) pred_char_logits = self.pred(self.pred1(decoded)) color_feats = self.color_pred1(decoded) return pred_char_logits, \ self.color_pred_fg(color_feats), \ self.color_pred_bg(color_feats), \ self.color_pred_fg_ind(color_feats), \ self.color_pred_bg_ind(color_feats) def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384): N, C, H, W = img.shape assert H == 48 and C == 3 memory = self.backbone(img) memory = einops.rearrange(memory, 'N C 1 W -> N W C') valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths] input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device) for i, l in enumerate(valid_feats_length): input_mask[i, l:] = True for layer in self.encoders : memory = layer(layer, src = memory, src_key_padding_mask = input_mask) hypos = [Hypothesis(img.device, start_tok, end_tok, pad_tok, i, len(self.decoders), 320) for i in range(N)] # N, E decoded = next_token_batch(hypos, memory, input_mask, self.decoders, self.embd) # N, n_chars pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N, k pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) new_hypos: List[Hypothesis] = [] finished_hypos = defaultdict(list) for i in range(N): for k in range(beams_k): new_hypos.append(hypos[i].extend(pred_chars_index[i, k], pred_chars_values[i, k])) hypos = new_hypos for ixx in range(max_seq_length): # N * k, E decoded = next_token_batch(hypos, memory, torch.stack([input_mask[hyp.memory_idx] for hyp in hypos]) , self.decoders, self.embd) # N * k, n_chars pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N * k, k pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) hypos_per_sample = defaultdict(list) h: Hypothesis for i, h in enumerate(hypos): for k in range(beams_k): hypos_per_sample[h.memory_idx].append(h.extend(pred_chars_index[i, k], pred_chars_values[i, k])) hypos = [] # hypos_per_sample now contains N * k^2 hypos for i in hypos_per_sample.keys(): cur_hypos: List[Hypothesis] = hypos_per_sample[i] cur_hypos = sorted(cur_hypos, key = lambda a: a.sort_key())[: beams_k + 1] #print(cur_hypos[0].out_idx[-1]) to_added_hypos = [] sample_done = False for h in cur_hypos: if h.seq_end(): finished_hypos[i].append(h) if len(finished_hypos[i]) >= max_finished_hypos: sample_done = True break else: if len(to_added_hypos) < beams_k: to_added_hypos.append(h) if not sample_done: hypos.extend(to_added_hypos) if len(hypos) == 0: break # add remaining hypos to finished for i in range(N): if i not in finished_hypos: cur_hypos: List[Hypothesis] = hypos_per_sample[i] cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0] finished_hypos[i].append(cur_hypo) assert len(finished_hypos) == N result = [] for i in range(N): cur_hypos = finished_hypos[i] cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0] decoded = cur_hypo.output() color_feats = self.color_pred1(decoded) fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \ self.color_pred_fg(color_feats), \ self.color_pred_bg(color_feats), \ self.color_pred_fg_ind(color_feats), \ self.color_pred_bg_ind(color_feats) result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0])) return result import numpy as np def convert_pl_model(filename: str) : sd = torch.load(filename, map_location = 'cpu')['state_dict'] sd2 = {} for k, v in sd.items() : k: str k = k.removeprefix('model.') sd2[k] = v return sd2 def test_LocalViT_FeatureExtractor() : net = ConvNext_FeatureExtractor(48, 3, 320) inp = torch.randn(2, 3, 48, 512) out = net(inp) print(out.shape) def test_infer() : with open('alphabet-all-v7.txt', 'r') as fp : dictionary = [s[:-1] for s in fp.readlines()] model = OCR(dictionary, 32) model.eval() sd = convert_pl_model('epoch=0-step=13000.ckpt') model.load_state_dict(sd) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print(params) img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2RGB) ratio = img.shape[1] / float(img.shape[0]) new_w = int(round(ratio * 48)) #print(img.shape) img = cv2.resize(img, (new_w, 48), interpolation=cv2.INTER_AREA) img_torch = einops.rearrange((torch.from_numpy(img) / 127.5 - 1.0), 'h w c -> 1 c h w') with torch.no_grad() : idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch(img_torch, [new_w], 5, max_seq_length = 32)[0] txt = '' for i in idx : txt += dictionary[i] print(txt, prob) for chid, fg, bg, fg_ind, bg_ind in zip(idx, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]) : has_fg = (fg_ind[1] > fg_ind[0]).item() has_bg = (bg_ind[1] > bg_ind[0]).item() if has_fg : fg = np.clip((fg * 255).numpy(), 0, 255) if has_bg : bg = np.clip((bg * 255).numpy(), 0, 255) print(f'{dictionary[chid]} {fg if has_fg else "None"} {bg if has_bg else "None"}') if __name__ == "__main__": test_infer()