shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
import torch
import numpy as np
import cv2
import os
import math
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
from IndicPhotoOCR.detection.textbpn.util import canvas as cav
import matplotlib
matplotlib.use('agg')
import pylab as plt
from matplotlib import cm
import torch.nn.functional as F
def visualize_network_output(output_dict, input_dict, mode='train'):
vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name + '_' + mode)
if not os.path.exists(vis_dir):
os.mkdir(vis_dir)
fy_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear')
fy_preds = fy_preds.data.cpu().numpy()
py_preds = output_dict["py_preds"][1:]
init_polys = output_dict["py_preds"][0]
inds = output_dict["inds"]
image = input_dict['img']
tr_mask = input_dict['tr_mask'].data.cpu().numpy() > 0
distance_field = input_dict['distance_field'].data.cpu().numpy()
direction_field = input_dict['direction_field']
weight_matrix = input_dict['weight_matrix']
gt_tags = input_dict['gt_points'].cpu().numpy()
ignore_tags = input_dict['ignore_tags'].cpu().numpy()
b, c, _, _ = fy_preds.shape
for i in range(b):
fig = plt.figure(figsize=(12, 9))
mask_pred = fy_preds[i, 0, :, :]
distance_pred = fy_preds[i, 1, :, :]
norm_pred = np.sqrt(fy_preds[i, 2, :, :] ** 2 + fy_preds[i, 3, :, :] ** 2)
angle_pred = 180 / math.pi * np.arctan2(fy_preds[i, 2, :, :], fy_preds[i, 3, :, :] + 0.00001)
ax1 = fig.add_subplot(341)
ax1.set_title('mask_pred')
# ax1.set_autoscale_on(True)
im1 = ax1.imshow(mask_pred, cmap=cm.jet)
# plt.colorbar(im1, shrink=0.5)
ax2 = fig.add_subplot(342)
ax2.set_title('distance_pred')
# ax2.set_autoscale_on(True)
im2 = ax2.imshow(distance_pred, cmap=cm.jet)
# plt.colorbar(im2, shrink=0.5)
ax3 = fig.add_subplot(343)
ax3.set_title('norm_pred')
# ax3.set_autoscale_on(True)
im3 = ax3.imshow(norm_pred, cmap=cm.jet)
# plt.colorbar(im3, shrink=0.5)
ax4 = fig.add_subplot(344)
ax4.set_title('angle_pred')
# ax4.set_autoscale_on(True)
im4 = ax4.imshow(angle_pred, cmap=cm.jet)
# plt.colorbar(im4, shrink=0.5)
mask_gt = tr_mask[i]
distance_gt = distance_field[i]
# gt_flux = 0.999999 * direction_field[i] / (direction_field[i].norm(p=2, dim=0) + 1e-9)
gt_flux = direction_field[i].cpu().numpy()
norm_gt = np.sqrt(gt_flux[0, :, :] ** 2 + gt_flux[1, :, :] ** 2)
angle_gt = 180 / math.pi * np.arctan2(gt_flux[0, :, :], gt_flux[1, :, :]+0.00001)
ax11 = fig.add_subplot(345)
# ax11.set_title('mask_gt')
# ax11.set_autoscale_on(True)
im11 = ax11.imshow(mask_gt, cmap=cm.jet)
# plt.colorbar(im11, shrink=0.5)
ax22 = fig.add_subplot(346)
# ax22.set_title('distance_gt')
# ax22.set_autoscale_on(True)
im22 = ax22.imshow(distance_gt, cmap=cm.jet)
# plt.colorbar(im22, shrink=0.5)
ax33 = fig.add_subplot(347)
# ax33.set_title('norm_gt')
# ax33.set_autoscale_on(True)
im33 = ax33.imshow(norm_gt, cmap=cm.jet)
# plt.colorbar(im33, shrink=0.5)
ax44 = fig.add_subplot(348)
# ax44.set_title('angle_gt')
# ax44.set_autoscale_on(True)
im44 = ax44.imshow(angle_gt, cmap=cm.jet)
# plt.colorbar(im44, shrink=0.5)
img_show = image[i].permute(1, 2, 0).cpu().numpy()
img_show = ((img_show * cfg.stds + cfg.means) * 255).astype(np.uint8)
img_show = np.ascontiguousarray(img_show[:, :, ::-1])
shows = []
gt = gt_tags[i]
gt_idx = np.where(ignore_tags[i] > 0)
gt_py = gt[gt_idx[0], :, :]
index = torch.where(inds[0] == i)[0]
init_py = init_polys[index].detach().cpu().numpy()
image_show = img_show.copy()
cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2)
cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2)
shows.append(image_show)
for py in py_preds:
contours = py[index].detach().cpu().numpy()
image_show = img_show.copy()
cv2.drawContours(image_show, init_py.astype(np.int32), -1, (255, 255, 0), 2)
cv2.drawContours(image_show, gt_py.astype(np.int32), -1, (0, 255, 0), 2)
cv2.drawContours(image_show, contours.astype(np.int32), -1, (0, 0, 255), 2)
shows.append(image_show)
for idx, im_show in enumerate(shows):
axb = fig.add_subplot(3, 4, 9+idx)
# axb.set_title('boundary_{}'.format(idx))
# axb.set_autoscale_on(True)
im11 = axb.imshow(im_show, cmap=cm.jet)
# plt.colorbar(im11, shrink=0.5)
path = os.path.join(vis_dir, '{}.png'.format(i))
plt.savefig(path)
plt.close(fig)
def visualize_gt(image, contours, label_tag):
image_show = image.copy()
image_show = np.ascontiguousarray(image_show[:, :, ::-1])
image_show = cv2.polylines(image_show,
[contours[i] for i, tag in enumerate(label_tag) if tag >0], True, (0, 0, 255), 3)
image_show = cv2.polylines(image_show,
[contours[i] for i, tag in enumerate(label_tag) if tag <0], True, (0, 255, 0), 3)
show_gt = cv2.resize(image_show, (320, 320))
return show_gt
def visualize_detection(image, output_dict, meta=None):
image_show = image.copy()
image_show = np.ascontiguousarray(image_show[:, :, ::-1])
cls_preds = F.interpolate(output_dict["fy_preds"], scale_factor=cfg.scale, mode='bilinear')
cls_preds = cls_preds[0].data.cpu().numpy()
py_preds = output_dict["py_preds"][1:]
init_polys = output_dict["py_preds"][0]
shows = []
init_py = init_polys.data.cpu().numpy()
path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
meta['image_id'][0].split(".")[0] + "_init.png")
im_show0 = image_show.copy()
for i, bpts in enumerate(init_py.astype(np.int32)):
cv2.drawContours(im_show0, [bpts.astype(np.int32)], -1, (255, 255, 0), 2)
for j, pp in enumerate(bpts):
if j == 0:
cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1)
elif j == 1:
cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1)
else:
cv2.circle(im_show0, (int(pp[0]), int(pp[1])), 3, (0, 0, 255), -1)
cv2.imwrite(path, im_show0)
for idx, py in enumerate(py_preds):
im_show = im_show0.copy()
contours = py.data.cpu().numpy()
cv2.drawContours(im_show, contours.astype(np.int32), -1, (0, 0, 255), 2)
for ppts in contours:
for j, pp in enumerate(ppts):
if j == 0:
cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (255, 0, 255), -1)
elif j == 1:
cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 255), -1)
else:
cv2.circle(im_show, (int(pp[0]), int(pp[1])), 3, (0, 255, 0), -1)
path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
meta['image_id'][0].split(".")[0] + "_{}iter.png".format(idx))
cv2.imwrite(path, im_show)
shows.append(im_show)
# init_py = init_polys.data.cpu().numpy()
# im_show_score = image_show.copy()
# for in_py in init_py:
# mask = np.zeros_like(cls_preds[0], dtype=np.uint8)
# cv2.drawContours(mask, [in_py.astype(np.int32)], -1, (1,), -1)
# score = cls_preds[0][mask > 0].mean()
# if score > 0.9:
# cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (0, 255, 0), 2)
# else:
# cv2.drawContours(im_show_score, [in_py.astype(np.int32)], -1, (255, 0, 255), 2)
# cv2.putText(im_show_score, "{:.2f}".format(score),
# (int(np.mean(in_py[:, 0])), int(np.mean(in_py[:, 1]))), 1, 1, (0, 255, 255), 2)
# print(score)
# path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
# meta['image_id'][0].split(".")[0] + "init.png")
# cv2.imwrite(path, im_show_score)
show_img = np.concatenate(shows, axis=1)
show_boundary = cv2.resize(show_img, (320 * len(py_preds), 320))
# fig = plt.figure(figsize=(5, 4))
# ax1 = fig.add_subplot(111)
# # ax1.set_title('distance_field')
# ax1.set_autoscale_on(True)
# im1 = ax1.imshow(cls_preds[0], cmap=cm.jet)
# plt.colorbar(im1, shrink=0.75)
# plt.axis("off")
# path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
# meta['image_id'][0].split(".")[0] + "_cls.png")
# plt.savefig(path, dpi=300)
# plt.close(fig)
#
# fig = plt.figure(figsize=(5, 4))
# ax1 = fig.add_subplot(111)
# # ax1.set_title('distance_field')
# ax1.set_autoscale_on(True)
# im1 = ax1.imshow(np.array(cls_preds[1] / np.max(cls_preds[1])), cmap=cm.jet)
# plt.colorbar(im1, shrink=0.75)
# plt.axis("off")
# path = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name),
# meta['image_id'][0].split(".")[0] + "_dis.png")
# plt.savefig(path, dpi=300)
# plt.close(fig)
cls_pred = cav.heatmap(np.array(cls_preds[0] * 255, dtype=np.uint8))
dis_pred = cav.heatmap(np.array(cls_preds[1] * 255, dtype=np.uint8))
heat_map = np.concatenate([cls_pred*255, dis_pred*255], axis=1)
heat_map = cv2.resize(heat_map, (320 * 2, 320))
return show_boundary, heat_map