import math def compare_result(cursor, sample_query, sample_result, query_output): # Clean model output to only have the query output if query_output[0:8] == "SQLite:\n": query = query_output[8:] elif query_output[0:8] == "SQLite: ": query = query_output[8:] elif query_output[0:7] == "SQLite:": query = query_output[7:] elif query_output[0:5] == "SQL:\n": query = query_output[5:] elif query_output[0:5] == "SQL: ": query = query_output[5:] elif query_output[0:4] == "SQL:": query = query_output[4:] else: query = query_output # Clean any excess text after the query semicolon for i in range(len(query)): if query[i] == ";": query = query[:i+1] break # Try to execute query, if it fails, then this is a failure of the model try: # Execute query and obtain result cursor.execute(query) rows = cursor.fetchall() # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc. query = query.replace(" ", "").replace("\n", "").replace("\t", "") sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "") query_match = (query == sample_query) # If the queries match, the results clearly also match if query_match: return True, True, True # Check if this is a multi-line query if "|" in sample_result or "(" in sample_result: #print(rows) # Create list of results by stripping separators and splitting on them if "(" in sample_result: sample_result = sample_result.replace("(", "").replace(")", "") result_list = sample_result.split(",") else: result_list = sample_result.split("|") # Strip all results in list for i in range(len(result_list)): result_list[i] = str(result_list[i]).strip() # Loop through model result and see if it matches training example result = False for row in rows: for r in row: for res in result_list: try: if math.isclose(float(r), float(res), abs_tol=0.5): return True, query_match, True except: if str(r) in res or res in str(r): return True, query_match, True # Check if the model returned a sum of examples as opposed to the whole thing if len(rows) == 1: for r in rows[0]: if r == str(len(result_list)): return True, query_match, True return True, query_match, result # Else the sample result is a single value or string else: #print(rows) result = False # Loop through model result and see if it contains the sample result for row in rows: for r in row: # Check by string if str(r) in str(sample_result): try: if math.isclose(float(r), float(sample_result), abs_tol=0.5): return True, query_match, True except: return True, query_match, True # Check by number, using try incase the cast as float fails try: if math.isclose(float(r), float(sample_result), abs_tol=0.5): return True, query_match, True except: pass # Check if the model returned a list of examples instead of a total sum (both acceptable) try: if len(rows) > 1 and len(rows) == int(sample_result): return True, query_match, True if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result): return True, query_match, True except: pass # Compare results and return return True, query_match, result except: return False, False, False