# -*- coding: utf-8 -*- # @Time : 10/1/21 # @Author : GXYM import torch import torch.nn as nn from IndicPhotoOCR.detection.textbpn.network.layers.model_block import FPN from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg import numpy as np from IndicPhotoOCR.detection.textbpn.network.layers.CircConv import DeepSnake from IndicPhotoOCR.detection.textbpn.network.layers.GCN import GCN from IndicPhotoOCR.detection.textbpn.network.layers.RNN import RNN from IndicPhotoOCR.detection.textbpn.network.layers.Adaptive_Deformation import AdaptiveDeformation # from IndicPhotoOCR.detection.textbpn.network.layers.Transformer_old import Transformer_old from IndicPhotoOCR.detection.textbpn.network.layers.Transformer import Transformer import cv2 from IndicPhotoOCR.detection.textbpn.util.misc import get_sample_point, fill_hole from IndicPhotoOCR.detection.textbpn.network.layers.gcn_utils import get_node_feature, \ get_adj_mat, get_adj_ind, coord_embedding, normalize_adj import torch.nn.functional as F import time class Evolution(nn.Module): def __init__(self, node_num, adj_num, is_training=True, device=None, model="snake"): super(Evolution, self).__init__() self.node_num = node_num self.adj_num = adj_num self.device = device self.is_training = is_training self.clip_dis = 16 self.iter = 3 if model == "gcn": self.adj = get_adj_mat(self.adj_num, self.node_num) self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device) for i in range(self.iter): evolve_gcn = GCN(36, 128) self.__setattr__('evolve_gcn' + str(i), evolve_gcn) elif model == "rnn": self.adj = None for i in range(self.iter): evolve_gcn = RNN(36, 128) self.__setattr__('evolve_gcn' + str(i), evolve_gcn) elif model == "AD": self.adj = get_adj_mat(self.adj_num, self.node_num) self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device) for i in range(self.iter): evolve_gcn = AdaptiveDeformation(36, 128) self.__setattr__('evolve_gcn' + str(i), evolve_gcn) # elif model == "BT_old": # self.adj = None # for i in range(self.iter): # evolve_gcn = Transformer_old(36, 512, num_heads=8, # dim_feedforward=2048, drop_rate=0.0, if_resi=True, block_nums=4) # self.__setattr__('evolve_gcn' + str(i), evolve_gcn) elif model == "BT": self.adj = None for i in range(self.iter): evolve_gcn = Transformer(36, 128, num_heads=8, dim_feedforward=1024, drop_rate=0.0, if_resi=True, block_nums=3) self.__setattr__('evolve_gcn' + str(i), evolve_gcn) else: self.adj = get_adj_ind(self.adj_num, self.node_num, self.device) for i in range(self.iter): evolve_gcn = DeepSnake(state_dim=128, feature_dim=36, conv_type='dgrid') self.__setattr__('evolve_gcn' + str(i), evolve_gcn) for m in self.modules(): if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) # nn.init.kaiming_normal_(m.weight, mode='fan_in') if m.bias is not None: nn.init.constant_(m.bias, 0) @staticmethod def get_boundary_proposal(input=None, seg_preds=None, switch="gt"): if switch == "gt": inds = torch.where(input['ignore_tags'] > 0) # if len(inds[0]) > 320: # inds = (inds[0][:320], inds[1][:320]) init_polys = input['proposal_points'][inds] else: tr_masks = input['tr_mask'].cpu().numpy() tcl_masks = seg_preds[:, 0, :, :].detach().cpu().numpy() > cfg.threshold inds = [] init_polys = [] for bid, tcl_mask in enumerate(tcl_masks): ret, labels = cv2.connectedComponents(tcl_mask.astype(np.uint8), connectivity=8) for idx in range(1, ret): text_mask = labels == idx ist_id = int(np.sum(text_mask*tr_masks[bid])/np.sum(text_mask))-1 inds.append([bid, ist_id]) poly = get_sample_point(text_mask, cfg.num_points, cfg.approx_factor) init_polys.append(poly) inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device) init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device) return init_polys, inds, None def get_boundary_proposal_eval(self, input=None, seg_preds=None): # if cfg.scale > 1: # seg_preds = F.interpolate(seg_preds, scale_factor=cfg.scale, mode='bilinear') cls_preds = seg_preds[:, 0, :, :].detach().cpu().numpy() dis_preds = seg_preds[:, 1, :, ].detach().cpu().numpy() inds = [] init_polys = [] confidences = [] for bid, dis_pred in enumerate(dis_preds): # # dis_mask = (dis_pred / np.max(dis_pred)) > cfg.dis_threshold dis_mask = dis_pred > cfg.dis_threshold # dis_mask = fill_hole(dis_mask) ret, labels = cv2.connectedComponents(dis_mask.astype(np.uint8), connectivity=8, ltype=cv2.CV_16U) for idx in range(1, ret): text_mask = labels == idx confidence = round(cls_preds[bid][text_mask].mean(), 3) # 50 for MLT2017 and ArT (or DCN is used in backone); else is all 150; # just can set to 50, which has little effect on the performance if np.sum(text_mask) < 50/(cfg.scale*cfg.scale) or confidence < cfg.cls_threshold: continue confidences.append(confidence) inds.append([bid, 0]) poly = get_sample_point(text_mask, cfg.num_points, cfg.approx_factor, scales=np.array([cfg.scale, cfg.scale])) init_polys.append(poly) if len(inds) > 0: inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device, non_blocking=True) init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float() else: init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float() inds = torch.from_numpy(np.array(inds)).to(input["img"].device, non_blocking=True) return init_polys, inds, confidences def evolve_poly(self, snake, cnn_feature, i_it_poly, ind): if len(i_it_poly) == 0: return torch.zeros_like(i_it_poly) h, w = cnn_feature.size(2)*cfg.scale, cnn_feature.size(3)*cfg.scale node_feats = get_node_feature(cnn_feature, i_it_poly, ind, h, w) i_poly = i_it_poly + torch.clamp(snake(node_feats, self.adj).permute(0, 2, 1), -self.clip_dis, self.clip_dis) if self.is_training: i_poly = torch.clamp(i_poly, 0, w-1) else: i_poly[:, :, 0] = torch.clamp(i_poly[:, :, 0], 0, w - 1) i_poly[:, :, 1] = torch.clamp(i_poly[:, :, 1], 0, h - 1) return i_poly def forward(self, embed_feature, input=None, seg_preds=None, switch="gt"): if self.is_training: init_polys, inds, confidences = self.get_boundary_proposal(input=input, seg_preds=seg_preds, switch=switch) # TODO sample fix number else: init_polys, inds, confidences = self.get_boundary_proposal_eval(input=input, seg_preds=seg_preds) if init_polys.shape[0] == 0: return [init_polys for i in range(self.iter+1)], inds, confidences py_preds = [init_polys, ] for i in range(self.iter): evolve_gcn = self.__getattr__('evolve_gcn' + str(i)) init_polys = self.evolve_poly(evolve_gcn, embed_feature, init_polys, inds[0]) py_preds.append(init_polys) return py_preds, inds, confidences class TextNet(nn.Module): def __init__(self, backbone='vgg', is_training=True): super().__init__() self.is_training = is_training self.backbone_name = backbone self.fpn = FPN(self.backbone_name, is_training=(not cfg.resume and is_training)) self.seg_head = nn.Sequential( nn.Conv2d(32, 16, kernel_size=3, padding=2, dilation=2), nn.PReLU(), nn.Conv2d(16, 16, kernel_size=3, padding=4, dilation=4), nn.PReLU(), nn.Conv2d(16, 4, kernel_size=1, stride=1, padding=0), ) self.BPN = Evolution(cfg.num_points, adj_num=4, is_training=is_training, device=cfg.device, model="BT") def load_model(self, model_path): print('Loading from {}'.format(model_path)) state_dict = torch.load(model_path, map_location=torch.device(cfg.device)) self.load_state_dict(state_dict['model'], strict=(not self.is_training)) def forward(self, input_dict, test_speed=False): output = {} b, c, h, w = input_dict["img"].shape if self.is_training or cfg.exp_name in ['ArT', 'MLT2017', "MLT2019"] or test_speed: image = input_dict["img"] else: image = torch.zeros((b, c, cfg.test_size[1], cfg.test_size[1]), dtype=torch.float32).to(cfg.device) image[:, :, :h, :w] = input_dict["img"][:, :, :, :] up1, _, _, _, _ = self.fpn(image) up1 = up1[:, :, :h // cfg.scale, :w // cfg.scale] preds = self.seg_head(up1) fy_preds = torch.cat([torch.sigmoid(preds[:, 0:2, :, :]), preds[:, 2:4, :, :]], dim=1) cnn_feats = torch.cat([up1, fy_preds], dim=1) py_preds, inds, confidences = self.BPN(cnn_feats, input=input_dict, seg_preds=fy_preds, switch="gt") output["fy_preds"] = fy_preds output["py_preds"] = py_preds output["inds"] = inds output["confidences"] = confidences return output