SQL-Generation / src /evaluation /compare_result.py
DeanGumas's picture
Tweaked compare_result to work with proper directory, updated rag python notebook prompt and vector stores to increase accuracy
4022db3
raw
history blame
4.59 kB
import math
import sqlite3 as sql
def compare_result(sample_query, sample_result, query_output):
# Create connection to sqlite3 database
connection = sql.connect('./nba-data/nba.sqlite')
cursor = connection.cursor()
# Clean model output to only have the query output
if query_output[0:8] == "SQLite:\n":
query = query_output[8:]
elif query_output[0:8] == "SQLite: ":
query = query_output[8:]
elif query_output[0:7] == "SQLite:":
query = query_output[7:]
elif query_output[0:5] == "SQL:\n":
query = query_output[5:]
elif query_output[0:5] == "SQL: ":
query = query_output[5:]
elif query_output[0:4] == "SQL:":
query = query_output[4:]
else:
query = query_output
# Clean any excess text after the query semicolon
for i in range(len(query)):
if query[i] == ";":
query = query[:i+1]
break
# Try to execute query, if it fails, then this is a failure of the model
try:
# Execute query and obtain result
cursor.execute(query)
rows = cursor.fetchall()
# Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
query = query.replace(" ", "").replace("\n", "").replace("\t", "")
sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
query_match = (query == sample_query)
# If the queries match, the results clearly also match
if query_match:
return True, True, True
# Check if this is a multi-line query
if "|" in sample_result or "(" in sample_result:
#print(rows)
# Create list of results by stripping separators and splitting on them
if "(" in sample_result:
sample_result = sample_result.replace("(", "").replace(")", "")
result_list = sample_result.split(",")
else:
result_list = sample_result.split("|")
# Strip all results in list
for i in range(len(result_list)):
result_list[i] = str(result_list[i]).strip()
# Loop through model result and see if it matches training example
result = False
for row in rows:
for r in row:
for res in result_list:
try:
if math.isclose(float(r), float(res), abs_tol=0.5):
return True, query_match, True
except:
if str(r) in res or res in str(r):
return True, query_match, True
# Check if the model returned a sum of examples as opposed to the whole thing
if len(rows) == 1:
for r in rows[0]:
if r == str(len(result_list)):
return True, query_match, True
return True, query_match, result
# Else the sample result is a single value or string
else:
#print(rows)
result = False
# Loop through model result and see if it contains the sample result
for row in rows:
for r in row:
# Check by string
if str(r) in str(sample_result):
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
return True, query_match, True
# Check by number, using try incase the cast as float fails
try:
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
return True, query_match, True
except:
pass
# Check if the model returned a list of examples instead of a total sum (both acceptable)
try:
if len(rows) > 1 and len(rows) == int(sample_result):
return True, query_match, True
if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
return True, query_match, True
except:
pass
# Compare results and return
return True, query_match, result
except:
return False, False, False