Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,634 Bytes
0c3992e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch.nn as nn
from typing import Any, Union
import torch
import os
import os.path as osp
from torchmetrics.functional.retrieval import retrieval_hit_rate, \
retrieval_reciprocal_rank, \
retrieval_recall, retrieval_precision, \
retrieval_average_precision, \
retrieval_normalized_dcg, \
retrieval_r_precision
from src.tools.api import get_openai_embedding
class ModelForSemiStructQA(nn.Module):
def __init__(self, kb):
super(ModelForSemiStructQA, self).__init__()
self.kb = kb
self.candidate_ids = kb.candidate_ids
self.num_candidates = kb.num_candidates
self.query_emb_dict = {}
def forward(self,
query: Union[str, list],
candidates=None,
query_id=None,
**kwargs: Any):
'''
Args:
query (Union[str, list]): query string or a list of query strings
candidates (Union[list, None]): a list of candidate ids (optional)
query_id (Union[int, list, None]): query index (optional)
Returns:
pred_dict (dict): a dictionary of predicted scores or answer ids
'''
raise NotImplementedError
def _get_query_emb(self, query: str, query_id: int,
emb_model: str = 'text-embedding-ada-002'):
if query_id is None:
query_emb = get_openai_embedding(query, model=emb_model)
elif len(self.query_emb_dict) > 0:
query_emb = self.query_emb_dict[query_id]
else:
query_emb_dic_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt')
if os.path.exists(query_emb_dic_path):
print(f'Load query embeddings from {query_emb_dic_path}')
self.query_emb_dict = torch.load(query_emb_dic_path)
query_emb = self.query_emb_dict[query_id]
else:
query_emb_dir = osp.join(self.query_emb_dir, 'query_embs')
if not os.path.exists(query_emb_dir):
os.makedirs(query_emb_dir)
query_emb_path = osp.join(query_emb_dir, f'query_{query_id}.pt')
query_emb = get_openai_embedding(query, model=emb_model)
torch.save(query_emb, query_emb_path)
return query_emb
def evaluate(self,
pred_dict: dict,
answer_ids: torch.LongTensor,
metrics=['mrr', 'hit@3', 'recall@20'],
**kwargs: Any):
'''
Args:
pred_dict (torch.Tensor): predicted answer ids or scores
answer_ids (torch.LongTensor): ground truth answer ids
metrics (list): a list of metrics to be evaluated,
including 'mrr', 'hit@k', 'recall@k', 'precision@k', 'map@k', 'ndcg@k'
Returns:
eval_metrics (dict): a dictionary of evaluation metrics
'''
pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1)
pred = torch.FloatTensor(list(pred_dict.values())).view(-1)
answer_ids = answer_ids.view(-1)
all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * min(pred) - 1
all_pred[pred_ids] = pred
all_pred = all_pred[self.candidate_ids]
bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool)
bool_gd[answer_ids] = True
bool_gd = bool_gd[self.candidate_ids]
eval_metrics = {}
for metric in metrics:
k = int(metric.split('@')[-1]) if '@' in metric else None
if 'mrr' == metric:
result = retrieval_reciprocal_rank(all_pred, bool_gd)
elif 'rprecision' == metric:
result = retrieval_r_precision(all_pred, bool_gd)
elif 'hit' in metric:
result = retrieval_hit_rate(all_pred, bool_gd, top_k=k)
elif 'recall' in metric:
result = retrieval_recall(all_pred, bool_gd, top_k=k)
elif 'precision' in metric:
result = retrieval_precision(all_pred, bool_gd, top_k=k)
elif 'map' in metric:
result = retrieval_average_precision(all_pred, bool_gd, top_k=k)
elif 'ndcg' in metric:
result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k)
eval_metrics[metric] = float(result)
return eval_metrics |