Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch.nn as nn | |
from typing import Any, Union | |
import torch | |
import os | |
import os.path as osp | |
from torchmetrics.functional.retrieval import retrieval_hit_rate, \ | |
retrieval_reciprocal_rank, \ | |
retrieval_recall, retrieval_precision, \ | |
retrieval_average_precision, \ | |
retrieval_normalized_dcg, \ | |
retrieval_r_precision | |
from src.tools.api import get_openai_embedding | |
class ModelForSemiStructQA(nn.Module): | |
def __init__(self, kb): | |
super(ModelForSemiStructQA, self).__init__() | |
self.kb = kb | |
self.candidate_ids = kb.candidate_ids | |
self.num_candidates = kb.num_candidates | |
self.query_emb_dict = {} | |
def forward(self, | |
query: Union[str, list], | |
candidates=None, | |
query_id=None, | |
**kwargs: Any): | |
''' | |
Args: | |
query (Union[str, list]): query string or a list of query strings | |
candidates (Union[list, None]): a list of candidate ids (optional) | |
query_id (Union[int, list, None]): query index (optional) | |
Returns: | |
pred_dict (dict): a dictionary of predicted scores or answer ids | |
''' | |
raise NotImplementedError | |
def _get_query_emb(self, query: str, query_id: int, | |
emb_model: str = 'text-embedding-ada-002'): | |
if query_id is None: | |
query_emb = get_openai_embedding(query, model=emb_model) | |
elif len(self.query_emb_dict) > 0: | |
query_emb = self.query_emb_dict[query_id] | |
else: | |
query_emb_dic_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt') | |
if os.path.exists(query_emb_dic_path): | |
print(f'Load query embeddings from {query_emb_dic_path}') | |
self.query_emb_dict = torch.load(query_emb_dic_path) | |
query_emb = self.query_emb_dict[query_id] | |
else: | |
query_emb_dir = osp.join(self.query_emb_dir, 'query_embs') | |
if not os.path.exists(query_emb_dir): | |
os.makedirs(query_emb_dir) | |
query_emb_path = osp.join(query_emb_dir, f'query_{query_id}.pt') | |
query_emb = get_openai_embedding(query, model=emb_model) | |
torch.save(query_emb, query_emb_path) | |
return query_emb | |
def evaluate(self, | |
pred_dict: dict, | |
answer_ids: torch.LongTensor, | |
metrics=['mrr', 'hit@3', 'recall@20'], | |
**kwargs: Any): | |
''' | |
Args: | |
pred_dict (torch.Tensor): predicted answer ids or scores | |
answer_ids (torch.LongTensor): ground truth answer ids | |
metrics (list): a list of metrics to be evaluated, | |
including 'mrr', 'hit@k', 'recall@k', 'precision@k', 'map@k', 'ndcg@k' | |
Returns: | |
eval_metrics (dict): a dictionary of evaluation metrics | |
''' | |
pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1) | |
pred = torch.FloatTensor(list(pred_dict.values())).view(-1) | |
answer_ids = answer_ids.view(-1) | |
all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * min(pred) - 1 | |
all_pred[pred_ids] = pred | |
all_pred = all_pred[self.candidate_ids] | |
bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool) | |
bool_gd[answer_ids] = True | |
bool_gd = bool_gd[self.candidate_ids] | |
eval_metrics = {} | |
for metric in metrics: | |
k = int(metric.split('@')[-1]) if '@' in metric else None | |
if 'mrr' == metric: | |
result = retrieval_reciprocal_rank(all_pred, bool_gd) | |
elif 'rprecision' == metric: | |
result = retrieval_r_precision(all_pred, bool_gd) | |
elif 'hit' in metric: | |
result = retrieval_hit_rate(all_pred, bool_gd, top_k=k) | |
elif 'recall' in metric: | |
result = retrieval_recall(all_pred, bool_gd, top_k=k) | |
elif 'precision' in metric: | |
result = retrieval_precision(all_pred, bool_gd, top_k=k) | |
elif 'map' in metric: | |
result = retrieval_average_precision(all_pred, bool_gd, top_k=k) | |
elif 'ndcg' in metric: | |
result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k) | |
eval_metrics[metric] = float(result) | |
return eval_metrics |