File size: 5,397 Bytes
4f6b78d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import sys
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
import numpy as np
from davis2017.davis import MaskDataset
from davis2017.metrics import db_eval_boundary, db_eval_iou
from davis2017 import utils
from davis2017.results import Results
from scipy.optimize import linear_sum_assignment
from skimage.transform import resize
import cv2
import PIL
def _resize_pil_image(img, long_edge_size, nearest=False):
S = max(img.size)
if S > long_edge_size:
interp = PIL.Image.LANCZOS if not nearest else PIL.Image.NEAREST
elif S <= long_edge_size:
interp = PIL.Image.BICUBIC
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
return img.resize(new_size, interp)
def crop_img(img, size, square_ok=False, nearest=True, crop=True):
W1, H1 = img.size
if size == 224:
# resize short side to 224 (then crop)
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)), nearest=nearest)
else:
# resize long side to 512
img = _resize_pil_image(img, size, nearest=nearest)
W, H = img.size
cx, cy = W//2, H//2
if size == 224:
half = min(cx, cy)
img = img.crop((cx-half, cy-half, cx+half, cy+half))
else:
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
if not (square_ok) and W == H:
halfh = 3*halfw/4
if crop:
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
else: # resize
img = img.resize((2*halfw, 2*halfh), PIL.Image.NEAREST)
return img
class MaskEvaluation(object):
def __init__(self, root, sequences):
self.dataset = MaskDataset(root=root, sequences=sequences)
self.sequences = sequences
@staticmethod
def _evaluate(all_gt_masks, all_res_masks, all_void_masks, metric):
for i in range(len(all_gt_masks)):
all_gt_masks[i]= (np.array(crop_img(all_gt_masks[i], 512, square_ok=True)) > 0.5) * 255
for i in range(len(all_res_masks)):
all_res_masks[i]= np.array(all_res_masks[i])
for i in range(len(all_res_masks)):
if i % 10 == 0:
concatenated_mask = np.concatenate((all_gt_masks[i], all_res_masks[i]), axis=1).astype(np.uint8)
import matplotlib.pyplot as plt
plt.imshow(concatenated_mask, cmap='gray')
plt.title(f'Mask {i}')
plt.show()
all_gt_masks = np.stack(all_gt_masks, axis=0)
all_res_masks = np.stack(all_res_masks, axis=0)
if all_res_masks.shape[0] > all_gt_masks.shape[0]:
all_res_masks = all_res_masks[:all_gt_masks.shape[0], ...]
elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
# Resize all_res_masks to match all_gt_masks using interpolation
# all_res_masks = resized_res_masks
j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])
for ii in range(all_gt_masks.shape[0]):
if 'J' in metric:
j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
if 'F' in metric:
f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
return j_metrics_res, f_metrics_res
def evaluate(self, res_path, metric=('J', 'F')):
metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]
if 'T' in metric:
raise ValueError('Temporal metric not supported!')
if 'J' not in metric and 'F' not in metric:
raise ValueError('Metric possible values are J for IoU or F for Boundary')
# Containers
metrics_res = {}
if 'J' in metric:
metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
if 'F' in metric:
metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
results = MaskDataset(root=res_path, sequences=self.sequences, is_label=False)
# Sweep all sequences
for seq in tqdm(self.sequences):
all_gt_masks = self.dataset.read_masks(seq)
all_res_masks = results.read_masks(seq)
j_metrics_res, f_metrics_res = self._evaluate(all_gt_masks, all_res_masks, None, metric)
for ii in range(len(all_gt_masks)):
seq_name = f'{seq}_{ii+1}'
if 'J' in metric:
[JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])
metrics_res['J']["M"].append(JM)
metrics_res['J']["R"].append(JR)
metrics_res['J']["D"].append(JD)
metrics_res['J']["M_per_object"][seq_name] = JM
if 'F' in metric:
[FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])
metrics_res['F']["M"].append(FM)
metrics_res['F']["R"].append(FR)
metrics_res['F']["D"].append(FD)
metrics_res['F']["M_per_object"][seq_name] = FM
return metrics_res
|