|
import argparse |
|
import json |
|
import os |
|
import sys |
|
|
|
from tqdm import tqdm |
|
|
|
sys.path.append("..") |
|
|
|
from src.model.metric import RecMetric |
|
|
|
datasets = ["redial_eval", "opendialkg_eval"] |
|
models = ["kbrd", "barcor", "unicrs", "chatgpt"] |
|
|
|
|
|
|
|
def rec_eval(turn_num, mode): |
|
for dataset in datasets: |
|
with open( |
|
f"../data/{dataset.split('_')[0]}/entity2id.json", |
|
"r", |
|
encoding="utf-8", |
|
) as f: |
|
entity2id = json.load(f) |
|
|
|
for model in models: |
|
metric = RecMetric([1, 10, 25, 50]) |
|
persuatiness = 0 |
|
save_path = ( |
|
f"../save_{turn_num}/{mode}/{model}/{dataset}" |
|
) |
|
result_path = f"../save_{turn_num}/result/{mode}/{model}" |
|
os.makedirs(result_path, exist_ok=True) |
|
if os.path.exists(save_path) and len(os.listdir(save_path)) > 0: |
|
path_list = os.listdir(save_path) |
|
print( |
|
f"turn_num: {turn_num}, mode: {mode} model: {model} dataset: {dataset}", |
|
len(path_list), |
|
) |
|
|
|
for path in tqdm(path_list): |
|
with open(f"{save_path}/{path}", "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
if mode == "chat": |
|
persuasiveness_score = data["persuasiveness_score"] |
|
persuatiness += float(persuasiveness_score) |
|
PE_dialog = data["simulator_dialog"] |
|
rec_label = data["rec"] |
|
rec_label = [ |
|
entity2id[rec] for rec in rec_label if rec in entity2id |
|
] |
|
contexts = PE_dialog["context"] |
|
for context in contexts[::-1]: |
|
if "rec_items" in context: |
|
rec_items = context["rec_items"] |
|
metric.evaluate(rec_items, rec_label) |
|
break |
|
|
|
report = metric.report() |
|
|
|
print( |
|
"r1:", |
|
f"{report['recall@1']:.3f}", |
|
"r10:", |
|
f"{report['recall@10']:.3f}", |
|
"r25:", |
|
f"{report['recall@25']:.3f}", |
|
"r50:", |
|
f"{report['recall@50']:.3f}", |
|
"count:", |
|
report["count"], |
|
) |
|
if mode == "chat": |
|
persuativeness_score = persuatiness / len(path_list) |
|
print(f"{persuativeness_score:.3f}") |
|
report["persuativeness"] = persuativeness_score |
|
|
|
with open(f"{result_path}/{dataset}.json", "w", encoding="utf-8") as w: |
|
w.write(json.dumps(report)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--turn_num", type=int) |
|
parser.add_argument("--mode", type=str) |
|
args = parser.parse_args() |
|
rec_eval(args.turn_num, args.mode) |
|
|