File size: 4,706 Bytes
0c3992e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os.path as osp
import json
import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from src.benchmarks import get_qa_dataset, get_semistructured_data
from src.models import get_model
from src.tools.args import merge_args, load_args


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="amazon", choices=['amazon', 'primekg', 'mag'])
    parser.add_argument(
        "--model", default="VSS", choices=["VSS", "MultiVSS", "LLMReranker"]
    )
    parser.add_argument("--split", default="test")

    # can eval on a subset only
    parser.add_argument("--test_ratio", type=float, default=1.0)

    # for multivss
    parser.add_argument("--chunk_size", type=int, default=None)
    parser.add_argument("--multi_vss_topk", type=int, default=None)
    parser.add_argument("--aggregate", type=str, default="max")

    # for vss, multivss, and llm reranker
    parser.add_argument("--emb_model", type=str, default="text-embedding-ada-002")

    # for llm reranker
    parser.add_argument("--llm_model", type=str, default="gpt-4-1106-preview",
                        help='the LLM to rerank candidates.')
    parser.add_argument("--llm_topk", type=int, default=20)
    parser.add_argument("--max_retry", type=int, default=3)

    # path
    parser.add_argument("--emb_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)

    # save prediction
    parser.add_argument("--save_pred", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    default_args = load_args(
        json.load(open("config/default_args.json", "r"))[args.dataset]
    )
    args = merge_args(args, default_args)

    args.query_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "query")
    args.node_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "doc")
    args.chunk_emb_dir = osp.join(args.emb_dir, args.dataset, args.emb_model, "chunk")
    surfix = args.llm_model if args.model == 'LLMReranker' else args.emb_model
    output_dir = osp.join(args.output_dir, "eval", args.dataset, args.model, surfix)

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(args.query_emb_dir, exist_ok=True)
    os.makedirs(args.chunk_emb_dir, exist_ok=True)
    os.makedirs(args.node_emb_dir, exist_ok=True)
    json.dump(vars(args), open(osp.join(output_dir, "args.json"), "w"), indent=4)

    eval_csv_path = osp.join(output_dir, f"eval_results_{args.split}.csv")
    final_eval_path = (
        osp.join(output_dir, f"eval_metrics_{args.split}.json")
        if args.test_ratio == 1.0
        else osp.join(output_dir, f"eval_metrics_{args.split}_{args.test_ratio}.json")
    )

    kb = get_semistructured_data(args.dataset)
    qa_dataset = get_qa_dataset(args.dataset)
    model = get_model(args, kb)

    split_idx = qa_dataset.get_idx_split(test_ratio=args.test_ratio)

    eval_metrics = [
        "mrr",
        "map",
        "rprecision",
        "recall@5",
        "recall@10",
        "recall@20",
        "recall@50",
        "recall@100",
        "hit@1",
        "hit@3",
        "hit@5",
        "hit@10",
        "hit@20",
        "hit@50",
    ]
    eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)

    existing_idx = []
    if osp.exists(eval_csv_path):
        eval_csv = pd.read_csv(eval_csv_path)
        existing_idx = eval_csv["idx"].tolist()

    indices = split_idx[args.split].tolist()

    for idx in tqdm(indices):
        if idx in existing_idx:
            continue
        query, query_id, answer_ids, meta_info = qa_dataset[idx]
        pred_dict = model.forward(query, query_id)

        answer_ids = torch.LongTensor(answer_ids)
        result = model.evaluate(pred_dict, answer_ids, metrics=eval_metrics)

        result["idx"], result["query_id"] = idx, query_id
        result["pred_rank"] = torch.LongTensor(list(pred_dict.keys()))[
            torch.argsort(torch.tensor(list(pred_dict.values())), descending=True)[
                :1000
            ]
        ].tolist()

        eval_csv = pd.concat([eval_csv, pd.DataFrame([result])], ignore_index=True)

        if args.save_pred:
            eval_csv.to_csv(eval_csv_path, index=False)
        for metric in eval_metrics:
            print(
                f"{metric}: {np.mean(eval_csv[eval_csv['idx'].isin(indices)][metric])}"
            )
    if args.save_pred:
        eval_csv.to_csv(eval_csv_path, index=False)
    final_metrics = (
        eval_csv[eval_csv["idx"].isin(indices)][eval_metrics].mean().to_dict()
    )
    json.dump(final_metrics, open(final_eval_path, "w"), indent=4)