File size: 4,634 Bytes
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch.nn as nn
from typing import Any, Union
import torch
import os
import os.path as osp
from torchmetrics.functional.retrieval import retrieval_hit_rate, \
                                              retrieval_reciprocal_rank, \
                                              retrieval_recall, retrieval_precision, \
                                              retrieval_average_precision, \
                                              retrieval_normalized_dcg, \
                                              retrieval_r_precision
from src.tools.api import get_openai_embedding


class ModelForSemiStructQA(nn.Module):
    
    def __init__(self, kb):
        super(ModelForSemiStructQA, self).__init__()
        self.kb = kb
        self.candidate_ids = kb.candidate_ids
        self.num_candidates = kb.num_candidates
        self.query_emb_dict = {}
    
    def forward(self, 
                query: Union[str, list], 
                candidates=None,
                query_id=None,
                **kwargs: Any):
        '''
        Args:
            query (Union[str, list]): query string or a list of query strings
            candidates (Union[list, None]): a list of candidate ids (optional)
            query_id (Union[int, list, None]): query index (optional)
            
        Returns:
            pred_dict (dict): a dictionary of predicted scores or answer ids
        '''
        raise NotImplementedError
    
    def _get_query_emb(self, query: str, query_id: int, 
                       emb_model: str = 'text-embedding-ada-002'):
        if query_id is None:
            query_emb = get_openai_embedding(query, model=emb_model)
        elif len(self.query_emb_dict) > 0:
            query_emb = self.query_emb_dict[query_id]
        else:
            query_emb_dic_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
            if os.path.exists(query_emb_dic_path):
                print(f'Load query embeddings from {query_emb_dic_path}')
                self.query_emb_dict = torch.load(query_emb_dic_path)
                query_emb = self.query_emb_dict[query_id]
            else:
                query_emb_dir = osp.join(self.query_emb_dir, 'query_embs')
                if not os.path.exists(query_emb_dir):
                    os.makedirs(query_emb_dir)
                query_emb_path = osp.join(query_emb_dir, f'query_{query_id}.pt')
                query_emb = get_openai_embedding(query, model=emb_model)
                torch.save(query_emb, query_emb_path)
        return query_emb
    
    def evaluate(self, 
                 pred_dict: dict, 
                 answer_ids: torch.LongTensor, 
                 metrics=['mrr', 'hit@3', 'recall@20'], 
                 **kwargs: Any):
        '''
        Args:
            pred_dict (torch.Tensor): predicted answer ids or scores
            answer_ids (torch.LongTensor): ground truth answer ids
            metrics (list): a list of metrics to be evaluated, 
                including 'mrr', 'hit@k', 'recall@k', 'precision@k', 'map@k', 'ndcg@k'
        Returns:
            eval_metrics (dict): a dictionary of evaluation metrics
        '''
        pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1)
        pred = torch.FloatTensor(list(pred_dict.values())).view(-1)
        answer_ids = answer_ids.view(-1)

        all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * min(pred) - 1
        all_pred[pred_ids] = pred
        all_pred = all_pred[self.candidate_ids]

        bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool)
        bool_gd[answer_ids] = True
        bool_gd = bool_gd[self.candidate_ids]

        eval_metrics = {}
        for metric in metrics:
            k = int(metric.split('@')[-1]) if '@' in metric else None
            if 'mrr' == metric:
                result = retrieval_reciprocal_rank(all_pred, bool_gd)
            elif 'rprecision' == metric:
                result = retrieval_r_precision(all_pred, bool_gd)
            elif 'hit' in metric:
                result = retrieval_hit_rate(all_pred, bool_gd, top_k=k)
            elif 'recall' in metric:
                result = retrieval_recall(all_pred, bool_gd, top_k=k)
            elif 'precision' in metric:
                result = retrieval_precision(all_pred, bool_gd, top_k=k)
            elif 'map' in metric:
                result = retrieval_average_precision(all_pred, bool_gd, top_k=k)
            elif 'ndcg' in metric:
                result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k)
            eval_metrics[metric] = float(result)

        return eval_metrics