Spaces:
Runtime error
Runtime error
File size: 5,002 Bytes
2cdd41c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from datetime import timedelta
from pathlib import Path
import torch
import numpy as np
from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset
from isegm.utils.serialization import load_model
def get_time_metrics(all_ious, elapsed_time):
n_images = len(all_ious)
n_clicks = sum(map(len, all_ious))
mean_spc = elapsed_time / n_clicks
mean_spi = elapsed_time / n_images
return mean_spc, mean_spi
def load_is_model(checkpoint, device, **kwargs):
if isinstance(checkpoint, (str, Path)):
state_dict = torch.load(checkpoint, map_location='cpu')
else:
state_dict = checkpoint
if isinstance(state_dict, list):
model = load_single_is_model(state_dict[0], device, **kwargs)
models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
return model, models
else:
return load_single_is_model(state_dict, device, **kwargs)
def load_single_is_model(state_dict, device, **kwargs):
model = load_model(state_dict['config'], **kwargs)
model.load_state_dict(state_dict['state_dict'], strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def get_dataset(dataset_name, cfg):
if dataset_name == 'GrabCut':
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
elif dataset_name == 'Berkeley':
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
elif dataset_name == 'DAVIS':
dataset = DavisDataset(cfg.DAVIS_PATH)
elif dataset_name == 'SBD':
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
elif dataset_name == 'SBD_Train':
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train')
elif dataset_name == 'PascalVOC':
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test')
elif dataset_name == 'COCO_MVal':
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
else:
dataset = None
return dataset
def get_iou(gt_mask, pred_mask, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
obj_gt_mask = gt_mask == 1
intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
return intersection / union
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
noc_list = []
over_max_list = []
for iou_thr in iou_thrs:
scores_arr = np.array([_get_noc(iou_arr, iou_thr)
for iou_arr in all_ious], dtype=np.int)
score = scores_arr.mean()
over_max = (scores_arr == max_clicks).sum()
noc_list.append(score)
over_max_list.append(over_max)
return noc_list, over_max_list
def find_checkpoint(weights_folder, checkpoint_name):
weights_folder = Path(weights_folder)
if ':' in checkpoint_name:
model_name, checkpoint_name = checkpoint_name.split(':')
models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
assert len(models_candidates) == 1
model_folder = models_candidates[0]
else:
model_folder = weights_folder
if checkpoint_name.endswith('.pth'):
if Path(checkpoint_name).exists():
checkpoint_path = checkpoint_name
else:
checkpoint_path = weights_folder / checkpoint_name
else:
model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
assert len(model_checkpoints) == 1
checkpoint_path = model_checkpoints[0]
return str(checkpoint_path)
def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
n_clicks=20, model_name=None):
table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
f'{"SPC,s":^7}|{"Time":^9}|')
row_width = len(table_header)
header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
header += '-' * row_width + '\n'
header += table_header + '\n' + '-' * row_width
eval_time = str(timedelta(seconds=int(elapsed_time)))
table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
table_row += f'{noc_list[0]:^9.2f}|'
table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
return header, table_row |