updated rag evaluation function to match that from baseline testing
Browse files- __pycache__/rag_metadata.cpython-312.pyc +0 -0
- test_rag.py +102 -39
__pycache__/rag_metadata.cpython-312.pyc
ADDED
Binary file (3.32 kB). View file
|
|
test_rag.py
CHANGED
@@ -24,7 +24,7 @@ print("\n")
|
|
24 |
# ------------------------------
|
25 |
# Load tokenizer and model
|
26 |
# ------------------------------
|
27 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
"./deepseek-coder-1.3b-instruct",
|
@@ -160,48 +160,111 @@ retriever.add_documents(metadata_docs)
|
|
160 |
# ------------------------------
|
161 |
# Define a function to compare model output to ground truth
|
162 |
# ------------------------------
|
163 |
-
def compare_result(sample_query, sample_result,
|
164 |
-
#
|
165 |
-
if
|
166 |
-
query =
|
167 |
-
elif
|
168 |
-
query =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
else:
|
170 |
-
query =
|
171 |
|
172 |
-
#
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
# Simple function to clean strings: removes whitespace and lowercases.
|
178 |
-
clean_str = lambda s: "".join(s.split()).lower()
|
179 |
-
|
180 |
-
# Compare the generated query text with the sample query.
|
181 |
-
query_match = (clean_str(query) == clean_str(sample_query))
|
182 |
|
183 |
-
#
|
184 |
try:
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
except Exception:
|
189 |
-
# Otherwise, do a cleaned string comparison.
|
190 |
-
result_match = (clean_str(str(sample_result)) == clean_str(str(actual_result)))
|
191 |
-
|
192 |
-
overall_valid = query_match and result_match
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
|
207 |
# ------------------------------
|
@@ -313,7 +376,7 @@ Request: {row["natural_query"]}
|
|
313 |
actual_result = "Error executing query: " + str(e)
|
314 |
|
315 |
# Compare the ground truth query and expected result to the generated query and actual result.
|
316 |
-
valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], generated_query
|
317 |
print("=============================================")
|
318 |
print(f"Overall Valid: {valid}")
|
319 |
print(f"SQL Query Matched: {sql_matched}")
|
|
|
24 |
# ------------------------------
|
25 |
# Load tokenizer and model
|
26 |
# ------------------------------
|
27 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
28 |
tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
"./deepseek-coder-1.3b-instruct",
|
|
|
160 |
# ------------------------------
|
161 |
# Define a function to compare model output to ground truth
|
162 |
# ------------------------------
|
163 |
+
def compare_result(sample_query, sample_result, query_output):
|
164 |
+
# Clean model output to only have the query output
|
165 |
+
if query_output[0:8] == "SQLite:\n":
|
166 |
+
query = query_output[8:]
|
167 |
+
elif query_output[0:8] == "SQLite: ":
|
168 |
+
query = query_output[8:]
|
169 |
+
elif query_output[0:7] == "SQLite:":
|
170 |
+
query = query_output[7:]
|
171 |
+
elif query_output[0:5] == "SQL:\n":
|
172 |
+
query = query_output[5:]
|
173 |
+
elif query_output[0:5] == "SQL: ":
|
174 |
+
query = query_output[5:]
|
175 |
+
elif query_output[0:4] == "SQL:":
|
176 |
+
query = query_output[4:]
|
177 |
else:
|
178 |
+
query = query_output
|
179 |
|
180 |
+
# Clean any excess text after the query semicolon
|
181 |
+
for i in range(len(query)):
|
182 |
+
if query[i] == ";":
|
183 |
+
query = query[:i+1]
|
184 |
+
break
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
+
# Try to execute query, if it fails, then this is a failure of the model
|
187 |
try:
|
188 |
+
# Execute query and obtain result
|
189 |
+
cursor.execute(query)
|
190 |
+
rows = cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
# Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
|
193 |
+
query = query.replace(" ", "").replace("\n", "").replace("\t", "")
|
194 |
+
sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
|
195 |
+
query_match = (query == sample_query)
|
196 |
+
|
197 |
+
# If the queries match, the results clearly also match
|
198 |
+
if query_match:
|
199 |
+
return True, True, True
|
200 |
+
|
201 |
+
# Check if this is a multi-line query
|
202 |
+
if "|" in sample_result or "(" in sample_result:
|
203 |
+
#print(rows)
|
204 |
+
# Create list of results by stripping separators and splitting on them
|
205 |
+
if "(" in sample_result:
|
206 |
+
sample_result = sample_result.replace("(", "").replace(")", "")
|
207 |
+
result_list = sample_result.split(",")
|
208 |
+
else:
|
209 |
+
result_list = sample_result.split("|")
|
210 |
+
|
211 |
+
# Strip all results in list
|
212 |
+
for i in range(len(result_list)):
|
213 |
+
result_list[i] = str(result_list[i]).strip()
|
214 |
+
|
215 |
+
# Loop through model result and see if it matches training example
|
216 |
+
result = False
|
217 |
+
for row in rows:
|
218 |
+
for r in row:
|
219 |
+
for res in result_list:
|
220 |
+
try:
|
221 |
+
if math.isclose(float(r), float(res), abs_tol=0.5):
|
222 |
+
return True, query_match, True
|
223 |
+
except:
|
224 |
+
if str(r) in res or res in str(r):
|
225 |
+
return True, query_match, True
|
226 |
+
|
227 |
+
# Check if the model returned a sum of examples as opposed to the whole thing
|
228 |
+
if len(rows) == 1:
|
229 |
+
for r in rows[0]:
|
230 |
+
if r == str(len(result_list)):
|
231 |
+
return True, query_match, True
|
232 |
+
|
233 |
+
return True, query_match, result
|
234 |
+
# Else the sample result is a single value or string
|
235 |
+
else:
|
236 |
+
#print(rows)
|
237 |
+
result = False
|
238 |
+
# Loop through model result and see if it contains the sample result
|
239 |
+
for row in rows:
|
240 |
+
for r in row:
|
241 |
+
# Check by string
|
242 |
+
if str(r) in str(sample_result):
|
243 |
+
try:
|
244 |
+
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
|
245 |
+
return True, query_match, True
|
246 |
+
except:
|
247 |
+
return True, query_match, True
|
248 |
+
# Check by number, using try incase the cast as float fails
|
249 |
+
try:
|
250 |
+
if math.isclose(float(r), float(sample_result), abs_tol=0.5):
|
251 |
+
return True, query_match, True
|
252 |
+
except:
|
253 |
+
pass
|
254 |
+
|
255 |
+
# Check if the model returned a list of examples instead of a total sum (both acceptable)
|
256 |
+
try:
|
257 |
+
if len(rows) > 1 and len(rows) == int(sample_result):
|
258 |
+
return True, query_match, True
|
259 |
+
if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
|
260 |
+
return True, query_match, True
|
261 |
+
except:
|
262 |
+
pass
|
263 |
+
|
264 |
+
# Compare results and return
|
265 |
+
return True, query_match, result
|
266 |
+
except:
|
267 |
+
return False, False, False
|
268 |
|
269 |
|
270 |
# ------------------------------
|
|
|
376 |
actual_result = "Error executing query: " + str(e)
|
377 |
|
378 |
# Compare the ground truth query and expected result to the generated query and actual result.
|
379 |
+
valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], generated_query)
|
380 |
print("=============================================")
|
381 |
print(f"Overall Valid: {valid}")
|
382 |
print(f"SQL Query Matched: {sql_matched}")
|