DeanGumas commited on
Commit
32c1934
·
1 Parent(s): 7f1de95

ran full testing with updated prompt, improves team and other_stats performance, slight drop on game performance

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +52 -103
test_pretrained.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 34,
20
  "metadata": {},
21
  "outputs": [
22
  {
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "What was the largest deficit overcome by the Miami Heat in any home victory?\n",
30
- "SELECT o.largest_lead_away AS max_deficit_overcome FROM game g JOIN other_stats o ON g.game_id = o.game_id WHERE g.team_name_home = 'Miami Heat' AND g.wl_home = 'W' ORDER BY o.largest_lead_away DESC LIMIT 1;\n",
31
- "46\n"
32
  ]
33
  }
34
  ],
@@ -58,7 +58,7 @@
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 35,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
@@ -83,7 +83,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 36,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
@@ -236,21 +236,8 @@
236
  "\n",
237
  "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
238
  "\n",
239
- "The game_id column can be used to join the game and other_stats tables.\n",
240
- "\n",
241
  "Ensure queries return relevant columns and avoid unnecessary joins.\n",
242
  "\n",
243
- "When obtaining certain statistics by team from the game table, use the team_name_home and team_name_away columns. \n",
244
- "For example, to obtain home game data for the Washington Wizards from the game table use a statement like: team_name_home = 'Washington Wizards'\n",
245
- "To obtain away game data from the Los Angeles Lakers from the game table use a statement like: team_name_away = 'Los Angeles Lakers'\n",
246
- "To obtain general game data where home or away is not specified for the Chicago Bulls from the game table, use a statement like: (team_name_home = 'Chicago Bulls' OR team_name_away = 'Chicago Bulls')\n",
247
- "\n",
248
- "When obtaining certain statistics by team from the other_stats table, use the team_abbreviation_home and team_abbreviation away columns.\n",
249
- "For example, to obtain home statistics from the Charlotte Hornets from the other_stats table use a statement like: team_abbreviation_home = 'CHA'\n",
250
- "To obtain away statistics from the Dallas Mavericks from the other_stats table, use a statement like: team_abbreviation_away = 'DAL'\n",
251
- "To obtain general statistics from the other_stats table where home or away is not specified for the Detroit Pistons use a statement like: (team_abbreviation_home = 'DET' OR team_abbreviation_away = 'DET)\n",
252
- "\n",
253
- "\n",
254
  "Example User Requests and SQLite Queries\n",
255
  "Request:\n",
256
  "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
@@ -265,62 +252,22 @@
265
  "SELECT full_name FROM team WHERE state = 'California';\n",
266
  "\n",
267
  "Request:\n",
268
- "\"How many total team rebounds did the Los Angeles Clippers have in away games where they scored over 15 fast break points?\"\n",
269
- "SQLite:\n",
270
- "SELECT SUM(os.team_rebounds_away) \n",
271
- "FROM other_stats os \n",
272
- "JOIN game g ON os.game_id = g.game_id \n",
273
- "WHERE g.team_abbreviation_away = 'LAC' AND os.pts_fb_away > 15;\n",
274
- "\n",
275
- "Request:\n",
276
- "\"How many points did the Miami Heat score on January 10, 2010?\"\n",
277
- "SQLite:\n",
278
- "SELECT team_name_home, pts_home, team_name_away, pts_away \n",
279
- "FROM game \n",
280
- "WHERE DATE(game_date) = '2010-01-10' \n",
281
- "AND (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
282
- "\n",
283
- "Request:\n",
284
  "\"Which team had the highest number of team turnovers in an away game?\"\n",
285
  "SQLite:\n",
286
  "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
287
  "\n",
288
  "Request:\n",
289
- "\"Which team won the most home games in the 2000 season?\"\n",
290
- "SQLite:\n",
291
- "SELECT team_name_home, COUNT(*) AS wins\n",
292
- "FROM game\n",
293
- "WHERE wl_home = 'W' AND season_id = '22000'\n",
294
- "GROUP BY team_name_home\n",
295
- "ORDER BY wins DESC\n",
296
- "LIMIT 1;\n",
297
- "\n",
298
- "Request:\n",
299
  "\"Which teams were founded before 1979?\"\n",
300
  "SQLite:\n",
301
  "SELECT full_name FROM team WHERE year_founded < 1979;\n",
302
  "\n",
303
  "Request:\n",
304
- "\"Which game had the most lead changes in the 2020 season?\"\n",
305
- "SQLite:\n",
306
- "SELECT game_id, lead_changes \n",
307
- "FROM other_stats \n",
308
- "WHERE game_id IN \n",
309
- "(SELECT game_id FROM game WHERE season_id = '22020')\n",
310
- "ORDER BY lead_changes DESC LIMIT 1;\n",
311
- "\n",
312
- "Request:\n",
313
  "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
314
  "SQLite:\n",
315
  "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
316
  "FROM game\n",
317
  "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
318
  "\n",
319
- "Request:\n",
320
- "\"How many fast break points did the Atlanta Hawks score at home?\"\n",
321
- "SQLite:\n",
322
- "SELECT SUM(pts_fb_home) as total_fb_points FROM other_stats WHERE team_abbreviation_home = 'ATL';\n",
323
- "\n",
324
  "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
325
  "\"\"\""
326
  ]
@@ -334,7 +281,7 @@
334
  },
335
  {
336
  "cell_type": "code",
337
- "execution_count": 37,
338
  "metadata": {},
339
  "outputs": [
340
  {
@@ -342,9 +289,9 @@
342
  "output_type": "stream",
343
  "text": [
344
  "SQLite:\n",
345
- "SELECT MAX(pts_home - pts_away) AS largest_deficit\n",
346
- "FROM game\n",
347
- "WHERE wl_home = 'W' AND team_name_home = 'Miami Heat';\n",
348
  "\n"
349
  ]
350
  }
@@ -369,15 +316,14 @@
369
  },
370
  {
371
  "cell_type": "code",
372
- "execution_count": 38,
373
  "metadata": {},
374
  "outputs": [
375
  {
376
  "name": "stdout",
377
  "output_type": "stream",
378
  "text": [
379
- "cleaned\n",
380
- "(43.0,)\n"
381
  ]
382
  }
383
  ],
@@ -415,24 +361,24 @@
415
  },
416
  {
417
  "cell_type": "code",
418
- "execution_count": 39,
419
  "metadata": {},
420
  "outputs": [
421
  {
422
  "name": "stdout",
423
  "output_type": "stream",
424
  "text": [
425
- "How many home games did the Chicago Bulls play in the 2020 season?\n",
426
- "SELECT COUNT(*) FROM game WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';\n",
427
- "36.0\n",
428
  "SQLite:\n",
429
- "SELECT COUNT(*) as total_home_games \n",
430
- "FROM game \n",
431
- "WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';\n",
432
  "\n",
433
- "Statement valid? True\n",
434
  "SQLite matched? False\n",
435
- "Result matched? True\n"
436
  ]
437
  }
438
  ],
@@ -561,7 +507,7 @@
561
  },
562
  {
563
  "cell_type": "code",
564
- "execution_count": 40,
565
  "metadata": {},
566
  "outputs": [],
567
  "source": [
@@ -609,7 +555,7 @@
609
  },
610
  {
611
  "cell_type": "code",
612
- "execution_count": 41,
613
  "metadata": {},
614
  "outputs": [
615
  {
@@ -617,18 +563,21 @@
617
  "output_type": "stream",
618
  "text": [
619
  "Completed 50\n",
 
 
 
620
  "\n",
621
  "Less than 90 results:\n",
622
- "Percent valid: 0.62\n",
623
- "Percent SQLite matched: 0.12\n",
624
- "Percent result matched: 0.4\n",
625
  "Dataset length: 245\n"
626
  ]
627
  }
628
  ],
629
  "source": [
630
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
631
- "run_evaluation(less_than_90_df.sample(n=50), \"Less than 90\")\n",
632
  "print(\"Dataset length: \" + str(len(less_than_90_df)))"
633
  ]
634
  },
@@ -641,7 +590,7 @@
641
  },
642
  {
643
  "cell_type": "code",
644
- "execution_count": 25,
645
  "metadata": {},
646
  "outputs": [
647
  {
@@ -666,9 +615,9 @@
666
  "Completed 800\n",
667
  "\n",
668
  "Queries from game results:\n",
669
- "Percent valid: 0.6181384248210023\n",
670
- "Percent SQLite matched: 0.015513126491646777\n",
671
- "Percent result matched: 0.24343675417661098\n",
672
  "Dataset length: 838\n"
673
  ]
674
  }
@@ -688,7 +637,7 @@
688
  },
689
  {
690
  "cell_type": "code",
691
- "execution_count": 23,
692
  "metadata": {},
693
  "outputs": [
694
  {
@@ -700,9 +649,9 @@
700
  "Completed 150\n",
701
  "\n",
702
  "Queries from other stats results:\n",
703
- "Percent valid: 0.6168831168831169\n",
704
- "Percent SQLite matched: 0.06493506493506493\n",
705
- "Percent result matched: 0.34415584415584416\n",
706
  "Dataset length: 154\n"
707
  ]
708
  }
@@ -722,7 +671,7 @@
722
  },
723
  {
724
  "cell_type": "code",
725
- "execution_count": 24,
726
  "metadata": {},
727
  "outputs": [
728
  {
@@ -732,9 +681,9 @@
732
  "Completed 50\n",
733
  "\n",
734
  "Queries from team results:\n",
735
- "Percent valid: 0.8846153846153846\n",
736
- "Percent SQLite matched: 0.6346153846153846\n",
737
- "Percent result matched: 0.8269230769230769\n",
738
  "Dataset length: 52\n"
739
  ]
740
  }
@@ -754,7 +703,7 @@
754
  },
755
  {
756
  "cell_type": "code",
757
- "execution_count": 12,
758
  "metadata": {},
759
  "outputs": [
760
  {
@@ -766,9 +715,9 @@
766
  "Completed 150\n",
767
  "\n",
768
  "Queries with join results:\n",
769
- "Percent valid: 0.06486486486486487\n",
770
  "Percent SQLite matched: 0.0\n",
771
- "Percent result matched: 0.010810810810810811\n",
772
  "Dataset length: 185\n"
773
  ]
774
  }
@@ -788,7 +737,7 @@
788
  },
789
  {
790
  "cell_type": "code",
791
- "execution_count": 13,
792
  "metadata": {},
793
  "outputs": [
794
  {
@@ -814,9 +763,9 @@
814
  "Completed 850\n",
815
  "\n",
816
  "Queries without join results:\n",
817
- "Percent valid: 0.7974388824214202\n",
818
- "Percent SQLite matched: 0.1559953434225844\n",
819
- "Percent result matched: 0.4318975552968568\n",
820
  "Dataset length: 859\n"
821
  ]
822
  }
@@ -836,7 +785,7 @@
836
  },
837
  {
838
  "cell_type": "code",
839
- "execution_count": 14,
840
  "metadata": {},
841
  "outputs": [
842
  {
@@ -865,8 +814,8 @@
865
  "Completed 1000\n",
866
  "\n",
867
  "All training data results:\n",
868
- "Percent valid: 0.6676245210727969\n",
869
- "Percent SQLite matched: 0.12835249042145594\n",
870
  "Percent result matched: 0.35823754789272033\n",
871
  "Dataset length: 1044\n"
872
  ]
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [
22
  {
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "Which team committed the fewest total turnovers in an away game that resulted in a win?\n",
30
+ "SELECT team_abbreviation_away FROM other_stats WHERE game_id IN (SELECT game_id FROM game WHERE wl_away = 'W') ORDER BY total_turnovers_away ASC LIMIT 1;\n",
31
+ "PHX\n"
32
  ]
33
  }
34
  ],
 
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": 2,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": 3,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
 
236
  "\n",
237
  "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
238
  "\n",
 
 
239
  "Ensure queries return relevant columns and avoid unnecessary joins.\n",
240
  "\n",
 
 
 
 
 
 
 
 
 
 
 
241
  "Example User Requests and SQLite Queries\n",
242
  "Request:\n",
243
  "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
 
252
  "SELECT full_name FROM team WHERE state = 'California';\n",
253
  "\n",
254
  "Request:\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  "\"Which team had the highest number of team turnovers in an away game?\"\n",
256
  "SQLite:\n",
257
  "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
258
  "\n",
259
  "Request:\n",
 
 
 
 
 
 
 
 
 
 
260
  "\"Which teams were founded before 1979?\"\n",
261
  "SQLite:\n",
262
  "SELECT full_name FROM team WHERE year_founded < 1979;\n",
263
  "\n",
264
  "Request:\n",
 
 
 
 
 
 
 
 
 
265
  "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
266
  "SQLite:\n",
267
  "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
268
  "FROM game\n",
269
  "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
270
  "\n",
 
 
 
 
 
271
  "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
272
  "\"\"\""
273
  ]
 
281
  },
282
  {
283
  "cell_type": "code",
284
+ "execution_count": 4,
285
  "metadata": {},
286
  "outputs": [
287
  {
 
289
  "output_type": "stream",
290
  "text": [
291
  "SQLite:\n",
292
+ "SELECT team_abbreviation_away \n",
293
+ "FROM other_stats \n",
294
+ "WHERE wl_away = 'W' AND total_turnovers_away < (SELECT MIN(total_turnovers_away) FROM other_stats WHERE wl_away = 'L');\n",
295
  "\n"
296
  ]
297
  }
 
316
  },
317
  {
318
  "cell_type": "code",
319
+ "execution_count": 5,
320
  "metadata": {},
321
  "outputs": [
322
  {
323
  "name": "stdout",
324
  "output_type": "stream",
325
  "text": [
326
+ "cleaned\n"
 
327
  ]
328
  }
329
  ],
 
361
  },
362
  {
363
  "cell_type": "code",
364
+ "execution_count": 6,
365
  "metadata": {},
366
  "outputs": [
367
  {
368
  "name": "stdout",
369
  "output_type": "stream",
370
  "text": [
371
+ "What is the largest lead the Minnesota Timberwolves had at home?\n",
372
+ "SELECT MAX(largest_lead_home) as max_lead FROM other_stats WHERE team_abbreviation_home = 'MIN';\n",
373
+ "48.0\n",
374
  "SQLite:\n",
375
+ "SELECT MAX(largest_lead_home) \n",
376
+ "FROM other_stats \n",
377
+ "WHERE team_name_home = 'Minnesota Timberwolves';\n",
378
  "\n",
379
+ "Statement valid? False\n",
380
  "SQLite matched? False\n",
381
+ "Result matched? False\n"
382
  ]
383
  }
384
  ],
 
507
  },
508
  {
509
  "cell_type": "code",
510
+ "execution_count": 7,
511
  "metadata": {},
512
  "outputs": [],
513
  "source": [
 
555
  },
556
  {
557
  "cell_type": "code",
558
+ "execution_count": 9,
559
  "metadata": {},
560
  "outputs": [
561
  {
 
563
  "output_type": "stream",
564
  "text": [
565
  "Completed 50\n",
566
+ "Completed 100\n",
567
+ "Completed 150\n",
568
+ "Completed 200\n",
569
  "\n",
570
  "Less than 90 results:\n",
571
+ "Percent valid: 0.8448979591836735\n",
572
+ "Percent SQLite matched: 0.43673469387755104\n",
573
+ "Percent result matched: 0.6530612244897959\n",
574
  "Dataset length: 245\n"
575
  ]
576
  }
577
  ],
578
  "source": [
579
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
580
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
581
  "print(\"Dataset length: \" + str(len(less_than_90_df)))"
582
  ]
583
  },
 
590
  },
591
  {
592
  "cell_type": "code",
593
+ "execution_count": 10,
594
  "metadata": {},
595
  "outputs": [
596
  {
 
615
  "Completed 800\n",
616
  "\n",
617
  "Queries from game results:\n",
618
+ "Percent valid: 0.7613365155131265\n",
619
+ "Percent SQLite matched: 0.13842482100238662\n",
620
+ "Percent result matched: 0.383054892601432\n",
621
  "Dataset length: 838\n"
622
  ]
623
  }
 
637
  },
638
  {
639
  "cell_type": "code",
640
+ "execution_count": 11,
641
  "metadata": {},
642
  "outputs": [
643
  {
 
649
  "Completed 150\n",
650
  "\n",
651
  "Queries from other stats results:\n",
652
+ "Percent valid: 0.21428571428571427\n",
653
+ "Percent SQLite matched: 0.01948051948051948\n",
654
+ "Percent result matched: 0.07142857142857142\n",
655
  "Dataset length: 154\n"
656
  ]
657
  }
 
671
  },
672
  {
673
  "cell_type": "code",
674
+ "execution_count": 12,
675
  "metadata": {},
676
  "outputs": [
677
  {
 
681
  "Completed 50\n",
682
  "\n",
683
  "Queries from team results:\n",
684
+ "Percent valid: 0.8653846153846154\n",
685
+ "Percent SQLite matched: 0.5961538461538461\n",
686
+ "Percent result matched: 0.7884615384615384\n",
687
  "Dataset length: 52\n"
688
  ]
689
  }
 
703
  },
704
  {
705
  "cell_type": "code",
706
+ "execution_count": 13,
707
  "metadata": {},
708
  "outputs": [
709
  {
 
715
  "Completed 150\n",
716
  "\n",
717
  "Queries with join results:\n",
718
+ "Percent valid: 0.1945945945945946\n",
719
  "Percent SQLite matched: 0.0\n",
720
+ "Percent result matched: 0.04864864864864865\n",
721
  "Dataset length: 185\n"
722
  ]
723
  }
 
737
  },
738
  {
739
  "cell_type": "code",
740
+ "execution_count": 14,
741
  "metadata": {},
742
  "outputs": [
743
  {
 
763
  "Completed 850\n",
764
  "\n",
765
  "Queries without join results:\n",
766
+ "Percent valid: 0.7916181606519208\n",
767
+ "Percent SQLite matched: 0.17462165308498254\n",
768
+ "Percent result matched: 0.42374854481955765\n",
769
  "Dataset length: 859\n"
770
  ]
771
  }
 
785
  },
786
  {
787
  "cell_type": "code",
788
+ "execution_count": 15,
789
  "metadata": {},
790
  "outputs": [
791
  {
 
814
  "Completed 1000\n",
815
  "\n",
816
  "All training data results:\n",
817
+ "Percent valid: 0.685823754789272\n",
818
+ "Percent SQLite matched: 0.14367816091954022\n",
819
  "Percent result matched: 0.35823754789272033\n",
820
  "Dataset length: 1044\n"
821
  ]