Spaces:
Runtime error
Runtime error
File size: 5,070 Bytes
2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from datetime import timedelta
from pathlib import Path
import numpy as np
import torch
from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset,
PascalVocDataset, SBDEvaluationDataset)
from isegm.utils.serialization import load_model
def get_time_metrics(all_ious, elapsed_time):
n_images = len(all_ious)
n_clicks = sum(map(len, all_ious))
mean_spc = elapsed_time / n_clicks
mean_spi = elapsed_time / n_images
return mean_spc, mean_spi
def load_is_model(checkpoint, device, **kwargs):
if isinstance(checkpoint, (str, Path)):
state_dict = torch.load(checkpoint, map_location="cpu")
else:
state_dict = checkpoint
if isinstance(state_dict, list):
model = load_single_is_model(state_dict[0], device, **kwargs)
models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
return model, models
else:
return load_single_is_model(state_dict, device, **kwargs)
def load_single_is_model(state_dict, device, **kwargs):
model = load_model(state_dict["config"], **kwargs)
model.load_state_dict(state_dict["state_dict"], strict=False)
for param in model.parameters():
param.requires_grad = False
model.to(device)
model.eval()
return model
def get_dataset(dataset_name, cfg):
if dataset_name == "GrabCut":
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
elif dataset_name == "Berkeley":
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
elif dataset_name == "DAVIS":
dataset = DavisDataset(cfg.DAVIS_PATH)
elif dataset_name == "SBD":
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
elif dataset_name == "SBD_Train":
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train")
elif dataset_name == "PascalVOC":
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test")
elif dataset_name == "COCO_MVal":
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
else:
dataset = None
return dataset
def get_iou(gt_mask, pred_mask, ignore_label=-1):
ignore_gt_mask_inv = gt_mask != ignore_label
obj_gt_mask = gt_mask == 1
intersection = np.logical_and(
np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv
).sum()
union = np.logical_and(
np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv
).sum()
return intersection / union
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
def _get_noc(iou_arr, iou_thr):
vals = iou_arr >= iou_thr
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
noc_list = []
over_max_list = []
for iou_thr in iou_thrs:
scores_arr = np.array(
[_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int
)
score = scores_arr.mean()
over_max = (scores_arr == max_clicks).sum()
noc_list.append(score)
over_max_list.append(over_max)
return noc_list, over_max_list
def find_checkpoint(weights_folder, checkpoint_name):
weights_folder = Path(weights_folder)
if ":" in checkpoint_name:
model_name, checkpoint_name = checkpoint_name.split(":")
models_candidates = [
x for x in weights_folder.glob(f"{model_name}*") if x.is_dir()
]
assert len(models_candidates) == 1
model_folder = models_candidates[0]
else:
model_folder = weights_folder
if checkpoint_name.endswith(".pth"):
if Path(checkpoint_name).exists():
checkpoint_path = checkpoint_name
else:
checkpoint_path = weights_folder / checkpoint_name
else:
model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth"))
assert len(model_checkpoints) == 1
checkpoint_path = model_checkpoints[0]
return str(checkpoint_path)
def get_results_table(
noc_list,
over_max_list,
brs_type,
dataset_name,
mean_spc,
elapsed_time,
n_clicks=20,
model_name=None,
):
table_header = (
f'|{"BRS Type":^13}|{"Dataset":^11}|'
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
f'{"SPC,s":^7}|{"Time":^9}|'
)
row_width = len(table_header)
header = f"Eval results for model: {model_name}\n" if model_name is not None else ""
header += "-" * row_width + "\n"
header += table_header + "\n" + "-" * row_width
eval_time = str(timedelta(seconds=int(elapsed_time)))
table_row = f"|{brs_type:^13}|{dataset_name:^11}|"
table_row += f"{noc_list[0]:^9.2f}|"
table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|'
table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|'
table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|"
return header, table_row
|