Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,222 Bytes
0301e15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from typing import List, Tuple
import numpy as np
from numpy import ndarray
def get_mAP(
preds: ndarray,
gt_file: str,
taglist: List[str]
) -> Tuple[float, ndarray]:
assert preds.shape[1] == len(taglist)
# When mapping categories from test datasets to our system, there might be
# multiple vs one situation due to different semantic definitions of tags.
# So there can be duplicate tags in `taglist`. This special case is taken
# into account.
tag2idxs = {}
for idx, tag in enumerate(taglist):
if tag not in tag2idxs:
tag2idxs[tag] = []
tag2idxs[tag].append(idx)
# build targets
targets = np.zeros_like(preds)
with open(gt_file, "r") as f:
lines = [line.strip("\n").split(",") for line in f.readlines()]
assert len(lines) == targets.shape[0]
for i, line in enumerate(lines):
for tag in line[1:]:
targets[i, tag2idxs[tag]] = 1.0
# compute average precision for each class
APs = np.zeros(preds.shape[1])
for k in range(preds.shape[1]):
APs[k] = _average_precision(preds[:, k], targets[:, k])
return APs.mean(), APs
def _average_precision(output: ndarray, target: ndarray) -> float:
epsilon = 1e-8
# sort examples
indices = output.argsort()[::-1]
# Computes prec@i
total_count_ = np.cumsum(np.ones((len(output), 1)))
target_ = target[indices]
ind = target_ == 1
pos_count_ = np.cumsum(ind)
total = pos_count_[-1]
pos_count_[np.logical_not(ind)] = 0
pp = pos_count_ / total_count_
precision_at_i_ = np.sum(pp)
precision_at_i = precision_at_i_ / (total + epsilon)
return precision_at_i
def get_PR(
pred_file: str,
gt_file: str,
taglist: List[str]
) -> Tuple[float, float, ndarray, ndarray]:
# When mapping categories from test datasets to our system, there might be
# multiple vs one situation due to different semantic definitions of tags.
# So there can be duplicate tags in `taglist`. This special case is taken
# into account.
tag2idxs = {}
for idx, tag in enumerate(taglist):
if tag not in tag2idxs:
tag2idxs[tag] = []
tag2idxs[tag].append(idx)
# build preds
with open(pred_file, "r", encoding="utf-8") as f:
lines = [line.strip().split(",") for line in f.readlines()]
preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
for i, line in enumerate(lines):
for tag in line[1:]:
preds[i, tag2idxs[tag]] = True
# build targets
with open(gt_file, "r", encoding="utf-8") as f:
lines = [line.strip().split(",") for line in f.readlines()]
targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool)
for i, line in enumerate(lines):
for tag in line[1:]:
targets[i, tag2idxs[tag]] = True
assert preds.shape == targets.shape
# calculate P and R
TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222
FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222
FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222
eps = 1.e-9
Ps = TPs / (TPs + FPs + eps)
Rs = TPs / (TPs + FNs + eps)
return Ps.mean(), Rs.mean(), Ps, Rs
|