Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os.path as osp | |
import json | |
import os | |
import argparse | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import torch | |
from src.benchmarks import get_qa_dataset, get_semistructured_data | |
from src.models import get_model | |
from src.tools.args import merge_args, load_args | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset", default="amazon", choices=['amazon', 'primekg', 'mag']) | |
parser.add_argument( | |
"--model", default="VSS", choices=["VSS", "MultiVSS", "LLMReranker"] | |
) | |
parser.add_argument("--split", default="test") | |
# can eval on a subset only | |
parser.add_argument("--test_ratio", type=float, default=1.0) | |
# for multivss | |
parser.add_argument("--chunk_size", type=int, default=None) | |
parser.add_argument("--multi_vss_topk", type=int, default=None) | |
parser.add_argument("--aggregate", type=str, default="max") | |
# for vss, multivss, and llm reranker | |
parser.add_argument("--emb_model", type=str, default="text-embedding-ada-002") | |
# for llm reranker | |
parser.add_argument("--llm_model", type=str, default="gpt-4-1106-preview", | |
help='the LLM to rerank candidates.') | |
parser.add_argument("--llm_topk", type=int, default=20) | |
parser.add_argument("--max_retry", type=int, default=3) | |
# path | |
parser.add_argument("--emb_dir", type=str, required=True) | |
parser.add_argument("--output_dir", type=str, required=True) | |
# save prediction | |
parser.add_argument("--save_pred", action="store_true") | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_args() | |
default_args = load_args( | |
json.load(open("config/default_args.json", "r"))[args.dataset] | |
) | |
args = merge_args(args, default_args) | |
args.query_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "query") | |
args.node_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "doc") | |
args.chunk_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "chunk") | |
surfix = args.llm_model if args.model == 'LLMReranker' else args.emb_model | |
output_dir = osp.join(args.output_dir, "eval", args.dataset, args.model, surfix) | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(args.query_emb_dir, exist_ok=True) | |
os.makedirs(args.chunk_emb_dir, exist_ok=True) | |
os.makedirs(args.node_emb_dir, exist_ok=True) | |
json.dump(vars(args), open(osp.join(output_dir, "args.json"), "w"), indent=4) | |
eval_csv_path = osp.join(output_dir, f"eval_results_{args.split}.csv") | |
final_eval_path = ( | |
osp.join(output_dir, f"eval_metrics_{args.split}.json") | |
if args.test_ratio == 1.0 | |
else osp.join(output_dir, f"eval_metrics_{args.split}_{args.test_ratio}.json") | |
) | |
kb = get_semistructured_data(args.dataset) | |
qa_dataset = get_qa_dataset(args.dataset) | |
model = get_model(args, kb) | |
split_idx = qa_dataset.get_idx_split(test_ratio=args.test_ratio) | |
eval_metrics = [ | |
"mrr", | |
"map", | |
"rprecision", | |
"recall@5", | |
"recall@10", | |
"recall@20", | |
"recall@50", | |
"recall@100", | |
"hit@1", | |
"hit@3", | |
"hit@5", | |
"hit@10", | |
"hit@20", | |
"hit@50", | |
] | |
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics) | |
existing_idx = [] | |
if osp.exists(eval_csv_path): | |
eval_csv = pd.read_csv(eval_csv_path) | |
existing_idx = eval_csv["idx"].tolist() | |
indices = split_idx[args.split].tolist() | |
for idx in tqdm(indices): | |
if idx in existing_idx: | |
continue | |
query, query_id, answer_ids, meta_info = qa_dataset[idx] | |
pred_dict = model.forward(query, query_id) | |
answer_ids = torch.LongTensor(answer_ids) | |
result = model.evaluate(pred_dict, answer_ids, metrics=eval_metrics) | |
result["idx"], result["query_id"] = idx, query_id | |
result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[ | |
torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[ | |
:1000 | |
] | |
].tolist() | |
eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True) | |
if args.save_pred: | |
eval_csv.to_csv(eval_csv_path, index=False) | |
for metric in eval_metrics: | |
print( | |
f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(indices)][metric])}" | |
) | |
if args.save_pred: | |
eval_csv.to_csv(eval_csv_path, index=False) | |
final_metrics = ( | |
eval_csv[eval_csv["idx"].isin(indices)][eval_metrics].mean().to_dict() | |
) | |
json.dump(final_metrics, open(final_eval_path, "w"), indent=4) |