DeanGumas commited on
Commit
2b3100e
·
1 Parent(s): ace09b0

updated rag evaluation function to match that from baseline testing

Browse files
__pycache__/rag_metadata.cpython-312.pyc ADDED
Binary file (3.32 kB). View file
 
test_rag.py CHANGED
@@ -24,7 +24,7 @@ print("\n")
24
  # ------------------------------
25
  # Load tokenizer and model
26
  # ------------------------------
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
29
  model = AutoModelForCausalLM.from_pretrained(
30
  "./deepseek-coder-1.3b-instruct",
@@ -160,48 +160,111 @@ retriever.add_documents(metadata_docs)
160
  # ------------------------------
161
  # Define a function to compare model output to ground truth
162
  # ------------------------------
163
- def compare_result(sample_query, sample_result, generated_query, actual_result):
164
- # Remove any prefixes from the generated query
165
- if generated_query.startswith("SQLite: "):
166
- query = generated_query[len("SQLite: "):]
167
- elif generated_query.startswith("SQL: "):
168
- query = generated_query[len("SQL: "):]
 
 
 
 
 
 
 
 
169
  else:
170
- query = generated_query
171
 
172
- # Truncate query after the first semicolon (if present)
173
- semicolon_index = query.find(";")
174
- if semicolon_index != -1:
175
- query = query[:semicolon_index+1]
176
-
177
- # Simple function to clean strings: removes whitespace and lowercases.
178
- clean_str = lambda s: "".join(s.split()).lower()
179
-
180
- # Compare the generated query text with the sample query.
181
- query_match = (clean_str(query) == clean_str(sample_query))
182
 
183
- # Compare the expected result and the actual result numerically if possible.
184
  try:
185
- sample_val = float(sample_result)
186
- actual_val = float(actual_result)
187
- result_match = math.isclose(sample_val, actual_val, abs_tol=1e-6)
188
- except Exception:
189
- # Otherwise, do a cleaned string comparison.
190
- result_match = (clean_str(str(sample_result)) == clean_str(str(actual_result)))
191
-
192
- overall_valid = query_match and result_match
193
 
194
- # Debug output.
195
- print("DEBUG: Expected Result (from dataset):", sample_result)
196
- print("DEBUG: Actual DB Result:", actual_result)
197
- try:
198
- sample_val = float(sample_result)
199
- actual_val = float(actual_result)
200
- print("DEBUG: Numeric Comparison result:", math.isclose(sample_val, actual_val, abs_tol=1e-6))
201
- except Exception:
202
- print("DEBUG: Numeric Comparison: N/A")
203
-
204
- return overall_valid, query_match, result_match
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
 
207
  # ------------------------------
@@ -313,7 +376,7 @@ Request: {row["natural_query"]}
313
  actual_result = "Error executing query: " + str(e)
314
 
315
  # Compare the ground truth query and expected result to the generated query and actual result.
316
- valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], generated_query, actual_result)
317
  print("=============================================")
318
  print(f"Overall Valid: {valid}")
319
  print(f"SQL Query Matched: {sql_matched}")
 
24
  # ------------------------------
25
  # Load tokenizer and model
26
  # ------------------------------
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
  tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
29
  model = AutoModelForCausalLM.from_pretrained(
30
  "./deepseek-coder-1.3b-instruct",
 
160
  # ------------------------------
161
  # Define a function to compare model output to ground truth
162
  # ------------------------------
163
+ def compare_result(sample_query, sample_result, query_output):
164
+ # Clean model output to only have the query output
165
+ if query_output[0:8] == "SQLite:\n":
166
+ query = query_output[8:]
167
+ elif query_output[0:8] == "SQLite: ":
168
+ query = query_output[8:]
169
+ elif query_output[0:7] == "SQLite:":
170
+ query = query_output[7:]
171
+ elif query_output[0:5] == "SQL:\n":
172
+ query = query_output[5:]
173
+ elif query_output[0:5] == "SQL: ":
174
+ query = query_output[5:]
175
+ elif query_output[0:4] == "SQL:":
176
+ query = query_output[4:]
177
  else:
178
+ query = query_output
179
 
180
+ # Clean any excess text after the query semicolon
181
+ for i in range(len(query)):
182
+ if query[i] == ";":
183
+ query = query[:i+1]
184
+ break
 
 
 
 
 
185
 
186
+ # Try to execute query, if it fails, then this is a failure of the model
187
  try:
188
+ # Execute query and obtain result
189
+ cursor.execute(query)
190
+ rows = cursor.fetchall()
 
 
 
 
 
191
 
192
+ # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
193
+ query = query.replace(" ", "").replace("\n", "").replace("\t", "")
194
+ sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
195
+ query_match = (query == sample_query)
196
+
197
+ # If the queries match, the results clearly also match
198
+ if query_match:
199
+ return True, True, True
200
+
201
+ # Check if this is a multi-line query
202
+ if "|" in sample_result or "(" in sample_result:
203
+ #print(rows)
204
+ # Create list of results by stripping separators and splitting on them
205
+ if "(" in sample_result:
206
+ sample_result = sample_result.replace("(", "").replace(")", "")
207
+ result_list = sample_result.split(",")
208
+ else:
209
+ result_list = sample_result.split("|")
210
+
211
+ # Strip all results in list
212
+ for i in range(len(result_list)):
213
+ result_list[i] = str(result_list[i]).strip()
214
+
215
+ # Loop through model result and see if it matches training example
216
+ result = False
217
+ for row in rows:
218
+ for r in row:
219
+ for res in result_list:
220
+ try:
221
+ if math.isclose(float(r), float(res), abs_tol=0.5):
222
+ return True, query_match, True
223
+ except:
224
+ if str(r) in res or res in str(r):
225
+ return True, query_match, True
226
+
227
+ # Check if the model returned a sum of examples as opposed to the whole thing
228
+ if len(rows) == 1:
229
+ for r in rows[0]:
230
+ if r == str(len(result_list)):
231
+ return True, query_match, True
232
+
233
+ return True, query_match, result
234
+ # Else the sample result is a single value or string
235
+ else:
236
+ #print(rows)
237
+ result = False
238
+ # Loop through model result and see if it contains the sample result
239
+ for row in rows:
240
+ for r in row:
241
+ # Check by string
242
+ if str(r) in str(sample_result):
243
+ try:
244
+ if math.isclose(float(r), float(sample_result), abs_tol=0.5):
245
+ return True, query_match, True
246
+ except:
247
+ return True, query_match, True
248
+ # Check by number, using try incase the cast as float fails
249
+ try:
250
+ if math.isclose(float(r), float(sample_result), abs_tol=0.5):
251
+ return True, query_match, True
252
+ except:
253
+ pass
254
+
255
+ # Check if the model returned a list of examples instead of a total sum (both acceptable)
256
+ try:
257
+ if len(rows) > 1 and len(rows) == int(sample_result):
258
+ return True, query_match, True
259
+ if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
260
+ return True, query_match, True
261
+ except:
262
+ pass
263
+
264
+ # Compare results and return
265
+ return True, query_match, result
266
+ except:
267
+ return False, False, False
268
 
269
 
270
  # ------------------------------
 
376
  actual_result = "Error executing query: " + str(e)
377
 
378
  # Compare the ground truth query and expected result to the generated query and actual result.
379
+ valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], generated_query)
380
  print("=============================================")
381
  print(f"Overall Valid: {valid}")
382
  print(f"SQL Query Matched: {sql_matched}")