Sunday01's picture
up
9dce458
import os
import shutil
import numpy as np
import torch
import cv2
import einops
from typing import List, Tuple
from .default_utils.DBNet_resnet34 import TextDetection as TextDetectionDefault
from .default_utils import imgproc, dbnet_utils, craft_utils
from .common import OfflineDetector
from ..utils import TextBlock, Quadrilateral, det_rearrange_forward
MODEL = None
def det_batch_forward_default(batch: np.ndarray, device: str):
global MODEL
if isinstance(batch, list):
batch = np.array(batch)
batch = einops.rearrange(batch.astype(np.float32) / 127.5 - 1.0, 'n h w c -> n c h w')
batch = torch.from_numpy(batch).to(device)
with torch.no_grad():
db, mask = MODEL(batch)
db = db.sigmoid().cpu().numpy()
mask = mask.cpu().numpy()
return db, mask
class DefaultDetector(OfflineDetector):
_MODEL_MAPPING = {
'model': {
'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/detect.ckpt',
'hash': '69080aea78de0803092bc6b751ae283ca463011de5f07e1d20e6491b05571a30',
'file': '.',
}
}
def __init__(self, *args, **kwargs):
os.makedirs(self.model_dir, exist_ok=True)
if os.path.exists('detect.ckpt'):
shutil.move('detect.ckpt', self._get_file_path('detect.ckpt'))
super().__init__(*args, **kwargs)
async def _load(self, device: str):
self.model = TextDetectionDefault()
sd = torch.load(self._get_file_path('detect.ckpt'), map_location='cpu')
self.model.load_state_dict(sd['model'] if 'model' in sd else sd)
self.model.eval()
self.device = device
if device == 'cuda' or device == 'mps':
self.model = self.model.to(self.device)
global MODEL
MODEL = self.model
async def _unload(self):
del self.model
async def _infer(self, image: np.ndarray, detect_size: int, text_threshold: float, box_threshold: float,
unclip_ratio: float, verbose: bool = False):
# TODO: Move det_rearrange_forward to common.py and refactor
db, mask = det_rearrange_forward(image, det_batch_forward_default, detect_size, 4, device=self.device, verbose=verbose)
if db is None:
# rearrangement is not required, fallback to default forward
img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(cv2.bilateralFilter(image, 17, 80, 80), detect_size, cv2.INTER_LINEAR, mag_ratio = 1)
img_resized_h, img_resized_w = img_resized.shape[:2]
ratio_h = ratio_w = 1 / target_ratio
db, mask = det_batch_forward_default([img_resized], self.device)
else:
img_resized_h, img_resized_w = image.shape[:2]
ratio_w = ratio_h = 1
pad_h = pad_w = 0
self.logger.info(f'Detection resolution: {img_resized_w}x{img_resized_h}')
mask = mask[0, 0, :, :]
det = dbnet_utils.SegDetectorRepresenter(text_threshold, box_threshold, unclip_ratio=unclip_ratio)
# boxes, scores = det({'shape': [(img_resized.shape[0], img_resized.shape[1])]}, db)
boxes, scores = det({'shape':[(img_resized_h, img_resized_w)]}, db)
boxes, scores = boxes[0], scores[0]
if boxes.size == 0:
polys = []
else:
idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0
polys, _ = boxes[idx], scores[idx]
polys = polys.astype(np.float64)
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=1)
polys = polys.astype(np.int16)
textlines = [Quadrilateral(pts.astype(int), '', score) for pts, score in zip(polys, scores)]
textlines = list(filter(lambda q: q.area > 16, textlines))
mask_resized = cv2.resize(mask, (mask.shape[1] * 2, mask.shape[0] * 2), interpolation=cv2.INTER_LINEAR)
if pad_h > 0:
mask_resized = mask_resized[:-pad_h, :]
elif pad_w > 0:
mask_resized = mask_resized[:, :-pad_w]
raw_mask = np.clip(mask_resized * 255, 0, 255).astype(np.uint8)
# if verbose:
# img_bbox_raw = np.copy(image)
# for txtln in textlines:
# cv2.polylines(img_bbox_raw, [txtln.pts], True, color=(255, 0, 0), thickness=2)
# cv2.imwrite(f'result/bboxes_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR))
return textlines, raw_mask, None