|
import re
|
|
import math
|
|
import numpy as np
|
|
import babel
|
|
from babel import numbers
|
|
|
|
from table_instruct.eval.scripts.qa_datadump_utils import naive_str_to_float
|
|
|
|
def hmt_score(prediction, answer):
|
|
prediction = hmt_process_answer(prediction)
|
|
answer = hmt_process_answer(answer)
|
|
if hmt_equal(prediction, answer):
|
|
return 1.0
|
|
else:
|
|
return 0.0
|
|
|
|
|
|
def hmt_process_answer(answer):
|
|
""" 4 types of answer: 1)region; 2)num_list(aggr); 3)header_list(argmax); 4)num(count/div)"""
|
|
if isinstance(answer, int) or isinstance(answer, float):
|
|
return float(answer)
|
|
if isinstance(answer, str):
|
|
return naive_str_to_float(answer.strip().lower())
|
|
if isinstance(answer, list):
|
|
if isinstance(answer[0], list):
|
|
if len(answer) == 1 and len(answer[0]) == 1:
|
|
return hmt_process_answer(answer[0][0])
|
|
elif len(answer) == 1:
|
|
return hmt_process_answer(answer[0])
|
|
elif len(answer[0]) == 1:
|
|
return hmt_process_answer([row[0] for row in answer])
|
|
else:
|
|
return [hmt_process_answer(a) for a in answer]
|
|
else:
|
|
if len(answer) == 1:
|
|
return hmt_process_answer(answer[0])
|
|
else:
|
|
return [hmt_process_answer(a) for a in answer]
|
|
|
|
|
|
def hmt_equal(prediction, answer):
|
|
|
|
|
|
try:
|
|
if prediction.split(",")[0] not in ['inf', 'infinity', 'INF', 'INFINITY', 'True', 'NAN', 'nan', 'False', '-inf', '-INF', '-INFINITY', '-infinity', 'NaN', 'Nan'] and float(prediction.split(",")[0]):
|
|
prediction = float(prediction.split(",")[0])
|
|
except:
|
|
prediction = prediction
|
|
|
|
if type(prediction) != type(answer):
|
|
return False
|
|
if isinstance(prediction, str):
|
|
|
|
return prediction.find(answer) == 0
|
|
if isinstance(prediction, int) or isinstance(prediction, float):
|
|
|
|
|
|
return math.fabs(prediction - answer) < 1e-5
|
|
if isinstance(prediction, list):
|
|
if len(prediction) != len(answer):
|
|
return False
|
|
return all([hmt_equal(prediction[i], answer[i]) for i in range(len(prediction))])
|
|
|
|
|
|
def evaluate(golds, preds):
|
|
correct_num = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
correct_list = []
|
|
wrong_list = []
|
|
for i in range(len(preds)):
|
|
gold = golds[i]
|
|
pred = preds[i]
|
|
if hmt_score(pred, gold) == 1:
|
|
correct_item = {}
|
|
correct_item["idx"] = i
|
|
correct_item["output"] = gold
|
|
correct_item["pred"] = pred
|
|
correct_list.append(correct_item)
|
|
else:
|
|
wrong_item = {}
|
|
wrong_item["idx"] = i
|
|
wrong_item["output"] = gold
|
|
wrong_item["pred"] = pred
|
|
wrong_list.append(wrong_item)
|
|
|
|
correct_num += hmt_score(pred, gold)
|
|
|
|
correct_percent = 100.0 * correct_num/len(golds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {'exact_match': correct_percent}
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
pred = "-0.8"
|
|
gold = 0.8
|
|
print(hmt_score(pred, gold))
|
|
|