SQL-Generation / src /evaluation /compare_result.py
DeanGumas's picture
Updated compare_result function to allow passing the cursor, also re-ran test_pretrained and test_rag with updated loss function
23a14a5
raw
history blame
4.45 kB
import math
def compare_result(cursor, sample_query, sample_result, query_output):
# Clean model output to only have the query output
if query_output[0:8] == "SQLite:\n":
query = query_output[8:]
elif query_output[0:8] == "SQLite: ":
query = query_output[8:]
elif query_output[0:7] == "SQLite:":
query = query_output[7:]
elif query_output[0:5] == "SQL:\n":
query = query_output[5:]
elif query_output[0:5] == "SQL: ":
query = query_output[5:]
elif query_output[0:4] == "SQL:":
query = query_output[4:]
else:
query = query_output
# Clean any excess text after the query semicolon
for i in range(len(query)):
if query[i] == ";":
query = query[:i+1]
break
# Try to execute query, if it fails, then this is a failure of the model
try:
# Execute query and obtain result
cursor.execute(query)
rows = cursor.fetchall()
# Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
query = query.replace(" ", "").replace("\n", "").replace("\t", "")
sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
query_match = (query == sample_query)
# If the queries match, the results clearly also match
if query_match:
return True, True, True
# Check if this is a multi-line query
if "|" in sample_result or "(" in sample_result:
#print(rows)
# Create list of results by stripping separators and splitting on them
if "(" in sample_result:
sample_result = sample_result.replace("(", "").replace(")", "")
result_list = sample_result.split(",")
else:
result_list = sample_result.split("|")
# Strip all results in list
for i in range(len(result_list)):
result_list[i] = str(result_list[i]).strip()
# Loop through model result and see if it matches training example
result = False
for row in rows:
for r in row:
for res in result_list:
try:
if math.isclose(float(r), float(res), abs_tol=0.5):
return True, query_match, True
except:
if str(r) in res or res in str(r):
return True, query_match, True
# Check if the model returned a sum of examples as opposed to the whole thing
if len(rows) == 1:
for r in rows[0]:
if r == str(len(result_list)):
return True, query_match, True
return True, query_match, result
# Else the sample result is a single value or string
else:
#print(rows)
result = False
# Loop through model result and see if it contains the sample result
for row in rows:
for r in row:
# Check by string
if str(r) in str(sample_result):
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
return True, query_match, True
# Check by number, using try incase the cast as float fails
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
pass
# Check if the model returned a list of examples instead of a total sum (both acceptable)
try:
if len(rows) > 1 and len(rows) == int(sample_result):
return True, query_match, True
if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
return True, query_match, True
except:
pass
# Compare results and return
return True, query_match, result
except:
return False, False, False