File size: 1,596 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import math


class RecMetric:
    def __init__(self, k_list=(1, 10, 50)):
        self.k_list = k_list
        self.metric = {}
        self.reset_metric()

    def evaluate(self, preds, labels):
        for label in labels:
            pred_list = preds
            if label == -100:
                continue
            for k in self.k_list:
                self.metric[f"recall@{k}"] += self.compute_recall(pred_list, label, k)
                self.metric[f"ndcg@{k}"] += self.compute_ndcg(pred_list, label, k)
                self.metric[f"mrr@{k}"] += self.compute_mrr(pred_list, label, k)
            self.metric["count"] += 1

    def compute_recall(self, pred_list, label, k):
        return int(label in pred_list[:k])

    def compute_mrr(self, pred_list, label, k):
        if label in pred_list[:k]:
            label_rank = pred_list.index(label)
            return 1 / (label_rank + 1)
        return 0

    def compute_ndcg(self, pred_list, label, k):
        if label in pred_list[:k]:
            label_rank = pred_list.index(label)
            return 1 / math.log2(label_rank + 2)
        return 0

    def reset_metric(self):
        for metric in ["recall", "ndcg", "mrr"]:
            for k in self.k_list:
                self.metric[f"{metric}@{k}"] = 0
        self.metric["count"] = 0

    def report(self):
        report = {}
        for k, v in self.metric.items():
            if k != "count":
                report[k] = v / self.metric["count"]
            else:
                report[k] = v
        return report