SKB-Explorer / eval.py
zsyJosh
feat: ✨ SKB explorer
0c3992e
raw
history blame
4.71 kB
import os.path as osp
import json
import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from src.benchmarks import get_qa_dataset, get_semistructured_data
from src.models import get_model
from src.tools.args import merge_args, load_args
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="amazon", choices=['amazon', 'primekg', 'mag'])
parser.add_argument(
"--model", default="VSS", choices=["VSS", "MultiVSS", "LLMReranker"]
)
parser.add_argument("--split", default="test")
# can eval on a subset only
parser.add_argument("--test_ratio", type=float, default=1.0)
# for multivss
parser.add_argument("--chunk_size", type=int, default=None)
parser.add_argument("--multi_vss_topk", type=int, default=None)
parser.add_argument("--aggregate", type=str, default="max")
# for vss, multivss, and llm reranker
parser.add_argument("--emb_model", type=str, default="text-embedding-ada-002")
# for llm reranker
parser.add_argument("--llm_model", type=str, default="gpt-4-1106-preview",
help='the LLM to rerank candidates.')
parser.add_argument("--llm_topk", type=int, default=20)
parser.add_argument("--max_retry", type=int, default=3)
# path
parser.add_argument("--emb_dir", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
# save prediction
parser.add_argument("--save_pred", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
default_args = load_args(
json.load(open("config/default_args.json", "r"))[args.dataset]
)
args = merge_args(args, default_args)
args.query_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "query")
args.node_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "doc")
args.chunk_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "chunk")
surfix = args.llm_model if args.model == 'LLMReranker' else args.emb_model
output_dir = osp.join(args.output_dir, "eval", args.dataset, args.model, surfix)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(args.query_emb_dir, exist_ok=True)
os.makedirs(args.chunk_emb_dir, exist_ok=True)
os.makedirs(args.node_emb_dir, exist_ok=True)
json.dump(vars(args), open(osp.join(output_dir, "args.json"), "w"), indent=4)
eval_csv_path = osp.join(output_dir, f"eval_results_{args.split}.csv")
final_eval_path = (
osp.join(output_dir, f"eval_metrics_{args.split}.json")
if args.test_ratio == 1.0
else osp.join(output_dir, f"eval_metrics_{args.split}_{args.test_ratio}.json")
)
kb = get_semistructured_data(args.dataset)
qa_dataset = get_qa_dataset(args.dataset)
model = get_model(args, kb)
split_idx = qa_dataset.get_idx_split(test_ratio=args.test_ratio)
eval_metrics = [
"mrr",
"map",
"rprecision",
"recall@5",
"recall@10",
"recall@20",
"recall@50",
"recall@100",
"hit@1",
"hit@3",
"hit@5",
"hit@10",
"hit@20",
"hit@50",
]
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
existing_idx = []
if osp.exists(eval_csv_path):
eval_csv = pd.read_csv(eval_csv_path)
existing_idx = eval_csv["idx"].tolist()
indices = split_idx[args.split].tolist()
for idx in tqdm(indices):
if idx in existing_idx:
continue
query, query_id, answer_ids, meta_info = qa_dataset[idx]
pred_dict = model.forward(query, query_id)
answer_ids = torch.LongTensor(answer_ids)
result = model.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
result["idx"], result["query_id"] = idx, query_id
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
:1000
]
].tolist()
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)
if args.save_pred:
eval_csv.to_csv(eval_csv_path, index=False)
for metric in eval_metrics:
print(
f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(indices)][metric])}"
)
if args.save_pred:
eval_csv.to_csv(eval_csv_path, index=False)
final_metrics = (
eval_csv[eval_csv["idx"].isin(indices)][eval_metrics].mean().to_dict()
)
json.dump(final_metrics, open(final_eval_path, "w"), indent=4)