|
import math
|
|
|
|
|
|
class RecMetric:
|
|
def __init__(self, k_list=(1, 10, 50)):
|
|
self.k_list = k_list
|
|
self.metric = {}
|
|
self.reset_metric()
|
|
|
|
def evaluate(self, preds, labels):
|
|
for label in labels:
|
|
pred_list = preds
|
|
if label == -100:
|
|
continue
|
|
for k in self.k_list:
|
|
self.metric[f"recall@{k}"] += self.compute_recall(pred_list, label, k)
|
|
self.metric[f"ndcg@{k}"] += self.compute_ndcg(pred_list, label, k)
|
|
self.metric[f"mrr@{k}"] += self.compute_mrr(pred_list, label, k)
|
|
self.metric["count"] += 1
|
|
|
|
def compute_recall(self, pred_list, label, k):
|
|
return int(label in pred_list[:k])
|
|
|
|
def compute_mrr(self, pred_list, label, k):
|
|
if label in pred_list[:k]:
|
|
label_rank = pred_list.index(label)
|
|
return 1 / (label_rank + 1)
|
|
return 0
|
|
|
|
def compute_ndcg(self, pred_list, label, k):
|
|
if label in pred_list[:k]:
|
|
label_rank = pred_list.index(label)
|
|
return 1 / math.log2(label_rank + 2)
|
|
return 0
|
|
|
|
def reset_metric(self):
|
|
for metric in ["recall", "ndcg", "mrr"]:
|
|
for k in self.k_list:
|
|
self.metric[f"{metric}@{k}"] = 0
|
|
self.metric["count"] = 0
|
|
|
|
def report(self):
|
|
report = {}
|
|
for k, v in self.metric.items():
|
|
if k != "count":
|
|
report[k] = v / self.metric["count"]
|
|
else:
|
|
report[k] = v
|
|
return report
|
|
|