File size: 4,041 Bytes
2a26d3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import re
import math
import numpy as np
import babel
from babel import numbers
from table_instruct.eval.scripts.qa_datadump_utils import naive_str_to_float
def hmt_score(prediction, answer):
prediction = hmt_process_answer(prediction)
answer = hmt_process_answer(answer)
if hmt_equal(prediction, answer):
return 1.0
else:
return 0.0
def hmt_process_answer(answer):
""" 4 types of answer: 1)region; 2)num_list(aggr); 3)header_list(argmax); 4)num(count/div)"""
if isinstance(answer, int) or isinstance(answer, float):
return float(answer)
if isinstance(answer, str):
return naive_str_to_float(answer.strip().lower())
if isinstance(answer, list):
if isinstance(answer[0], list): # pred region
if len(answer) == 1 and len(answer[0]) == 1: # pred region with one cell, flatten
return hmt_process_answer(answer[0][0])
elif len(answer) == 1: # pred region with one line
return hmt_process_answer(answer[0])
elif len(answer[0]) == 1: # pred region with one line
return hmt_process_answer([row[0] for row in answer])
else: # pred region is a matrix
return [hmt_process_answer(a) for a in answer]
else: # list or processed single-line region
if len(answer) == 1: # answer with one cell or pred list
return hmt_process_answer(answer[0])
else:
return [hmt_process_answer(a) for a in answer]
def hmt_equal(prediction, answer):
# import pdb
# pdb.set_trace()
try:
if prediction.split(",")[0] not in ['inf', 'infinity', 'INF', 'INFINITY', 'True', 'NAN', 'nan', 'False', '-inf', '-INF', '-INFINITY', '-infinity', 'NaN', 'Nan'] and float(prediction.split(",")[0]):
prediction = float(prediction.split(",")[0])
except:
prediction = prediction
if type(prediction) != type(answer):
return False
if isinstance(prediction, str):
# print("str:", answer, prediction)
return prediction.find(answer) == 0
if isinstance(prediction, int) or isinstance(prediction, float):
# return math.fabs(prediction - answer) < 1e-5
# print("float/int:", answer, prediction)
return math.fabs(prediction - answer) < 1e-5
if isinstance(prediction, list):
if len(prediction) != len(answer):
return False
return all([hmt_equal(prediction[i], answer[i]) for i in range(len(prediction))])
def evaluate(golds, preds):
correct_num = 0.0
# print(golds[0]["answer"])
# print(preds[0])
# import pdb
# pdb.set_trace()
correct_list = []
wrong_list = []
for i in range(len(preds)):
gold = golds[i]
pred = preds[i]
if hmt_score(pred, gold) == 1:
correct_item = {}
correct_item["idx"] = i
correct_item["output"] = gold
correct_item["pred"] = pred
correct_list.append(correct_item)
else:
wrong_item = {}
wrong_item["idx"] = i
wrong_item["output"] = gold
wrong_item["pred"] = pred
wrong_list.append(wrong_item)
correct_num += hmt_score(pred, gold)
correct_percent = 100.0 * correct_num/len(golds)
# import json
# with open("/users/PAA0201/shubaobao/LongLoRA/tommy_server/ckp-15k-pred/correct_hitab.json", "w") as f:
# json.dump(correct_list, f, indent = 2)
# with open("/users/PAA0201/shubaobao/LongLoRA/tommy_server/ckp-15k-pred/wrong_hitab.json", "w") as f:
# json.dump(wrong_list, f, indent = 2)
# print("correct_percent:", correct_percent)
# import pdb
# pdb.set_trace()
return {'exact_match': correct_percent}
if __name__ == '__main__':
# main()
# test()
pred = "-0.8"
gold = 0.8
print(hmt_score(pred, gold))
|