File size: 4,448 Bytes
c1d6d12
4022db3
23a14a5
c1d6d12
0405efb
 
 
 
 
c1d6d12
0405efb
 
 
 
c1d6d12
 
 
 
0405efb
 
 
 
 
 
c1d6d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0405efb
c1d6d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0405efb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import math

def compare_result(cursor, sample_query, sample_result, query_output):
    # Clean model output to only have the query output
    if query_output[0:8] == "SQLite:\n":
        query = query_output[8:]
    elif query_output[0:8] == "SQLite: ":
        query = query_output[8:]
    elif query_output[0:7] == "SQLite:":
        query = query_output[7:]
    elif query_output[0:5] == "SQL:\n":
        query = query_output[5:]
    elif query_output[0:5] == "SQL: ":
        query = query_output[5:]
    elif query_output[0:4] == "SQL:":
        query = query_output[4:]
    else:
        query = query_output

    # Clean any excess text after the query semicolon
    for i in range(len(query)):
        if query[i] == ";":
            query = query[:i+1]
            break
    
    # Try to execute query, if it fails, then this is a failure of the model
    try:
        # Execute query and obtain result
        cursor.execute(query)
        rows = cursor.fetchall()

        # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
        query = query.replace(" ", "").replace("\n", "").replace("\t", "")
        sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
        query_match = (query == sample_query)

        # If the queries match, the results clearly also match
        if query_match:
            return True, True, True

        # Check if this is a multi-line query
        if "|" in sample_result or "(" in sample_result:
            #print(rows)
            # Create list of results by stripping separators and splitting on them
            if "(" in sample_result:
                sample_result = sample_result.replace("(", "").replace(")", "")
                result_list = sample_result.split(",") 
            else:
                result_list = sample_result.split("|") 

            # Strip all results in list
            for i in range(len(result_list)):
                result_list[i] = str(result_list[i]).strip()
            
            # Loop through model result and see if it matches training example
            result = False
            for row in rows:
                for r in row:
                    for res in result_list:
                        try:
                            if math.isclose(float(r), float(res), abs_tol=0.5):
                                return True, query_match, True
                        except:
                            if str(r) in res or res in str(r):
                                return True, query_match, True
                    
            # Check if the model returned a sum of examples as opposed to the whole thing
            if len(rows) == 1:
                for r in rows[0]:
                    if r == str(len(result_list)):
                        return True, query_match, True
                    
            return True, query_match, result
        # Else the sample result is a single value or string
        else:
            #print(rows)
            result = False
            # Loop through model result and see if it contains the sample result
            for row in rows:
                for r in row:
                    # Check by string
                    if str(r) in str(sample_result):
                        try:
                            if math.isclose(float(r), float(sample_result), abs_tol=0.5):
                                return True, query_match, True
                        except:
                            return True, query_match, True
                    # Check by number, using try incase the cast as float fails
                    try:
                        if math.isclose(float(r), float(sample_result), abs_tol=0.5):
                            return True, query_match, True
                    except:
                        pass

            # Check if the model returned a list of examples instead of a total sum (both acceptable)
            try:
                if len(rows) > 1 and len(rows) == int(sample_result):
                    return True, query_match, True
                if len(rows[0]) > 1 and rows[0][1] is not None and  len(rows[0]) == int(sample_result):
                    return True, query_match, True
            except:
                pass

            # Compare results and return
            return True, query_match, result
    except:
        return False, False, False