DeanGumas commited on
Commit
23a14a5
·
1 Parent(s): 4022db3

Updated compare_result function to allow passing the cursor, also re-ran test_pretrained and test_rag with updated loss function

Browse files
src/evaluation/__pycache__/compare_result.cpython-312.pyc CHANGED
Binary files a/src/evaluation/__pycache__/compare_result.cpython-312.pyc and b/src/evaluation/__pycache__/compare_result.cpython-312.pyc differ
 
src/evaluation/compare_result.py CHANGED
@@ -1,11 +1,6 @@
1
  import math
2
- import sqlite3 as sql
3
-
4
- def compare_result(sample_query, sample_result, query_output):
5
- # Create connection to sqlite3 database
6
- connection = sql.connect('./nba-data/nba.sqlite')
7
- cursor = connection.cursor()
8
 
 
9
  # Clean model output to only have the query output
10
  if query_output[0:8] == "SQLite:\n":
11
  query = query_output[8:]
 
1
  import math
 
 
 
 
 
 
2
 
3
+ def compare_result(cursor, sample_query, sample_result, query_output):
4
  # Clean model output to only have the query output
5
  if query_output[0:8] == "SQLite:\n":
6
  query = query_output[8:]
src/prompts/__pycache__/prompt.cpython-312.pyc ADDED
Binary file (9.22 kB). View file
 
test_pretrained.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 31,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
@@ -26,7 +26,7 @@
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": 32,
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
@@ -35,7 +35,7 @@
35
  },
36
  {
37
  "cell_type": "code",
38
- "execution_count": null,
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
@@ -56,7 +56,7 @@
56
  },
57
  {
58
  "cell_type": "code",
59
- "execution_count": 34,
60
  "metadata": {},
61
  "outputs": [],
62
  "source": [
@@ -73,7 +73,7 @@
73
  },
74
  {
75
  "cell_type": "code",
76
- "execution_count": null,
77
  "metadata": {},
78
  "outputs": [
79
  {
@@ -83,9 +83,9 @@
83
  "Total dataset examples: 1044\n",
84
  "\n",
85
  "\n",
86
- "How many points did the Phoenix Suns score in the highest scoring away game they played?\n",
87
- "SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'PHX';\n",
88
- "161.0\n"
89
  ]
90
  }
91
  ],
@@ -111,16 +111,20 @@
111
  },
112
  {
113
  "cell_type": "code",
114
- "execution_count": null,
115
  "metadata": {},
116
  "outputs": [],
117
  "source": [
118
  "# Set device to cuda if available, otherwise CPU\n",
119
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
120
  "\n",
121
  "# Load model and tokenizer\n",
122
- "tokenizer = AutoTokenizer.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"))\n",
123
- "model = AutoModelForCausalLM.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"), torch_dtype=torch.bfloat16, device_map=device) \n",
 
 
 
 
124
  "model.generation_config.pad_token_id = tokenizer.pad_token_id"
125
  ]
126
  },
@@ -133,7 +137,7 @@
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": 28,
137
  "metadata": {},
138
  "outputs": [
139
  {
@@ -141,7 +145,7 @@
141
  "output_type": "stream",
142
  "text": [
143
  "SQLite:\n",
144
- "SELECT team_abbreviation_home FROM other_stats WHERE lead_changes = 1 AND season_id = '2001';\n",
145
  "\n"
146
  ]
147
  }
@@ -166,14 +170,15 @@
166
  },
167
  {
168
  "cell_type": "code",
169
- "execution_count": null,
170
  "metadata": {},
171
  "outputs": [
172
  {
173
  "name": "stdout",
174
  "output_type": "stream",
175
  "text": [
176
- "cleaned\n"
 
177
  ]
178
  }
179
  ],
@@ -209,18 +214,22 @@
209
  },
210
  {
211
  "cell_type": "code",
212
- "execution_count": 12,
213
  "metadata": {},
214
  "outputs": [
215
  {
216
- "ename": "ImportError",
217
- "evalue": "cannot import name 'compare_result_two' from 'src.evaluation.compare_result' (/Users/esteban/Documents/USC/spring_2025/NLP/SQL-Generation/src/evaluation/compare_result.py)",
218
- "output_type": "error",
219
- "traceback": [
220
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
221
- "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
222
- "Cell \u001b[0;32mIn[30], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mevaluation\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcompare_result\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m compare_result_two\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompare_result\u001b[39m(sample_query, sample_result, query_output):\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Clean model output to only have the query output\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m query_output[\u001b[38;5;241m0\u001b[39m:\u001b[38;5;241m7\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSQLite:\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
223
- "\u001b[0;31mImportError\u001b[0m: cannot import name 'compare_result_two' from 'src.evaluation.compare_result' (/Users/esteban/Documents/USC/spring_2025/NLP/SQL-Generation/src/evaluation/compare_result.py)"
 
 
 
 
224
  ]
225
  }
226
  ],
@@ -256,7 +265,7 @@
256
  },
257
  {
258
  "cell_type": "code",
259
- "execution_count": null,
260
  "metadata": {},
261
  "outputs": [],
262
  "source": [
@@ -304,7 +313,7 @@
304
  },
305
  {
306
  "cell_type": "code",
307
- "execution_count": null,
308
  "metadata": {},
309
  "outputs": [
310
  {
@@ -317,9 +326,9 @@
317
  "Completed 200\n",
318
  "\n",
319
  "Less than 90 results:\n",
320
- "Percent valid: 0.8448979591836735\n",
321
- "Percent SQLite matched: 0.43673469387755104\n",
322
- "Percent result matched: 0.6530612244897959\n",
323
  "Dataset length: 245\n"
324
  ]
325
  }
@@ -341,36 +350,7 @@
341
  "cell_type": "code",
342
  "execution_count": null,
343
  "metadata": {},
344
- "outputs": [
345
- {
346
- "name": "stdout",
347
- "output_type": "stream",
348
- "text": [
349
- "Completed 50\n",
350
- "Completed 100\n",
351
- "Completed 150\n",
352
- "Completed 200\n",
353
- "Completed 250\n",
354
- "Completed 300\n",
355
- "Completed 350\n",
356
- "Completed 400\n",
357
- "Completed 450\n",
358
- "Completed 500\n",
359
- "Completed 550\n",
360
- "Completed 600\n",
361
- "Completed 650\n",
362
- "Completed 700\n",
363
- "Completed 750\n",
364
- "Completed 800\n",
365
- "\n",
366
- "Queries from game results:\n",
367
- "Percent valid: 0.7613365155131265\n",
368
- "Percent SQLite matched: 0.13842482100238662\n",
369
- "Percent result matched: 0.383054892601432\n",
370
- "Dataset length: 838\n"
371
- ]
372
- }
373
- ],
374
  "source": [
375
  "game_queries = pd.read_csv(get_path(\"train-data/queries_from_game.tsv\"), sep='\\t')\n",
376
  "run_evaluation(game_queries, \"Queries from game\")\n",
@@ -388,23 +368,7 @@
388
  "cell_type": "code",
389
  "execution_count": null,
390
  "metadata": {},
391
- "outputs": [
392
- {
393
- "name": "stdout",
394
- "output_type": "stream",
395
- "text": [
396
- "Completed 50\n",
397
- "Completed 100\n",
398
- "Completed 150\n",
399
- "\n",
400
- "Queries from other stats results:\n",
401
- "Percent valid: 0.21428571428571427\n",
402
- "Percent SQLite matched: 0.01948051948051948\n",
403
- "Percent result matched: 0.07142857142857142\n",
404
- "Dataset length: 154\n"
405
- ]
406
- }
407
- ],
408
  "source": [
409
  "other_stats_queries = pd.read_csv(get_path(\"train-data/queries_from_other_stats.tsv\"), sep='\\t')\n",
410
  "run_evaluation(other_stats_queries, \"Queries from other stats\")\n",
@@ -422,21 +386,7 @@
422
  "cell_type": "code",
423
  "execution_count": null,
424
  "metadata": {},
425
- "outputs": [
426
- {
427
- "name": "stdout",
428
- "output_type": "stream",
429
- "text": [
430
- "Completed 50\n",
431
- "\n",
432
- "Queries from team results:\n",
433
- "Percent valid: 0.8653846153846154\n",
434
- "Percent SQLite matched: 0.5961538461538461\n",
435
- "Percent result matched: 0.7884615384615384\n",
436
- "Dataset length: 52\n"
437
- ]
438
- }
439
- ],
440
  "source": [
441
  "team_queries = pd.read_csv(get_path(\"train-data/queries_from_team.tsv\"), sep='\\t')\n",
442
  "run_evaluation(team_queries, \"Queries from team\")\n",
@@ -454,23 +404,7 @@
454
  "cell_type": "code",
455
  "execution_count": null,
456
  "metadata": {},
457
- "outputs": [
458
- {
459
- "name": "stdout",
460
- "output_type": "stream",
461
- "text": [
462
- "Completed 50\n",
463
- "Completed 100\n",
464
- "Completed 150\n",
465
- "\n",
466
- "Queries with join results:\n",
467
- "Percent valid: 0.1945945945945946\n",
468
- "Percent SQLite matched: 0.0\n",
469
- "Percent result matched: 0.04864864864864865\n",
470
- "Dataset length: 185\n"
471
- ]
472
- }
473
- ],
474
  "source": [
475
  "join_queries = pd.read_csv(get_path(\"train-data/with_join.tsv\"), sep='\\t')\n",
476
  "run_evaluation(join_queries, \"Queries with join\")\n",
@@ -488,37 +422,7 @@
488
  "cell_type": "code",
489
  "execution_count": null,
490
  "metadata": {},
491
- "outputs": [
492
- {
493
- "name": "stdout",
494
- "output_type": "stream",
495
- "text": [
496
- "Completed 50\n",
497
- "Completed 100\n",
498
- "Completed 150\n",
499
- "Completed 200\n",
500
- "Completed 250\n",
501
- "Completed 300\n",
502
- "Completed 350\n",
503
- "Completed 400\n",
504
- "Completed 450\n",
505
- "Completed 500\n",
506
- "Completed 550\n",
507
- "Completed 600\n",
508
- "Completed 650\n",
509
- "Completed 700\n",
510
- "Completed 750\n",
511
- "Completed 800\n",
512
- "Completed 850\n",
513
- "\n",
514
- "Queries without join results:\n",
515
- "Percent valid: 0.7916181606519208\n",
516
- "Percent SQLite matched: 0.17462165308498254\n",
517
- "Percent result matched: 0.42374854481955765\n",
518
- "Dataset length: 859\n"
519
- ]
520
- }
521
- ],
522
  "source": [
523
  "no_join_queries = pd.read_csv(get_path(\"train-data/without_join.tsv\"), sep='\\t')\n",
524
  "run_evaluation(no_join_queries, \"Queries without join\")\n",
@@ -534,7 +438,7 @@
534
  },
535
  {
536
  "cell_type": "code",
537
- "execution_count": 15,
538
  "metadata": {},
539
  "outputs": [
540
  {
@@ -563,9 +467,9 @@
563
  "Completed 1000\n",
564
  "\n",
565
  "All training data results:\n",
566
- "Percent valid: 0.685823754789272\n",
567
  "Percent SQLite matched: 0.14367816091954022\n",
568
- "Percent result matched: 0.35823754789272033\n",
569
  "Dataset length: 1044\n"
570
  ]
571
  }
@@ -579,7 +483,7 @@
579
  ],
580
  "metadata": {
581
  "kernelspec": {
582
- "display_name": "CSCI544",
583
  "language": "python",
584
  "name": "python3"
585
  },
@@ -593,7 +497,7 @@
593
  "name": "python",
594
  "nbconvert_exporter": "python",
595
  "pygments_lexer": "ipython3",
596
- "version": "3.11.11"
597
  }
598
  },
599
  "nbformat": 4,
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 1,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 2,
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
 
35
  },
36
  {
37
  "cell_type": "code",
38
+ "execution_count": 3,
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
 
56
  },
57
  {
58
  "cell_type": "code",
59
+ "execution_count": 4,
60
  "metadata": {},
61
  "outputs": [],
62
  "source": [
 
73
  },
74
  {
75
  "cell_type": "code",
76
+ "execution_count": 5,
77
  "metadata": {},
78
  "outputs": [
79
  {
 
83
  "Total dataset examples: 1044\n",
84
  "\n",
85
  "\n",
86
+ "How many times were games tied when the Indiana Pacers played at home?\n",
87
+ "SELECT SUM(times_tied) as total_times_tied FROM other_stats WHERE team_abbreviation_home = 'IND';\n",
88
+ "4805.0\n"
89
  ]
90
  }
91
  ],
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": 6,
115
  "metadata": {},
116
  "outputs": [],
117
  "source": [
118
  "# Set device to cuda if available, otherwise CPU\n",
119
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
120
  "\n",
121
  "# Load model and tokenizer\n",
122
+ "if is_google_colab:\n",
123
+ " tokenizer = AutoTokenizer.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"))\n",
124
+ " model = AutoModelForCausalLM.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"), torch_dtype=torch.bfloat16, device_map=device) \n",
125
+ "else:\n",
126
+ " tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
127
+ " model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) \n",
128
  "model.generation_config.pad_token_id = tokenizer.pad_token_id"
129
  ]
130
  },
 
137
  },
138
  {
139
  "cell_type": "code",
140
+ "execution_count": 7,
141
  "metadata": {},
142
  "outputs": [
143
  {
 
145
  "output_type": "stream",
146
  "text": [
147
  "SQLite:\n",
148
+ "SELECT COUNT(*) FROM game WHERE team_name_home = 'Indiana Pacers' AND wl_home = 'T';\n",
149
  "\n"
150
  ]
151
  }
 
170
  },
171
  {
172
  "cell_type": "code",
173
+ "execution_count": 8,
174
  "metadata": {},
175
  "outputs": [
176
  {
177
  "name": "stdout",
178
  "output_type": "stream",
179
  "text": [
180
+ "cleaned\n",
181
+ "(0,)\n"
182
  ]
183
  }
184
  ],
 
214
  },
215
  {
216
  "cell_type": "code",
217
+ "execution_count": 9,
218
  "metadata": {},
219
  "outputs": [
220
  {
221
+ "name": "stdout",
222
+ "output_type": "stream",
223
+ "text": [
224
+ "What is the year the Milwaukee team was founded?\n",
225
+ "SELECT year_founded FROM team WHERE city = 'Milwaukee';\n",
226
+ "1968.0\n",
227
+ "SQLite:\n",
228
+ "SELECT year_founded FROM team WHERE full_name = 'Milwaukee Bucks';\n",
229
+ "\n",
230
+ "Statement valid? True\n",
231
+ "SQLite matched? False\n",
232
+ "Result matched? True\n"
233
  ]
234
  }
235
  ],
 
265
  },
266
  {
267
  "cell_type": "code",
268
+ "execution_count": 10,
269
  "metadata": {},
270
  "outputs": [],
271
  "source": [
 
313
  },
314
  {
315
  "cell_type": "code",
316
+ "execution_count": 11,
317
  "metadata": {},
318
  "outputs": [
319
  {
 
326
  "Completed 200\n",
327
  "\n",
328
  "Less than 90 results:\n",
329
+ "Percent valid: 0.8734693877551021\n",
330
+ "Percent SQLite matched: 0.4448979591836735\n",
331
+ "Percent result matched: 0.6979591836734694\n",
332
  "Dataset length: 245\n"
333
  ]
334
  }
 
350
  "cell_type": "code",
351
  "execution_count": null,
352
  "metadata": {},
353
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  "source": [
355
  "game_queries = pd.read_csv(get_path(\"train-data/queries_from_game.tsv\"), sep='\\t')\n",
356
  "run_evaluation(game_queries, \"Queries from game\")\n",
 
368
  "cell_type": "code",
369
  "execution_count": null,
370
  "metadata": {},
371
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  "source": [
373
  "other_stats_queries = pd.read_csv(get_path(\"train-data/queries_from_other_stats.tsv\"), sep='\\t')\n",
374
  "run_evaluation(other_stats_queries, \"Queries from other stats\")\n",
 
386
  "cell_type": "code",
387
  "execution_count": null,
388
  "metadata": {},
389
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  "source": [
391
  "team_queries = pd.read_csv(get_path(\"train-data/queries_from_team.tsv\"), sep='\\t')\n",
392
  "run_evaluation(team_queries, \"Queries from team\")\n",
 
404
  "cell_type": "code",
405
  "execution_count": null,
406
  "metadata": {},
407
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  "source": [
409
  "join_queries = pd.read_csv(get_path(\"train-data/with_join.tsv\"), sep='\\t')\n",
410
  "run_evaluation(join_queries, \"Queries with join\")\n",
 
422
  "cell_type": "code",
423
  "execution_count": null,
424
  "metadata": {},
425
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  "source": [
427
  "no_join_queries = pd.read_csv(get_path(\"train-data/without_join.tsv\"), sep='\\t')\n",
428
  "run_evaluation(no_join_queries, \"Queries without join\")\n",
 
438
  },
439
  {
440
  "cell_type": "code",
441
+ "execution_count": 12,
442
  "metadata": {},
443
  "outputs": [
444
  {
 
467
  "Completed 1000\n",
468
  "\n",
469
  "All training data results:\n",
470
+ "Percent valid: 0.7097701149425287\n",
471
  "Percent SQLite matched: 0.14367816091954022\n",
472
+ "Percent result matched: 0.3668582375478927\n",
473
  "Dataset length: 1044\n"
474
  ]
475
  }
 
483
  ],
484
  "metadata": {
485
  "kernelspec": {
486
+ "display_name": "Python 3",
487
  "language": "python",
488
  "name": "python3"
489
  },
 
497
  "name": "python",
498
  "nbconvert_exporter": "python",
499
  "pygments_lexer": "ipython3",
500
+ "version": "3.12.6"
501
  }
502
  },
503
  "nbformat": 4,
test_rag.ipynb CHANGED
@@ -375,7 +375,7 @@
375
  " actual_result = \"Error executing query: \" + str(e)\n",
376
  " \n",
377
  " # Compare the ground truth query and expected result to the generated query and actual result.\n",
378
- " valid, sql_matched, result_matched = compare_result(row[\"sql_query\"], row[\"result\"], generated_query)\n",
379
  " print(\"=============================================\")\n",
380
  " print(f\"Overall Valid: {valid}\")\n",
381
  " print(f\"SQL Query Matched: {sql_matched}\")\n",
 
375
  " actual_result = \"Error executing query: \" + str(e)\n",
376
  " \n",
377
  " # Compare the ground truth query and expected result to the generated query and actual result.\n",
378
+ " valid, sql_matched, result_matched = compare_result(cursor, row[\"sql_query\"], row[\"result\"], generated_query)\n",
379
  " print(\"=============================================\")\n",
380
  " print(f\"Overall Valid: {valid}\")\n",
381
  " print(f\"SQL Query Matched: {sql_matched}\")\n",