File size: 3,168 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import json
import os
import sys

from tqdm import tqdm

sys.path.append("..")

from src.model.metric import RecMetric

datasets = ["redial_eval", "opendialkg_eval"]
models = ["kbrd", "barcor", "unicrs", "chatgpt"]


# compute rec recall
def rec_eval(turn_num, mode):
    for dataset in datasets:
        with open(
            f"../data/{dataset.split('_')[0]}/entity2id.json",
            "r",
            encoding="utf-8",
        ) as f:
            entity2id = json.load(f)

        for model in models:
            metric = RecMetric([1, 10, 25, 50])
            persuatiness = 0
            save_path = (
                f"../save_{turn_num}/{mode}/{model}/{dataset}"  # data loaded path
            )
            result_path = f"../save_{turn_num}/result/{mode}/{model}"
            os.makedirs(result_path, exist_ok=True)
            if os.path.exists(save_path) and len(os.listdir(save_path)) > 0:
                path_list = os.listdir(save_path)
                print(
                    f"turn_num: {turn_num}, mode: {mode} model: {model} dataset: {dataset}",
                    len(path_list),
                )

                for path in tqdm(path_list):
                    with open(f"{save_path}/{path}", "r", encoding="utf-8") as f:
                        data = json.load(f)
                        if mode == "chat":
                            persuasiveness_score = data["persuasiveness_score"]
                            persuatiness += float(persuasiveness_score)
                        PE_dialog = data["simulator_dialog"]
                        rec_label = data["rec"]
                        rec_label = [
                            entity2id[rec] for rec in rec_label if rec in entity2id
                        ]
                        contexts = PE_dialog["context"]
                        for context in contexts[::-1]:
                            if "rec_items" in context:
                                rec_items = context["rec_items"]
                                metric.evaluate(rec_items, rec_label)
                                break

                report = metric.report()

                print(
                    "r1:",
                    f"{report['recall@1']:.3f}",
                    "r10:",
                    f"{report['recall@10']:.3f}",
                    "r25:",
                    f"{report['recall@25']:.3f}",
                    "r50:",
                    f"{report['recall@50']:.3f}",
                    "count:",
                    report["count"],
                )
                if mode == "chat":
                    persuativeness_score = persuatiness / len(path_list)
                    print(f"{persuativeness_score:.3f}")
                    report["persuativeness"] = persuativeness_score

                with open(f"{result_path}/{dataset}.json", "w", encoding="utf-8") as w:
                    w.write(json.dumps(report))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--turn_num", type=int)
    parser.add_argument("--mode", type=str)
    args = parser.parse_args()
    rec_eval(args.turn_num, args.mode)