Spaces:
Build error
Build error
import json | |
import os | |
import argparse | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--gold_file') | |
parser.add_argument('--retrieval_file') | |
parser.add_argument('--output') | |
parser.add_argument('--test', action='store_true', default=False) | |
args = parser.parse_args() | |
filter_dict = dict() | |
data_dict = dict() | |
golden_dict = dict() | |
with open(args.gold_file) as f: | |
for line in f: | |
data = json.loads(line) | |
data_dict[data["id"]] = {"id": data["id"], "evidence":[], "claim": data["claim"]} | |
if "label" in data: | |
data_dict[data["id"]]["label"] = data["label"] | |
if not args.test: | |
for evidence in data["evidence"]: | |
data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], 1.0]) | |
string = str(data["id"]) + "_" + evidence[0] + "_" + str(evidence[1]) | |
golden_dict[string] = 1 | |
with open(args.retrieval_file) as f: | |
for line in f: | |
data = json.loads(line) | |
for step, evidence in enumerate(data["evidence"]): | |
string = str(data["id"]) + "_" + str(evidence[0]) + "_" + str(evidence[1]) | |
if string not in golden_dict and string not in filter_dict: | |
data_dict[data["id"]]["evidence"].append([evidence[0], evidence[1], evidence[2], evidence[4]]) | |
filter_dict[string] = 1 | |
with open(args.output, "w") as out: | |
for data in data_dict.values(): | |
evidence_tmp = data["evidence"] | |
evidence_tmp = sorted(evidence_tmp, key=lambda x:x[3], reverse=True) | |
data["evidence"] = evidence_tmp[:5] | |
out.write(json.dumps(data) + "\n") | |