Spaces:
Sleeping
Sleeping
from typing import Dict | |
def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]: | |
""" | |
Returns true positives, false positives, and false negatives for one example | |
If union is True, then disregards the type of the tag and only considers the union of all tags | |
""" | |
n_tp = 0 | |
n_fp = 0 | |
n_fn = 0 | |
if union: | |
gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases) | |
pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases) | |
n_tp = len(gold_phrases & pred_phrases) | |
n_fp = len(pred_phrases - gold_phrases) | |
n_fn = len(gold_phrases - pred_phrases) | |
return n_tp, n_fp, n_fn | |
for tag in set(gold.keys()).union(pred.keys()): | |
gold_phrases = set(gold.get(tag, [])) | |
pred_phrases = set(pred.get(tag, [])) | |
n_tp += len(gold_phrases & pred_phrases) | |
n_fp += len(pred_phrases - gold_phrases) | |
n_fn += len(gold_phrases - pred_phrases) | |
return n_tp, n_fp, n_fn | |
def compute_metrics(running_time, pred_times, runtype, eval_metrics=None): | |
metrics = {} | |
metrics["avg_pred_response_time_per_sentence"] = ( | |
round(sum(pred_times) / len(pred_times), 4) if pred_times else 0 | |
) | |
metrics["total_time"] = round(running_time, 4) | |
if runtype == "eval" and eval_metrics is not None: | |
n_tp, n_fp, n_fn, n_tp_union, n_fp_union, n_fn_union = eval_metrics | |
precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4) | |
recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4) | |
f1 = round( | |
( | |
2 * (precision * recall) / (precision + recall) | |
if (precision + recall) > 0 | |
else 0 | |
), | |
4, | |
) | |
union_precision = round( | |
( | |
n_tp_union / (n_tp_union + n_fp_union) | |
if (n_tp_union + n_fp_union) > 0 | |
else 0 | |
), | |
4, | |
) | |
union_recall = round( | |
( | |
n_tp_union / (n_tp_union + n_fn_union) | |
if (n_tp_union + n_fn_union) > 0 | |
else 0 | |
), | |
4, | |
) | |
union_f1 = round( | |
( | |
2 * (union_precision * union_recall) / (union_precision + union_recall) | |
if (union_precision + union_recall) > 0 | |
else 0 | |
), | |
4, | |
) | |
metrics.update( | |
{ | |
"precision": precision, | |
"recall": recall, | |
"f1": f1, | |
"union_precision": union_precision, | |
"union_recall": union_recall, | |
"union_f1": union_f1, | |
} | |
) | |
return metrics | |