CRSArena / src /model /metric.py
Nolwenn
Initial commit
b599481
raw
history blame
1.6 kB
import math
class RecMetric:
def __init__(self, k_list=(1, 10, 50)):
self.k_list = k_list
self.metric = {}
self.reset_metric()
def evaluate(self, preds, labels):
for label in labels:
pred_list = preds
if label == -100:
continue
for k in self.k_list:
self.metric[f"recall@{k}"] += self.compute_recall(pred_list, label, k)
self.metric[f"ndcg@{k}"] += self.compute_ndcg(pred_list, label, k)
self.metric[f"mrr@{k}"] += self.compute_mrr(pred_list, label, k)
self.metric["count"] += 1
def compute_recall(self, pred_list, label, k):
return int(label in pred_list[:k])
def compute_mrr(self, pred_list, label, k):
if label in pred_list[:k]:
label_rank = pred_list.index(label)
return 1 / (label_rank + 1)
return 0
def compute_ndcg(self, pred_list, label, k):
if label in pred_list[:k]:
label_rank = pred_list.index(label)
return 1 / math.log2(label_rank + 2)
return 0
def reset_metric(self):
for metric in ["recall", "ndcg", "mrr"]:
for k in self.k_list:
self.metric[f"{metric}@{k}"] = 0
self.metric["count"] = 0
def report(self):
report = {}
for k, v in self.metric.items():
if k != "count":
report[k] = v / self.metric["count"]
else:
report[k] = v
return report