File size: 5,397 Bytes
4f6b78d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

import numpy as np
from davis2017.davis import MaskDataset
from davis2017.metrics import db_eval_boundary, db_eval_iou
from davis2017 import utils
from davis2017.results import Results
from scipy.optimize import linear_sum_assignment
from skimage.transform import resize
import cv2
import PIL

def _resize_pil_image(img, long_edge_size, nearest=False):
    S = max(img.size)
    if S > long_edge_size:
        interp = PIL.Image.LANCZOS if not nearest else PIL.Image.NEAREST
    elif S <= long_edge_size:
        interp = PIL.Image.BICUBIC
    new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
    return img.resize(new_size, interp)

def crop_img(img, size, square_ok=False, nearest=True, crop=True):
    W1, H1 = img.size
    if size == 224:
        # resize short side to 224 (then crop)
        img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)), nearest=nearest)
    else:
        # resize long side to 512
        img = _resize_pil_image(img, size, nearest=nearest)
    W, H = img.size
    cx, cy = W//2, H//2
    if size == 224:
        half = min(cx, cy)
        img = img.crop((cx-half, cy-half, cx+half, cy+half))
    else:
        halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
        if not (square_ok) and W == H:
            halfh = 3*halfw/4
        if crop:
            img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
        else: # resize
            img = img.resize((2*halfw, 2*halfh), PIL.Image.NEAREST)
    return img


class MaskEvaluation(object):
    def __init__(self, root, sequences):
        self.dataset = MaskDataset(root=root, sequences=sequences)
        self.sequences = sequences


    @staticmethod
    def _evaluate(all_gt_masks, all_res_masks, all_void_masks, metric):
        for i in range(len(all_gt_masks)):
            all_gt_masks[i]= (np.array(crop_img(all_gt_masks[i], 512, square_ok=True)) > 0.5) * 255 

        for i in range(len(all_res_masks)):
            all_res_masks[i]= np.array(all_res_masks[i])   
        
        for i in range(len(all_res_masks)):
            if i % 10 == 0:
                concatenated_mask = np.concatenate((all_gt_masks[i], all_res_masks[i]), axis=1).astype(np.uint8)
                import matplotlib.pyplot as plt
                plt.imshow(concatenated_mask, cmap='gray')
                plt.title(f'Mask {i}')
                plt.show()

        all_gt_masks = np.stack(all_gt_masks, axis=0)
        all_res_masks = np.stack(all_res_masks, axis=0)


        if all_res_masks.shape[0] > all_gt_masks.shape[0]:
            all_res_masks = all_res_masks[:all_gt_masks.shape[0], ...]
        elif all_res_masks.shape[0] < all_gt_masks.shape[0]:
            zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:]))
            all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0)
            # Resize all_res_masks to match all_gt_masks using interpolation

        
        # all_res_masks = resized_res_masks
        
        j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2])
       
        for ii in range(all_gt_masks.shape[0]):
            if 'J' in metric:
                j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
            if 'F' in metric:
                f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks)
        return j_metrics_res, f_metrics_res

    def evaluate(self, res_path, metric=('J', 'F')):
        metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric]
        if 'T' in metric:
            raise ValueError('Temporal metric not supported!')
        if 'J' not in metric and 'F' not in metric:
            raise ValueError('Metric possible values are J for IoU or F for Boundary')

        # Containers
        metrics_res = {}
        if 'J' in metric:
            metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}}
        if 'F' in metric:
            metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}}

        results = MaskDataset(root=res_path, sequences=self.sequences, is_label=False)

        # Sweep all sequences
        for seq in tqdm(self.sequences):
            all_gt_masks = self.dataset.read_masks(seq)
            all_res_masks = results.read_masks(seq)
            j_metrics_res, f_metrics_res = self._evaluate(all_gt_masks, all_res_masks, None, metric)
            for ii in range(len(all_gt_masks)):
                seq_name = f'{seq}_{ii+1}'
                if 'J' in metric:
                    [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii])
                    metrics_res['J']["M"].append(JM)
                    metrics_res['J']["R"].append(JR)
                    metrics_res['J']["D"].append(JD)
                    metrics_res['J']["M_per_object"][seq_name] = JM
                if 'F' in metric:
                    [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii])
                    metrics_res['F']["M"].append(FM)
                    metrics_res['F']["R"].append(FR)
                    metrics_res['F']["D"].append(FD)
                    metrics_res['F']["M_per_object"][seq_name] = FM

        return metrics_res