jiangjiechen's picture
init loren for spaces
7f7285f
raw
history blame
1.79 kB
import json
import os
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold_file')
parser.add_argument('--retrieval_file')
parser.add_argument('--output')
parser.add_argument('--test', action='store_true', default=False)
args = parser.parse_args()
filter_dict = dict()
data_dict = dict()
golden_dict = dict()
with open(args.gold_file) as f:
for line in f:
data = json.loads(line)
data_dict[data["id"]] = {"id": data["id"], "evidence":[], "claim": data["claim"]}
if "label" in data:
data_dict[data["id"]]["label"] = data["label"]
if not args.test:
for evidence in data["evidence"]:
data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], 1.0])
string = str(data["id"]) + "_" + evidence[0] + "_" + str(evidence[1])
golden_dict[string] = 1
with open(args.retrieval_file) as f:
for line in f:
data = json.loads(line)
for step, evidence in enumerate(data["evidence"]):
string = str(data["id"]) + "_" + str(evidence[0]) + "_" + str(evidence[1])
if string not in golden_dict and string not in filter_dict:
data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], evidence[4]])
filter_dict[string] = 1
with open(args.output, "w") as out:
for data in data_dict.values():
evidence_tmp = data["evidence"]
evidence_tmp = sorted(evidence_tmp, key=lambda x:x[3], reverse=True)
data["evidence"] = evidence_tmp[:5]
out.write(json.dumps(data) + "\n")