Spaces:
Paused
Paused
File size: 10,856 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import numpy as np
import os
import tempfile
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict
from functools import lru_cache
import torch
from detectron2.data import MetadataCatalog
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager
from .evaluator import DatasetEvaluator
class PascalVOCDetectionEvaluator(DatasetEvaluator):
"""
Evaluate Pascal VOC style AP for Pascal VOC dataset.
It contains a synchronization, therefore has to be called from all ranks.
Note that the concept of AP can be implemented in different ways and may not
produce identical results. This class mimics the implementation of the official
Pascal VOC Matlab API, and should produce similar but not identical results to the
official API.
"""
def __init__(self, dataset_name):
"""
Args:
dataset_name (str): name of the dataset, e.g., "voc_2007_test"
"""
self._dataset_name = dataset_name
meta = MetadataCatalog.get(dataset_name)
# Too many tiny files, download all to local for speed.
annotation_dir_local = PathManager.get_local_path(
os.path.join(meta.dirname, "Annotations/")
)
self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml")
self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
self._class_names = meta.thing_classes
assert meta.year in [2007, 2012], meta.year
self._is_2007 = meta.year == 2007
self._cpu_device = torch.device("cpu")
self._logger = logging.getLogger(__name__)
def reset(self):
self._predictions = defaultdict(list) # class name -> list of prediction strings
def process(self, inputs, outputs):
for input, output in zip(inputs, outputs):
image_id = input["image_id"]
instances = output["instances"].to(self._cpu_device)
boxes = instances.pred_boxes.tensor.numpy()
scores = instances.scores.tolist()
classes = instances.pred_classes.tolist()
for box, score, cls in zip(boxes, scores, classes):
xmin, ymin, xmax, ymax = box
# The inverse of data loading logic in `datasets/pascal_voc.py`
xmin += 1
ymin += 1
self._predictions[cls].append(
f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
)
def evaluate(self):
"""
Returns:
dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
"""
all_predictions = comm.gather(self._predictions, dst=0)
if not comm.is_main_process():
return
predictions = defaultdict(list)
for predictions_per_rank in all_predictions:
for clsid, lines in predictions_per_rank.items():
predictions[clsid].extend(lines)
del all_predictions
self._logger.info(
"Evaluating {} using {} metric. "
"Note that results do not use the official Matlab API.".format(
self._dataset_name, 2007 if self._is_2007 else 2012
)
)
with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
res_file_template = os.path.join(dirname, "{}.txt")
aps = defaultdict(list) # iou -> ap per class
for cls_id, cls_name in enumerate(self._class_names):
lines = predictions.get(cls_id, [""])
with open(res_file_template.format(cls_name), "w") as f:
f.write("\n".join(lines))
for thresh in range(50, 100, 5):
rec, prec, ap = voc_eval(
res_file_template,
self._anno_file_template,
self._image_set_path,
cls_name,
ovthresh=thresh / 100.0,
use_07_metric=self._is_2007,
)
aps[thresh].append(ap * 100)
ret = OrderedDict()
mAP = {iou: np.mean(x) for iou, x in aps.items()}
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
return ret
##############################################################################
#
# Below code is modified from
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
# --------------------------------------------------------
# Fast/er R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Bharath Hariharan
# --------------------------------------------------------
"""Python implementation of the PASCAL VOC devkit's AP evaluation code."""
@lru_cache(maxsize=None)
def parse_rec(filename):
"""Parse a PASCAL VOC xml file."""
with PathManager.open(filename) as f:
tree = ET.parse(f)
objects = []
for obj in tree.findall("object"):
obj_struct = {}
obj_struct["name"] = obj.find("name").text
obj_struct["pose"] = obj.find("pose").text
obj_struct["truncated"] = int(obj.find("truncated").text)
obj_struct["difficult"] = int(obj.find("difficult").text)
bbox = obj.find("bndbox")
obj_struct["bbox"] = [
int(bbox.find("xmin").text),
int(bbox.find("ymin").text),
int(bbox.find("xmax").text),
int(bbox.find("ymax").text),
]
objects.append(obj_struct)
return objects
def voc_ap(rec, prec, use_07_metric=False):
"""Compute VOC AP given precision and recall. If use_07_metric is true, uses
the VOC 07 11-point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.0
for t in np.arange(0.0, 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.0
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.0], rec, [1.0]))
mpre = np.concatenate(([0.0], prec, [0.0]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
"""rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
classname,
[ovthresh],
[use_07_metric])
Top level function that does the PASCAL VOC evaluation.
detpath: Path to detections
detpath.format(classname) should produce the detection results file.
annopath: Path to annotations
annopath.format(imagename) should be the xml annotations file.
imagesetfile: Text file containing the list of images, one image per line.
classname: Category name (duh)
[ovthresh]: Overlap threshold (default = 0.5)
[use_07_metric]: Whether to use VOC07's 11 point AP computation
(default False)
"""
# assumes detections are in detpath.format(classname)
# assumes annotations are in annopath.format(imagename)
# assumes imagesetfile is a text file with each line an image name
# first load gt
# read list of images
with PathManager.open(imagesetfile, "r") as f:
lines = f.readlines()
imagenames = [x.strip() for x in lines]
# load annots
recs = {}
for imagename in imagenames:
recs[imagename] = parse_rec(annopath.format(imagename))
# extract gt objects for this class
class_recs = {}
npos = 0
for imagename in imagenames:
R = [obj for obj in recs[imagename] if obj["name"] == classname]
bbox = np.array([x["bbox"] for x in R])
difficult = np.array([x["difficult"] for x in R]).astype(bool)
# difficult = np.array([False for x in R]).astype(bool) # treat all "difficult" as GT
det = [False] * len(R)
npos = npos + sum(~difficult)
class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
# read dets
detfile = detpath.format(classname)
with open(detfile, "r") as f:
lines = f.readlines()
splitlines = [x.strip().split(" ") for x in lines]
image_ids = [x[0] for x in splitlines]
confidence = np.array([float(x[1]) for x in splitlines])
BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)
# sort by confidence
sorted_ind = np.argsort(-confidence)
BB = BB[sorted_ind, :]
image_ids = [image_ids[x] for x in sorted_ind]
# go down dets and mark TPs and FPs
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax = -np.inf
BBGT = R["bbox"].astype(float)
if BBGT.size > 0:
# compute overlaps
# intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
ih = np.maximum(iymax - iymin + 1.0, 0.0)
inters = iw * ih
# union
uni = (
(bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+ (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
- inters
)
overlaps = inters / uni
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
if ovmax > ovthresh:
if not R["difficult"][jmax]:
if not R["det"][jmax]:
tp[d] = 1.0
R["det"][jmax] = 1
else:
fp[d] = 1.0
else:
fp[d] = 1.0
# compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(npos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
|