|
|
|
import math
|
|
from typing import Callable, List, Optional, Tuple, Union
|
|
from collections import defaultdict
|
|
import os
|
|
import shutil
|
|
import cv2
|
|
import numpy as np
|
|
import einops
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .xpos_relative_position import XPOS
|
|
|
|
|
|
|
|
from .common import OfflineOCR
|
|
from ..utils import TextBlock, Quadrilateral, chunks
|
|
from ..utils.generic import AvgMeter
|
|
from ..utils.bubble import is_ignore
|
|
|
|
|
|
|
|
class Model48pxOCR(OfflineOCR):
|
|
_MODEL_MAPPING = {
|
|
'model': {
|
|
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr_ar_48px.ckpt',
|
|
'hash': '29daa46d080818bb4ab239a518a88338cbccff8f901bef8c9db191a7cb97671d',
|
|
},
|
|
'dict': {
|
|
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/alphabet-all-v7.txt',
|
|
'hash': 'f5722368146aa0fbcc9f4726866e4efc3203318ebb66c811d8cbbe915576538a',
|
|
},
|
|
}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
if os.path.exists('ocr_ar_48px.ckpt'):
|
|
shutil.move('ocr_ar_48px.ckpt', self._get_file_path('ocr_ar_48px.ckpt'))
|
|
if os.path.exists('alphabet-all-v7.txt'):
|
|
shutil.move('alphabet-all-v7.txt', self._get_file_path('alphabet-all-v7.txt'))
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def _load(self, device: str):
|
|
with open(self._get_file_path('alphabet-all-v7.txt'), 'r', encoding = 'utf-8') as fp:
|
|
dictionary = [s[:-1] for s in fp.readlines()]
|
|
|
|
self.model = OCR(dictionary, 768)
|
|
sd = torch.load(self._get_file_path('ocr_ar_48px.ckpt'))
|
|
self.model.load_state_dict(sd)
|
|
self.model.eval()
|
|
self.device = device
|
|
if (device == 'cuda' or device == 'mps'):
|
|
self.use_gpu = True
|
|
else:
|
|
self.use_gpu = False
|
|
if self.use_gpu:
|
|
self.model = self.model.to(device)
|
|
|
|
|
|
async def _unload(self):
|
|
del self.model
|
|
|
|
async def _infer(self, image: np.ndarray, textlines: List[Quadrilateral], args: dict, verbose: bool = False, ignore_bubble: int = 0) -> List[TextBlock]:
|
|
text_height = 48
|
|
max_chunk_size = 16
|
|
|
|
quadrilaterals = list(self._generate_text_direction(textlines))
|
|
region_imgs = [q.get_transformed_region(image, d, text_height) for q, d in quadrilaterals]
|
|
out_regions = []
|
|
|
|
perm = range(len(region_imgs))
|
|
is_quadrilaterals = False
|
|
if len(quadrilaterals) > 0 and isinstance(quadrilaterals[0][0], Quadrilateral):
|
|
perm = sorted(range(len(region_imgs)), key = lambda x: region_imgs[x].shape[1])
|
|
is_quadrilaterals = True
|
|
|
|
ix = 0
|
|
for indices in chunks(perm, max_chunk_size):
|
|
N = len(indices)
|
|
widths = [region_imgs[i].shape[1] for i in indices]
|
|
max_width = 4 * (max(widths) + 7) // 4
|
|
region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
|
|
for i, idx in enumerate(indices):
|
|
W = region_imgs[idx].shape[1]
|
|
tmp = region_imgs[idx]
|
|
region[i, :, : W, :]=tmp
|
|
if verbose:
|
|
os.makedirs('result/ocrs/', exist_ok=True)
|
|
if quadrilaterals[idx][1] == 'v':
|
|
cv2.imwrite(f'result/ocrs/{ix}.png', cv2.rotate(cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR), cv2.ROTATE_90_CLOCKWISE))
|
|
else:
|
|
cv2.imwrite(f'result/ocrs/{ix}.png', cv2.cvtColor(region[i, :, :, :], cv2.COLOR_RGB2BGR))
|
|
ix += 1
|
|
image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
|
|
image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
|
|
if self.use_gpu:
|
|
image_tensor = image_tensor.to(self.device)
|
|
with torch.no_grad():
|
|
ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
|
|
for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
|
|
if prob < 0.2:
|
|
continue
|
|
has_fg = (fg_ind_pred[:, 1] > fg_ind_pred[:, 0])
|
|
has_bg = (bg_ind_pred[:, 1] > bg_ind_pred[:, 0])
|
|
seq = []
|
|
fr = AvgMeter()
|
|
fg = AvgMeter()
|
|
fb = AvgMeter()
|
|
br = AvgMeter()
|
|
bg = AvgMeter()
|
|
bb = AvgMeter()
|
|
for chid, c_fg, c_bg, h_fg, h_bg in zip(pred_chars_index, fg_pred, bg_pred, has_fg, has_bg) :
|
|
ch = self.model.dictionary[chid]
|
|
if ch == '<S>':
|
|
continue
|
|
if ch == '</S>':
|
|
break
|
|
if ch == '<SP>':
|
|
ch = ' '
|
|
seq.append(ch)
|
|
if h_fg.item() :
|
|
fr(int(c_fg[0] * 255))
|
|
fg(int(c_fg[1] * 255))
|
|
fb(int(c_fg[2] * 255))
|
|
if h_bg.item() :
|
|
br(int(c_bg[0] * 255))
|
|
bg(int(c_bg[1] * 255))
|
|
bb(int(c_bg[2] * 255))
|
|
else :
|
|
br(int(c_fg[0] * 255))
|
|
bg(int(c_fg[1] * 255))
|
|
bb(int(c_fg[2] * 255))
|
|
txt = ''.join(seq)
|
|
fr = min(max(int(fr()), 0), 255)
|
|
fg = min(max(int(fg()), 0), 255)
|
|
fb = min(max(int(fb()), 0), 255)
|
|
br = min(max(int(br()), 0), 255)
|
|
bg = min(max(int(bg()), 0), 255)
|
|
bb = min(max(int(bb()), 0), 255)
|
|
self.logger.info(f'prob: {prob} {txt} fg: ({fr}, {fg}, {fb}) bg: ({br}, {bg}, {bb})')
|
|
cur_region = quadrilaterals[indices[i]][0]
|
|
if isinstance(cur_region, Quadrilateral):
|
|
cur_region.text = txt
|
|
cur_region.prob = prob
|
|
cur_region.fg_r = fr
|
|
cur_region.fg_g = fg
|
|
cur_region.fg_b = fb
|
|
cur_region.bg_r = br
|
|
cur_region.bg_g = bg
|
|
cur_region.bg_b = bb
|
|
else:
|
|
cur_region.text.append(txt)
|
|
cur_region.update_font_colors(np.array([fr, fg, fb]), np.array([br, bg, bb]))
|
|
|
|
out_regions.append(cur_region)
|
|
|
|
if is_quadrilaterals:
|
|
return out_regions
|
|
return textlines
|
|
|
|
class ConvNeXtBlock(nn.Module):
|
|
r""" ConvNeXt Block. There are two equivalent implementations:
|
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
|
We use (2) as we find it slightly faster in PyTorch
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
"""
|
|
def __init__(self, dim, layer_scale_init_value=1e-6, ks = 7, padding = 3):
|
|
super().__init__()
|
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=padding, groups=dim)
|
|
self.norm = nn.BatchNorm2d(dim, eps=1e-6)
|
|
self.pwconv1 = nn.Conv2d(dim, 4 * dim, 1, 1, 0)
|
|
self.act = nn.GELU()
|
|
self.pwconv2 = nn.Conv2d(4 * dim, dim, 1, 1, 0)
|
|
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(1, dim, 1, 1),
|
|
requires_grad=True) if layer_scale_init_value > 0 else None
|
|
|
|
def forward(self, x):
|
|
input = x
|
|
x = self.dwconv(x)
|
|
x = self.norm(x)
|
|
x = self.pwconv1(x)
|
|
x = self.act(x)
|
|
x = self.pwconv2(x)
|
|
if self.gamma is not None:
|
|
x = self.gamma * x
|
|
|
|
x = input + x
|
|
return x
|
|
|
|
class ConvNext_FeatureExtractor(nn.Module) :
|
|
def __init__(self, img_height = 48, in_dim = 3, dim = 512, n_layers = 12) -> None:
|
|
super().__init__()
|
|
base = dim // 8
|
|
self.stem = nn.Sequential(
|
|
nn.Conv2d(in_dim, base, kernel_size = 7, stride = 1, padding = 3),
|
|
nn.BatchNorm2d(base),
|
|
nn.ReLU(),
|
|
nn.Conv2d(base, base * 2, kernel_size = 2, stride = 2, padding = 0),
|
|
nn.BatchNorm2d(base * 2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(base * 2, base * 2, kernel_size = 3, stride = 1, padding = 1),
|
|
nn.BatchNorm2d(base * 2),
|
|
nn.ReLU(),
|
|
)
|
|
self.block1 = self.make_layers(base * 2, 4)
|
|
self.down1 = nn.Sequential(
|
|
nn.Conv2d(base * 2, base * 4, kernel_size = 2, stride = 2, padding = 0),
|
|
nn.BatchNorm2d(base * 4),
|
|
nn.ReLU(),
|
|
)
|
|
self.block2 = self.make_layers(base * 4, 12)
|
|
self.down2 = nn.Sequential(
|
|
nn.Conv2d(base * 4, base * 8, kernel_size = (2, 1), stride = (2, 1), padding = (0, 0)),
|
|
nn.BatchNorm2d(base * 8),
|
|
nn.ReLU(),
|
|
)
|
|
self.block3 = self.make_layers(base * 8, 10, ks = 5, padding = 2)
|
|
self.down3 = nn.Sequential(
|
|
nn.Conv2d(base * 8, base * 8, kernel_size = (2, 1), stride = (2, 1), padding = (0, 0)),
|
|
nn.BatchNorm2d(base * 8),
|
|
nn.ReLU(),
|
|
)
|
|
self.block4 = self.make_layers(base * 8, 8, ks = 3, padding = 1)
|
|
self.down4 = nn.Sequential(
|
|
nn.Conv2d(base * 8, base * 8, kernel_size = (3, 1), stride = (1, 1), padding = (0, 0)),
|
|
nn.BatchNorm2d(base * 8),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def make_layers(self, dim, n, ks = 7, padding = 3) :
|
|
layers = []
|
|
for i in range(n) :
|
|
layers.append(ConvNeXtBlock(dim, ks = ks, padding = padding))
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x) :
|
|
x = self.stem(x)
|
|
|
|
x = self.block1(x)
|
|
x = self.down1(x)
|
|
|
|
x = self.block2(x)
|
|
x = self.down2(x)
|
|
|
|
x = self.block3(x)
|
|
x = self.down3(x)
|
|
|
|
x = self.block4(x)
|
|
x = self.down4(x)
|
|
return x
|
|
|
|
def transformer_encoder_forward(
|
|
self,
|
|
src: torch.Tensor,
|
|
src_mask: Optional[torch.Tensor] = None,
|
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
|
is_causal: bool = False) -> torch.Tensor:
|
|
x = src
|
|
if self.norm_first:
|
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
|
x = x + self._ff_block(self.norm2(x))
|
|
else:
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
|
|
return x
|
|
|
|
class XposMultiheadAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
num_heads,
|
|
self_attention=False,
|
|
encoder_decoder_attention=False,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.self_attention = self_attention
|
|
self.encoder_decoder_attention = encoder_decoder_attention
|
|
assert self.self_attention ^ self.encoder_decoder_attention
|
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias = True)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias = True)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias = True)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias = True)
|
|
self.xpos = XPOS(self.head_dim, embed_dim)
|
|
self.batch_first = True
|
|
self._qkv_same_embed_dim = True
|
|
|
|
def reset_parameters(self):
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
|
|
def forward(
|
|
self,
|
|
query,
|
|
key,
|
|
value,
|
|
key_padding_mask=None,
|
|
attn_mask=None,
|
|
need_weights = False,
|
|
is_causal = False,
|
|
k_offset = 0,
|
|
q_offset = 0
|
|
):
|
|
assert not is_causal
|
|
bsz, tgt_len, embed_dim = query.size()
|
|
src_len = tgt_len
|
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
|
|
|
key_bsz, src_len, _ = key.size()
|
|
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
|
assert value is not None
|
|
assert bsz, src_len == value.shape[:2]
|
|
|
|
q = self.q_proj(query)
|
|
k = self.k_proj(key)
|
|
v = self.v_proj(value)
|
|
q *= self.scaling
|
|
|
|
q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
|
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
|
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
|
|
|
|
if self.xpos is not None:
|
|
k = self.xpos(k, offset=k_offset, downscale=True)
|
|
q = self.xpos(q, offset=q_offset, downscale=False)
|
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
|
|
|
if attn_mask is not None:
|
|
attn_weights = torch.nan_to_num(attn_weights)
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
attn_weights += attn_mask
|
|
|
|
if key_padding_mask is not None:
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_weights = attn_weights.masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
|
float("-inf"),
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
|
attn_weights
|
|
)
|
|
attn = torch.bmm(attn_weights, v)
|
|
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1)
|
|
|
|
attn = self.out_proj(attn)
|
|
attn_weights = attn_weights.view(
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
).transpose(1, 0)
|
|
|
|
if need_weights:
|
|
return attn, attn_weights
|
|
else :
|
|
return attn, None
|
|
|
|
def generate_square_subsequent_mask(sz):
|
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
return mask
|
|
|
|
class Beam:
|
|
def __init__(self, char_seq = [], logprobs = []):
|
|
|
|
if isinstance(char_seq, list):
|
|
self.chars = torch.tensor(char_seq, dtype=torch.long)
|
|
self.logprobs = torch.tensor(logprobs, dtype=torch.float32)
|
|
else:
|
|
self.chars = char_seq.clone()
|
|
self.logprobs = logprobs.clone()
|
|
|
|
def avg_logprob(self):
|
|
return self.logprobs.mean().item()
|
|
|
|
def sort_key(self):
|
|
return -self.avg_logprob()
|
|
|
|
def seq_end(self, end_tok):
|
|
return self.chars.view(-1)[-1] == end_tok
|
|
|
|
def extend(self, idx, logprob):
|
|
return Beam(
|
|
torch.cat([self.chars, idx.unsqueeze(0)], dim = -1),
|
|
torch.cat([self.logprobs, logprob.unsqueeze(0)], dim = -1),
|
|
)
|
|
|
|
DECODE_BLOCK_LENGTH = 8
|
|
|
|
class Hypothesis:
|
|
def __init__(self, device, start_tok: int, end_tok: int, padding_tok: int, memory_idx: int, num_layers: int, embd_dim: int):
|
|
self.device = device
|
|
self.start_tok = start_tok
|
|
self.end_tok = end_tok
|
|
self.padding_tok = padding_tok
|
|
self.memory_idx = memory_idx
|
|
self.embd_size = embd_dim
|
|
self.num_layers = num_layers
|
|
|
|
self.cached_activations = [torch.zeros(1, 0, self.embd_size).to(self.device)] * (num_layers + 1)
|
|
self.out_idx = torch.LongTensor([start_tok]).to(self.device)
|
|
self.out_logprobs = torch.FloatTensor([0]).to(self.device)
|
|
self.length = 0
|
|
|
|
def seq_end(self):
|
|
return self.out_idx.view(-1)[-1] == self.end_tok
|
|
|
|
def logprob(self):
|
|
return self.out_logprobs.mean().item()
|
|
|
|
def sort_key(self):
|
|
return -self.logprob()
|
|
|
|
def prob(self):
|
|
return self.out_logprobs.mean().exp().item()
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def extend(self, idx, logprob):
|
|
ret = Hypothesis(self.device, self.start_tok, self.end_tok, self.padding_tok, self.memory_idx, self.num_layers, self.embd_size)
|
|
ret.cached_activations = [item.clone() for item in self.cached_activations]
|
|
ret.length = self.length + 1
|
|
ret.out_idx = torch.cat([self.out_idx, torch.LongTensor([idx]).to(self.device)], dim = 0)
|
|
ret.out_logprobs = torch.cat([self.out_logprobs, torch.FloatTensor([logprob]).to(self.device)], dim = 0)
|
|
return ret
|
|
|
|
def output(self):
|
|
return self.cached_activations[-1]
|
|
|
|
def next_token_batch(
|
|
hyps: List[Hypothesis],
|
|
memory: torch.Tensor,
|
|
memory_mask: torch.BoolTensor,
|
|
decoders: nn.ModuleList,
|
|
embd: nn.Embedding
|
|
):
|
|
layer: nn.TransformerDecoderLayer
|
|
N = len(hyps)
|
|
offset = len(hyps[0])
|
|
|
|
|
|
last_toks = torch.stack([item.out_idx[-1] for item in hyps])
|
|
|
|
tgt: torch.FloatTensor = embd(last_toks).unsqueeze_(1)
|
|
|
|
|
|
memory = torch.stack([memory[idx, :, :] for idx in [item.memory_idx for item in hyps]], dim = 0)
|
|
for l, layer in enumerate(decoders):
|
|
|
|
|
|
combined_activations = torch.cat([item.cached_activations[l] for item in hyps], dim = 0)
|
|
|
|
combined_activations = torch.cat([combined_activations, tgt], dim = 1)
|
|
for i in range(N):
|
|
hyps[i].cached_activations[l] = combined_activations[i: i + 1, :, :]
|
|
|
|
tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset = offset)[0]
|
|
tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask = memory_mask, q_offset = offset)[0]
|
|
tgt = tgt + layer._ff_block(layer.norm3(tgt))
|
|
|
|
for i in range(N):
|
|
hyps[i].cached_activations[len(decoders)] = torch.cat([hyps[i].cached_activations[len(decoders)], tgt[i: i + 1, :, :]], dim = 1)
|
|
|
|
return tgt.squeeze_(1)
|
|
|
|
class OCR(nn.Module):
|
|
def __init__(self, dictionary, max_len):
|
|
super(OCR, self).__init__()
|
|
self.max_len = max_len
|
|
self.dictionary = dictionary
|
|
self.dict_size = len(dictionary)
|
|
n_decoders = 4
|
|
embd_dim = 320
|
|
nhead = 4
|
|
|
|
self.backbone = ConvNext_FeatureExtractor(48, 3, embd_dim)
|
|
self.encoders = nn.ModuleList()
|
|
self.decoders = nn.ModuleList()
|
|
for i in range(4) :
|
|
encoder = nn.TransformerEncoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
|
|
encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
|
|
encoder.forward = transformer_encoder_forward
|
|
self.encoders.append(encoder)
|
|
for i in range(5) :
|
|
decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
|
|
decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
|
|
decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True)
|
|
self.decoders.append(decoder)
|
|
self.embd = nn.Embedding(self.dict_size, embd_dim)
|
|
self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15))
|
|
self.pred = nn.Linear(embd_dim, self.dict_size)
|
|
self.pred.weight = self.embd.weight
|
|
self.color_pred1 = nn.Sequential(nn.Linear(embd_dim, 64), nn.ReLU())
|
|
self.color_pred_fg = nn.Linear(64, 3)
|
|
self.color_pred_bg = nn.Linear(64, 3)
|
|
self.color_pred_fg_ind = nn.Linear(64, 2)
|
|
self.color_pred_bg_ind = nn.Linear(64, 2)
|
|
|
|
def forward(self,
|
|
img: torch.FloatTensor,
|
|
char_idx: torch.LongTensor,
|
|
decoder_mask: torch.BoolTensor,
|
|
encoder_mask: torch.BoolTensor
|
|
):
|
|
memory = self.backbone(img)
|
|
memory = einops.rearrange(memory, 'N C 1 W -> N W C')
|
|
for layer in self.encoders :
|
|
memory = layer(memory, src_key_padding_mask = encoder_mask)
|
|
N, L = char_idx.shape
|
|
char_embd = self.embd(char_idx)
|
|
|
|
casual_mask = generate_square_subsequent_mask(L).to(img.device)
|
|
decoded = char_embd
|
|
for layer in self.decoders :
|
|
decoded = layer(decoded, memory, tgt_mask = casual_mask, tgt_key_padding_mask = decoder_mask, memory_key_padding_mask = encoder_mask)
|
|
|
|
pred_char_logits = self.pred(self.pred1(decoded))
|
|
color_feats = self.color_pred1(decoded)
|
|
return pred_char_logits, \
|
|
self.color_pred_fg(color_feats), \
|
|
self.color_pred_bg(color_feats), \
|
|
self.color_pred_fg_ind(color_feats), \
|
|
self.color_pred_bg_ind(color_feats)
|
|
|
|
def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
|
|
N, C, H, W = img.shape
|
|
assert H == 48 and C == 3
|
|
memory = self.backbone(img)
|
|
memory = einops.rearrange(memory, 'N C 1 W -> N W C')
|
|
valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
|
|
input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device)
|
|
for i, l in enumerate(valid_feats_length):
|
|
input_mask[i, l:] = True
|
|
for layer in self.encoders :
|
|
memory = layer(layer, src = memory, src_key_padding_mask = input_mask)
|
|
hypos = [Hypothesis(img.device, start_tok, end_tok, pad_tok, i, len(self.decoders), 320) for i in range(N)]
|
|
|
|
decoded = next_token_batch(hypos, memory, input_mask, self.decoders, self.embd)
|
|
|
|
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
|
|
|
|
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
|
|
new_hypos: List[Hypothesis] = []
|
|
finished_hypos = defaultdict(list)
|
|
for i in range(N):
|
|
for k in range(beams_k):
|
|
new_hypos.append(hypos[i].extend(pred_chars_index[i, k], pred_chars_values[i, k]))
|
|
hypos = new_hypos
|
|
for ixx in range(max_seq_length):
|
|
|
|
decoded = next_token_batch(hypos, memory, torch.stack([input_mask[hyp.memory_idx] for hyp in hypos]) , self.decoders, self.embd)
|
|
|
|
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1)
|
|
|
|
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1)
|
|
hypos_per_sample = defaultdict(list)
|
|
h: Hypothesis
|
|
for i, h in enumerate(hypos):
|
|
for k in range(beams_k):
|
|
hypos_per_sample[h.memory_idx].append(h.extend(pred_chars_index[i, k], pred_chars_values[i, k]))
|
|
hypos = []
|
|
|
|
for i in hypos_per_sample.keys():
|
|
cur_hypos: List[Hypothesis] = hypos_per_sample[i]
|
|
cur_hypos = sorted(cur_hypos, key = lambda a: a.sort_key())[: beams_k + 1]
|
|
|
|
to_added_hypos = []
|
|
sample_done = False
|
|
for h in cur_hypos:
|
|
if h.seq_end():
|
|
finished_hypos[i].append(h)
|
|
if len(finished_hypos[i]) >= max_finished_hypos:
|
|
sample_done = True
|
|
break
|
|
else:
|
|
if len(to_added_hypos) < beams_k:
|
|
to_added_hypos.append(h)
|
|
if not sample_done:
|
|
hypos.extend(to_added_hypos)
|
|
if len(hypos) == 0:
|
|
break
|
|
|
|
for i in range(N):
|
|
if i not in finished_hypos:
|
|
cur_hypos: List[Hypothesis] = hypos_per_sample[i]
|
|
cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
|
|
finished_hypos[i].append(cur_hypo)
|
|
assert len(finished_hypos) == N
|
|
result = []
|
|
for i in range(N):
|
|
cur_hypos = finished_hypos[i]
|
|
cur_hypo = sorted(cur_hypos, key = lambda a: a.sort_key())[0]
|
|
decoded = cur_hypo.output()
|
|
color_feats = self.color_pred1(decoded)
|
|
fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \
|
|
self.color_pred_fg(color_feats), \
|
|
self.color_pred_bg(color_feats), \
|
|
self.color_pred_fg_ind(color_feats), \
|
|
self.color_pred_bg_ind(color_feats)
|
|
result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))
|
|
return result
|
|
|
|
import numpy as np
|
|
|
|
def convert_pl_model(filename: str) :
|
|
sd = torch.load(filename, map_location = 'cpu')['state_dict']
|
|
sd2 = {}
|
|
for k, v in sd.items() :
|
|
k: str
|
|
k = k.removeprefix('model.')
|
|
sd2[k] = v
|
|
return sd2
|
|
|
|
def test_LocalViT_FeatureExtractor() :
|
|
net = ConvNext_FeatureExtractor(48, 3, 320)
|
|
inp = torch.randn(2, 3, 48, 512)
|
|
out = net(inp)
|
|
print(out.shape)
|
|
|
|
def test_infer() :
|
|
with open('alphabet-all-v7.txt', 'r') as fp :
|
|
dictionary = [s[:-1] for s in fp.readlines()]
|
|
model = OCR(dictionary, 32)
|
|
model.eval()
|
|
sd = convert_pl_model('epoch=0-step=13000.ckpt')
|
|
model.load_state_dict(sd)
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
|
params = sum([np.prod(p.size()) for p in model_parameters])
|
|
print(params)
|
|
|
|
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2RGB)
|
|
ratio = img.shape[1] / float(img.shape[0])
|
|
new_w = int(round(ratio * 48))
|
|
|
|
img = cv2.resize(img, (new_w, 48), interpolation=cv2.INTER_AREA)
|
|
|
|
img_torch = einops.rearrange((torch.from_numpy(img) / 127.5 - 1.0), 'h w c -> 1 c h w')
|
|
|
|
with torch.no_grad() :
|
|
idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch(img_torch, [new_w], 5, max_seq_length = 32)[0]
|
|
txt = ''
|
|
for i in idx :
|
|
txt += dictionary[i]
|
|
print(txt, prob)
|
|
for chid, fg, bg, fg_ind, bg_ind in zip(idx, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]) :
|
|
has_fg = (fg_ind[1] > fg_ind[0]).item()
|
|
has_bg = (bg_ind[1] > bg_ind[0]).item()
|
|
if has_fg :
|
|
fg = np.clip((fg * 255).numpy(), 0, 255)
|
|
if has_bg :
|
|
bg = np.clip((bg * 255).numpy(), 0, 255)
|
|
print(f'{dictionary[chid]} {fg if has_fg else "None"} {bg if has_bg else "None"}')
|
|
|
|
if __name__ == "__main__":
|
|
test_infer()
|
|
|