File size: 4,440 Bytes
c1d6d12 0405efb c1d6d12 0405efb c1d6d12 0405efb c1d6d12 0405efb c1d6d12 0405efb c1d6d12 0405efb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import math
def compare_result(sample_query, sample_result, query_output):
# Clean model output to only have the query output
if query_output[0:8] == "SQLite:\n":
query = query_output[8:]
elif query_output[0:8] == "SQLite: ":
query = query_output[8:]
elif query_output[0:7] == "SQLite:":
query = query_output[7:]
elif query_output[0:5] == "SQL:\n":
query = query_output[5:]
elif query_output[0:5] == "SQL: ":
query = query_output[5:]
elif query_output[0:4] == "SQL:":
query = query_output[4:]
else:
query = query_output
# Clean any excess text after the query semicolon
for i in range(len(query)):
if query[i] == ";":
query = query[:i+1]
break
# Try to execute query, if it fails, then this is a failure of the model
try:
# Execute query and obtain result
cursor.execute(query)
rows = cursor.fetchall()
# Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
query = query.replace(" ", "").replace("\n", "").replace("\t", "")
sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
query_match = (query == sample_query)
# If the queries match, the results clearly also match
if query_match:
return True, True, True
# Check if this is a multi-line query
if "|" in sample_result or "(" in sample_result:
#print(rows)
# Create list of results by stripping separators and splitting on them
if "(" in sample_result:
sample_result = sample_result.replace("(", "").replace(")", "")
result_list = sample_result.split(",")
else:
result_list = sample_result.split("|")
# Strip all results in list
for i in range(len(result_list)):
result_list[i] = str(result_list[i]).strip()
# Loop through model result and see if it matches training example
result = False
for row in rows:
for r in row:
for res in result_list:
try:
if math.isclose(float(r), float(res), abs_tol=0.5):
return True, query_match, True
except:
if str(r) in res or res in str(r):
return True, query_match, True
# Check if the model returned a sum of examples as opposed to the whole thing
if len(rows) == 1:
for r in rows[0]:
if r == str(len(result_list)):
return True, query_match, True
return True, query_match, result
# Else the sample result is a single value or string
else:
#print(rows)
result = False
# Loop through model result and see if it contains the sample result
for row in rows:
for r in row:
# Check by string
if str(r) in str(sample_result):
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
return True, query_match, True
# Check by number, using try incase the cast as float fails
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
pass
# Check if the model returned a list of examples instead of a total sum (both acceptable)
try:
if len(rows) > 1 and len(rows) == int(sample_result):
return True, query_match, True
if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
return True, query_match, True
except:
pass
# Compare results and return
return True, query_match, result
except:
return False, False, False |