File size: 852 Bytes
2a26d3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import json
import argparse
def main(args):
with open(args.pred_file, "r") as f:
data = json.load(f)
correct = 0
remove_count = 0
for i in range(len(data)):
ground_truth = data[i]["output"]
prediction = data[i]["predict"].strip("</s>")
# if prediction.find(ground_truth) == 0:
if prediction == ground_truth:
correct += 1
if prediction.find("<s>") == 0:
remove_count += 1
print("correct:", correct)
# print("remove_count:", remove_count)
print("accuracy:", correct/(len(data)-remove_count))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--pred_file', type=str, default='/TableLlama/ckpfinal_pred/tabfact_pred.json', help='')
args = parser.parse_args()
main(args)
|