File size: 2,744 Bytes
3891395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from evaluation.evaluate_utils.evaluate_factory import get_evaluator


def fix_ans(answer):

    try:
        answer = answer.replace("{'", '{"').replace("', '", '", "').replace("': '", '": "').replace("'}", '"}')
        answer = answer.replace("': ", '": ')
        return answer
    except:
        return answer


def parse_answer(answer):

    if len(answer) == 1:
        if answer[0].isnumeric():
            ans, is_num = fix_number(answer[0])
            if is_num:
                return ans, 'number'
        try:
            ans = json.loads(fix_ans(answer[0]))
            return [ans], 'json'
        except:
            ans, is_num = fix_number(answer[0])
            if is_num:
                return ans, 'number'
            else:
                return answer[0], 'string'
    else:
        try:
            ans = [json.loads(fix_ans(ex)) for ex in answer]
            return ans, 'json'
        except:
            return answer, "string list"


def fix_number(number):

    if type(number) == str:
        copy_ans = number
        copy_ans = ' '.join(' '.join(' '.join(copy_ans.split('$')).split('%')).split('sqft')).strip()
        copy_ans = copy_ans.strip()
        copy_ans = copy_ans.replace(',', '.').replace(' square kilometers', '')
        try:
            return float(copy_ans), True
        except:
            return number, False
    elif type(number) == int:
        return float(number), True
    else:
        return number, True


def fix_prediction(prediction, gold_answer, evaluator):

    if type(prediction) == list and len(prediction) == 1 and (type(prediction[0]) == int or ((type(prediction[0]) == str) and prediction[0].isnumeric())):
        prediction = fix_number(prediction[0])

    if type(prediction) != list:
        prediction, is_num = fix_number(prediction)
        if evaluator == 'json':
            try:
                prediction = [json.loads(pred) for pred in prediction.split('\n')]
            except:
                prediction = [prediction]

    if (hasattr(type(prediction), '__len__')) and (len(prediction) == 0):
        return prediction, False

    if (type(prediction) == list and len(prediction) > 1) and type(gold_answer) == float:
        return prediction, False

    return prediction, True


def question_scorer(prediction, gold_answer):

    answer_list = [x for x in gold_answer.split("\n") if len(x.strip()) > 0] if type(gold_answer) != list else gold_answer
    gold_answer, evaluator = parse_answer(answer_list)
    prediction, run_eval = fix_prediction(prediction, gold_answer, evaluator)

    if not run_eval:
        return 0.

    metric_eval = get_evaluator(evaluator)
    accuracy = metric_eval(prediction, gold_answer)
    return accuracy