surveyor-0 / src /eval /metrics.py
Abhipsha Das
add files
a2b5ed5 unverified
from typing import Dict
def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
"""
Returns true positives, false positives, and false negatives for one example
If union is True, then disregards the type of the tag and only considers the union of all tags
"""
n_tp = 0
n_fp = 0
n_fn = 0
if union:
gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
n_tp = len(gold_phrases & pred_phrases)
n_fp = len(pred_phrases - gold_phrases)
n_fn = len(gold_phrases - pred_phrases)
return n_tp, n_fp, n_fn
for tag in set(gold.keys()).union(pred.keys()):
gold_phrases = set(gold.get(tag, []))
pred_phrases = set(pred.get(tag, []))
n_tp += len(gold_phrases & pred_phrases)
n_fp += len(pred_phrases - gold_phrases)
n_fn += len(gold_phrases - pred_phrases)
return n_tp, n_fp, n_fn
def compute_metrics(running_time, pred_times, runtype, eval_metrics=None):
metrics = {}
metrics["avg_pred_response_time_per_sentence"] = (
round(sum(pred_times) / len(pred_times), 4) if pred_times else 0
)
metrics["total_time"] = round(running_time, 4)
if runtype == "eval" and eval_metrics is not None:
n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics
precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
f1 = round(
(
2 * (precision * recall) / (precision + recall)
if (precision + recall) > 0
else 0
),
4,
)
union_precision = round(
(
n_tp_union / (n_tp_union + n_fp_union)
if (n_tp_union + n_fp_union) > 0
else 0
),
4,
)
union_recall = round(
(
n_tp_union / (n_tp_union + n_fn_union)
if (n_tp_union + n_fn_union) > 0
else 0
),
4,
)
union_f1 = round(
(
2 * (union_precision * union_recall) / (union_precision + union_recall)
if (union_precision + union_recall) > 0
else 0
),
4,
)
metrics.update(
{
"precision": precision,
"recall": recall,
"f1": f1,
"union_precision": union_precision,
"union_recall": union_recall,
"union_f1": union_f1,
}
)
return metrics