import torch import numpy as np import cv2 import os import math from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg from IndicPhotoOCR.detection.textbpn.util import canvas as cav import matplotlib matplotlib.use('agg') import pylab as plt from matplotlib import cm import torch.nn.functional as F def visualize_network_output(output_dict, input_dict, mode='train'): vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name + '_' + mode) if not os.path.exists(vis_dir): os.mkdir(vis_dir) fy_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear') fy_preds = fy_preds.data.cpu().numpy() py_preds = output_dict["py_preds"][1:] init_polys = output_dict["py_preds"][0] inds = output_dict["inds"] image = input_dict['img'] tr_mask = input_dict['tr_mask'].data.cpu().numpy() > 0 distance_field = input_dict['distance_field'].data.cpu().numpy() direction_field = input_dict['direction_field'] weight_matrix = input_dict['weight_matrix'] gt_tags = input_dict['gt_points'].cpu().numpy() ignore_tags = input_dict['ignore_tags'].cpu().numpy() b, c, _, _ = fy_preds.shape for i in range(b): fig = plt.figure(figsize=(12, 9)) mask_pred = fy_preds[i, 0, :, :] distance_pred = fy_preds[i, 1, :, :] norm_pred = np.sqrt(fy_preds[i, 2, :, :] ** 2 + fy_preds[i, 3, :, :] ** 2) angle_pred = 180 / math.pi * np.arctan2(fy_preds[i, 2, :, :], fy_preds[i, 3, :, :] + 0.00001) ax1 = fig.add_subplot(341) ax1.set_title('mask_pred') # ax1.set_autoscale_on(True) im1 = ax1.imshow(mask_pred, cmap=cm.jet) # plt.colorbar(im1, shrink=0.5) ax2 = fig.add_subplot(342) ax2.set_title('distance_pred') # ax2.set_autoscale_on(True) im2 = ax2.imshow(distance_pred, cmap=cm.jet) # plt.colorbar(im2, shrink=0.5) ax3 = fig.add_subplot(343) ax3.set_title('norm_pred') # ax3.set_autoscale_on(True) im3 = ax3.imshow(norm_pred, cmap=cm.jet) # plt.colorbar(im3, shrink=0.5) ax4 = fig.add_subplot(344) ax4.set_title('angle_pred') # ax4.set_autoscale_on(True) im4 = ax4.imshow(angle_pred, cmap=cm.jet) # plt.colorbar(im4, shrink=0.5) mask_gt = tr_mask[i] distance_gt = distance_field[i] # gt_flux = 0.999999 * direction_field[i] / (direction_field[i].norm(p=2, dim=0) + 1e-9) gt_flux = direction_field[i].cpu().numpy() norm_gt = np.sqrt(gt_flux[0, :, :] ** 2 + gt_flux[1, :, :] ** 2) angle_gt = 180 / math.pi * np.arctan2(gt_flux[0, :, :], gt_flux[1, :, :]+0.00001) ax11 = fig.add_subplot(345) # ax11.set_title('mask_gt') # ax11.set_autoscale_on(True) im11 = ax11.imshow(mask_gt, cmap=cm.jet) # plt.colorbar(im11, shrink=0.5) ax22 = fig.add_subplot(346) # ax22.set_title('distance_gt') # ax22.set_autoscale_on(True) im22 = ax22.imshow(distance_gt, cmap=cm.jet) # plt.colorbar(im22, shrink=0.5) ax33 = fig.add_subplot(347) # ax33.set_title('norm_gt') # ax33.set_autoscale_on(True) im33 = ax33.imshow(norm_gt, cmap=cm.jet) # plt.colorbar(im33, shrink=0.5) ax44 = fig.add_subplot(348) # ax44.set_title('angle_gt') # ax44.set_autoscale_on(True) im44 = ax44.imshow(angle_gt, cmap=cm.jet) # plt.colorbar(im44, shrink=0.5) img_show = image[i].permute(1, 2, 0).cpu().numpy() img_show = ((img_show * cfg.stds + cfg.means) * 255).astype(np.uint8) img_show = np.ascontiguousarray(img_show[:, :, ::-1]) shows = [] gt = gt_tags[i] gt_idx = np.where(ignore_tags[i] > 0) gt_py = gt[gt_idx[0], :, :] index = torch.where(inds[0] == i)[0] init_py = init_polys[index].detach().cpu().numpy() image_show = img_show.copy() cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2) cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2) shows.append(image_show) for py in py_preds: contours = py[index].detach().cpu().numpy() image_show = img_show.copy() cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2) cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2) cv2.drawContours(image_show, contours.astype(np.int32), -1, (0, 0, 255), 2) shows.append(image_show) for idx, im_show in enumerate(shows): axb = fig.add_subplot(3, 4, 9+idx) # axb.set_title('boundary_{}'.format(idx)) # axb.set_autoscale_on(True) im11 = axb.imshow(im_show, cmap=cm.jet) # plt.colorbar(im11, shrink=0.5) path = os.path.join(vis_dir, '{}.png'.format(i)) plt.savefig(path) plt.close(fig) def visualize_gt(image, contours, label_tag): image_show = image.copy() image_show = np.ascontiguousarray(image_show[:, :, ::-1]) image_show = cv2.polylines(image_show, [contours[i] for i, tag in enumerate(label_tag) if tag >0], True, (0, 0, 255), 3) image_show = cv2.polylines(image_show, [contours[i] for i, tag in enumerate(label_tag) if tag <0], True, (0, 255, 0), 3) show_gt = cv2.resize(image_show, (320, 320)) return show_gt def visualize_detection(image, output_dict, meta=None): image_show = image.copy() image_show = np.ascontiguousarray(image_show[:, :, ::-1]) cls_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear') cls_preds = cls_preds[0].data.cpu().numpy() py_preds = output_dict["py_preds"][1:] init_polys = output_dict["py_preds"][0] shows = [] init_py = init_polys.data.cpu().numpy() path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name), meta['image_id'][0].split(".")[0] + "_init.png") im_show0 = image_show.copy() for i, bpts in enumerate(init_py.astype(np.int32)): cv2.drawContours(im_show0, [bpts.astype(np.int32)], -1, (255, 255, 0), 2) for j, pp in enumerate(bpts): if j == 0: cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1) elif j == 1: cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1) else: cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 0, 255), -1) cv2.imwrite(path, im_show0) for idx, py in enumerate(py_preds): im_show = im_show0.copy() contours = py.data.cpu().numpy() cv2.drawContours(im_show, contours.astype(np.int32), -1, (0, 0, 255), 2) for ppts in contours: for j, pp in enumerate(ppts): if j == 0: cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1) elif j == 1: cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1) else: cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 0), -1) path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name), meta['image_id'][0].split(".")[0] + "_{}iter.png".format(idx)) cv2.imwrite(path, im_show) shows.append(im_show) # init_py = init_polys.data.cpu().numpy() # im_show_score = image_show.copy() # for in_py in init_py: # mask = np.zeros_like(cls_preds[0], dtype=np.uint8) # cv2.drawContours(mask, [in_py.astype(np.int32)], -1, (1,), -1) # score = cls_preds[0][mask > 0].mean() # if score > 0.9: # cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (0, 255, 0), 2) # else: # cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (255, 0, 255), 2) # cv2.putText(im_show_score, "{:.2f}".format(score), # (int(np.mean(in_py[:, 0])), int(np.mean(in_py[:, 1]))), 1, 1, (0, 255, 255), 2) # print(score) # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name), # meta['image_id'][0].split(".")[0] + "init.png") # cv2.imwrite(path, im_show_score) show_img = np.concatenate(shows, axis=1) show_boundary = cv2.resize(show_img, (320 * len(py_preds), 320)) # fig = plt.figure(figsize=(5, 4)) # ax1 = fig.add_subplot(111) # # ax1.set_title('distance_field') # ax1.set_autoscale_on(True) # im1 = ax1.imshow(cls_preds[0], cmap=cm.jet) # plt.colorbar(im1, shrink=0.75) # plt.axis("off") # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name), # meta['image_id'][0].split(".")[0] + "_cls.png") # plt.savefig(path, dpi=300) # plt.close(fig) # # fig = plt.figure(figsize=(5, 4)) # ax1 = fig.add_subplot(111) # # ax1.set_title('distance_field') # ax1.set_autoscale_on(True) # im1 = ax1.imshow(np.array(cls_preds[1] / np.max(cls_preds[1])), cmap=cm.jet) # plt.colorbar(im1, shrink=0.75) # plt.axis("off") # path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name), # meta['image_id'][0].split(".")[0] + "_dis.png") # plt.savefig(path, dpi=300) # plt.close(fig) cls_pred = cav.heatmap(np.array(cls_preds[0] * 255, dtype=np.uint8)) dis_pred = cav.heatmap(np.array(cls_preds[1] * 255, dtype=np.uint8)) heat_map = np.concatenate([cls_pred*255, dis_pred*255], axis=1) heat_map = cv2.resize(heat_map, (320 * 2, 320)) return show_boundary, heat_map