DeanGumas commited on
Commit
7f1de95
·
1 Parent(s): a30f35d

trying to prompt engineer fixes for other_stats table

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +102 -52
test_pretrained.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [
22
  {
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "What is the average number of tov in home games by the Miami Heat?\n",
30
- "SELECT AVG(tov_home) FROM game WHERE team_name_home = 'Miami Heat';\n",
31
- "14.627184466019418\n"
32
  ]
33
  }
34
  ],
@@ -58,7 +58,7 @@
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 2,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
@@ -83,7 +83,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": null,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
@@ -163,9 +163,9 @@
163
  ");\n",
164
  "\n",
165
  "other_stats Table\n",
166
- "Stores additional game statistics, linked to the game table via game_id.\n",
167
  "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
168
- " \"game_id\" TEXT, -- Unique game identifier (links to \"game\" table)\n",
169
  " \"league_id\" TEXT, -- League identifier\n",
170
  " \"team_id_home\" TEXT, -- Home team identifier\n",
171
  " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
@@ -174,19 +174,20 @@
174
  " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
175
  " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
176
  " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
177
- " \"lead_changes\" INTEGER, -- Number of lead changes in the game\n",
178
  " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
179
  " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
180
- " \"total_turnovers_home\" INTEGER, -- Total turnovers in the game\n",
181
  " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
182
  " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
183
  " \"team_id_away\" TEXT, -- Away team identifier\n",
 
184
  " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
185
  " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
186
  " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
187
  " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
188
  " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
189
- " \"total_turnovers_away\" INTEGER, -- Total turnovers in the game\n",
190
  " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
191
  " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
192
  ");\n",
@@ -228,19 +229,28 @@
228
  "Detroit Pistons|DET\n",
229
  "Charlotte Hornets|CHA\n",
230
  "\n",
231
- "\n",
232
- "\n",
233
  "Query Guidelines\n",
234
- "Use team_name_home and team_name_away to match teams.\n",
235
  "\n",
236
  "To filter by season, use season_id = '2YYYY'.\n",
237
  "\n",
238
- "Example: To get games from 2005, use season_id = '22005'. To get games from 1972, use season_id = \"21972\". To get games from 2015, use season_id = \"22015\".\n",
239
  "\n",
240
- "The game_id column links the game and other_stats tables.\n",
241
  "\n",
242
  "Ensure queries return relevant columns and avoid unnecessary joins.\n",
243
  "\n",
 
 
 
 
 
 
 
 
 
 
 
244
  "Example User Requests and SQLite Queries\n",
245
  "Request:\n",
246
  "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
@@ -250,6 +260,19 @@
250
  "WHERE team_name_home = 'Los Angeles Lakers';\n",
251
  "\n",
252
  "Request:\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  "\"How many points did the Miami Heat score on January 10, 2010?\"\n",
254
  "SQLite:\n",
255
  "SELECT team_name_home, pts_home, team_name_away, pts_away \n",
@@ -258,6 +281,11 @@
258
  "AND (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
259
  "\n",
260
  "Request:\n",
 
 
 
 
 
261
  "\"Which team won the most home games in the 2000 season?\"\n",
262
  "SQLite:\n",
263
  "SELECT team_name_home, COUNT(*) AS wins\n",
@@ -267,7 +295,34 @@
267
  "ORDER BY wins DESC\n",
268
  "LIMIT 1;\n",
269
  "\n",
270
- "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following question: \"\"\""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  ]
272
  },
273
  {
@@ -279,7 +334,7 @@
279
  },
280
  {
281
  "cell_type": "code",
282
- "execution_count": 4,
283
  "metadata": {},
284
  "outputs": [
285
  {
@@ -287,9 +342,9 @@
287
  "output_type": "stream",
288
  "text": [
289
  "SQLite:\n",
290
- "SELECT AVG(tov_home) \n",
291
- "FROM game \n",
292
- "WHERE team_name_home = 'Miami Heat';\n",
293
  "\n"
294
  ]
295
  }
@@ -314,7 +369,7 @@
314
  },
315
  {
316
  "cell_type": "code",
317
- "execution_count": 5,
318
  "metadata": {},
319
  "outputs": [
320
  {
@@ -322,7 +377,7 @@
322
  "output_type": "stream",
323
  "text": [
324
  "cleaned\n",
325
- "(14.627184466019418,)\n"
326
  ]
327
  }
328
  ],
@@ -360,22 +415,20 @@
360
  },
361
  {
362
  "cell_type": "code",
363
- "execution_count": 6,
364
  "metadata": {},
365
  "outputs": [
366
  {
367
  "name": "stdout",
368
  "output_type": "stream",
369
  "text": [
370
- "How many times have the Houston Rockets won an away game while scoring at least 110 points?\n",
371
- "SELECT COUNT(*) FROM game WHERE team_abbreviation_away = 'HOU' AND pts_away >= 110 AND wl_away = 'W';\n",
372
- "425\n",
373
  "SQLite:\n",
374
- "SELECT COUNT(*) \n",
375
  "FROM game \n",
376
- "WHERE team_name_away = 'Houston Rockets' \n",
377
- "AND wl_away = 'W' \n",
378
- "AND pts_away >= 110;\n",
379
  "\n",
380
  "Statement valid? True\n",
381
  "SQLite matched? False\n",
@@ -508,7 +561,7 @@
508
  },
509
  {
510
  "cell_type": "code",
511
- "execution_count": 7,
512
  "metadata": {},
513
  "outputs": [],
514
  "source": [
@@ -556,7 +609,7 @@
556
  },
557
  {
558
  "cell_type": "code",
559
- "execution_count": 8,
560
  "metadata": {},
561
  "outputs": [
562
  {
@@ -564,21 +617,18 @@
564
  "output_type": "stream",
565
  "text": [
566
  "Completed 50\n",
567
- "Completed 100\n",
568
- "Completed 150\n",
569
- "Completed 200\n",
570
  "\n",
571
  "Less than 90 results:\n",
572
- "Percent valid: 0.8612244897959184\n",
573
- "Percent SQLite matched: 0.4163265306122449\n",
574
- "Percent result matched: 0.6530612244897959\n",
575
  "Dataset length: 245\n"
576
  ]
577
  }
578
  ],
579
  "source": [
580
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
581
- "run_evaluation(less_than_90_df, \"Less than 90\")\n",
582
  "print(\"Dataset length: \" + str(len(less_than_90_df)))"
583
  ]
584
  },
@@ -591,7 +641,7 @@
591
  },
592
  {
593
  "cell_type": "code",
594
- "execution_count": 9,
595
  "metadata": {},
596
  "outputs": [
597
  {
@@ -616,9 +666,9 @@
616
  "Completed 800\n",
617
  "\n",
618
  "Queries from game results:\n",
619
- "Percent valid: 0.7708830548926014\n",
620
- "Percent SQLite matched: 0.1431980906921241\n",
621
- "Percent result matched: 0.40692124105011934\n",
622
  "Dataset length: 838\n"
623
  ]
624
  }
@@ -638,7 +688,7 @@
638
  },
639
  {
640
  "cell_type": "code",
641
- "execution_count": 10,
642
  "metadata": {},
643
  "outputs": [
644
  {
@@ -650,9 +700,9 @@
650
  "Completed 150\n",
651
  "\n",
652
  "Queries from other stats results:\n",
653
- "Percent valid: 0.07792207792207792\n",
654
- "Percent SQLite matched: 0.0\n",
655
- "Percent result matched: 0.0\n",
656
  "Dataset length: 154\n"
657
  ]
658
  }
@@ -672,7 +722,7 @@
672
  },
673
  {
674
  "cell_type": "code",
675
- "execution_count": 11,
676
  "metadata": {},
677
  "outputs": [
678
  {
@@ -682,9 +732,9 @@
682
  "Completed 50\n",
683
  "\n",
684
  "Queries from team results:\n",
685
- "Percent valid: 0.75\n",
686
- "Percent SQLite matched: 0.2692307692307692\n",
687
- "Percent result matched: 0.6153846153846154\n",
688
  "Dataset length: 52\n"
689
  ]
690
  }
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 34,
20
  "metadata": {},
21
  "outputs": [
22
  {
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "What was the largest deficit overcome by the Miami Heat in any home victory?\n",
30
+ "SELECT o.largest_lead_away AS max_deficit_overcome FROM game g JOIN other_stats o ON g.game_id = o.game_id WHERE g.team_name_home = 'Miami Heat' AND g.wl_home = 'W' ORDER BY o.largest_lead_away DESC LIMIT 1;\n",
31
+ "46\n"
32
  ]
33
  }
34
  ],
 
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": 35,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": 36,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
 
163
  ");\n",
164
  "\n",
165
  "other_stats Table\n",
166
+ "Stores additional statistics, linked to the game table via game_id.\n",
167
  "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
168
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
169
  " \"league_id\" TEXT, -- League identifier\n",
170
  " \"team_id_home\" TEXT, -- Home team identifier\n",
171
  " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
 
174
  " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
175
  " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
176
  " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
177
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
178
  " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
179
  " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
180
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
181
  " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
182
  " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
183
  " \"team_id_away\" TEXT, -- Away team identifier\n",
184
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
185
  " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
186
  " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
187
  " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
188
  " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
189
  " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
190
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
191
  " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
192
  " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
193
  ");\n",
 
229
  "Detroit Pistons|DET\n",
230
  "Charlotte Hornets|CHA\n",
231
  "\n",
 
 
232
  "Query Guidelines\n",
233
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
234
  "\n",
235
  "To filter by season, use season_id = '2YYYY'.\n",
236
  "\n",
237
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
238
  "\n",
239
+ "The game_id column can be used to join the game and other_stats tables.\n",
240
  "\n",
241
  "Ensure queries return relevant columns and avoid unnecessary joins.\n",
242
  "\n",
243
+ "When obtaining certain statistics by team from the game table, use the team_name_home and team_name_away columns. \n",
244
+ "For example, to obtain home game data for the Washington Wizards from the game table use a statement like: team_name_home = 'Washington Wizards'\n",
245
+ "To obtain away game data from the Los Angeles Lakers from the game table use a statement like: team_name_away = 'Los Angeles Lakers'\n",
246
+ "To obtain general game data where home or away is not specified for the Chicago Bulls from the game table, use a statement like: (team_name_home = 'Chicago Bulls' OR team_name_away = 'Chicago Bulls')\n",
247
+ "\n",
248
+ "When obtaining certain statistics by team from the other_stats table, use the team_abbreviation_home and team_abbreviation away columns.\n",
249
+ "For example, to obtain home statistics from the Charlotte Hornets from the other_stats table use a statement like: team_abbreviation_home = 'CHA'\n",
250
+ "To obtain away statistics from the Dallas Mavericks from the other_stats table, use a statement like: team_abbreviation_away = 'DAL'\n",
251
+ "To obtain general statistics from the other_stats table where home or away is not specified for the Detroit Pistons use a statement like: (team_abbreviation_home = 'DET' OR team_abbreviation_away = 'DET)\n",
252
+ "\n",
253
+ "\n",
254
  "Example User Requests and SQLite Queries\n",
255
  "Request:\n",
256
  "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
 
260
  "WHERE team_name_home = 'Los Angeles Lakers';\n",
261
  "\n",
262
  "Request:\n",
263
+ "\"Which teams are located in the state of California?\"\n",
264
+ "SQLite:\n",
265
+ "SELECT full_name FROM team WHERE state = 'California';\n",
266
+ "\n",
267
+ "Request:\n",
268
+ "\"How many total team rebounds did the Los Angeles Clippers have in away games where they scored over 15 fast break points?\"\n",
269
+ "SQLite:\n",
270
+ "SELECT SUM(os.team_rebounds_away) \n",
271
+ "FROM other_stats os \n",
272
+ "JOIN game g ON os.game_id = g.game_id \n",
273
+ "WHERE g.team_abbreviation_away = 'LAC' AND os.pts_fb_away > 15;\n",
274
+ "\n",
275
+ "Request:\n",
276
  "\"How many points did the Miami Heat score on January 10, 2010?\"\n",
277
  "SQLite:\n",
278
  "SELECT team_name_home, pts_home, team_name_away, pts_away \n",
 
281
  "AND (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
282
  "\n",
283
  "Request:\n",
284
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
285
+ "SQLite:\n",
286
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
287
+ "\n",
288
+ "Request:\n",
289
  "\"Which team won the most home games in the 2000 season?\"\n",
290
  "SQLite:\n",
291
  "SELECT team_name_home, COUNT(*) AS wins\n",
 
295
  "ORDER BY wins DESC\n",
296
  "LIMIT 1;\n",
297
  "\n",
298
+ "Request:\n",
299
+ "\"Which teams were founded before 1979?\"\n",
300
+ "SQLite:\n",
301
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
302
+ "\n",
303
+ "Request:\n",
304
+ "\"Which game had the most lead changes in the 2020 season?\"\n",
305
+ "SQLite:\n",
306
+ "SELECT game_id, lead_changes \n",
307
+ "FROM other_stats \n",
308
+ "WHERE game_id IN \n",
309
+ "(SELECT game_id FROM game WHERE season_id = '22020')\n",
310
+ "ORDER BY lead_changes DESC LIMIT 1;\n",
311
+ "\n",
312
+ "Request:\n",
313
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
314
+ "SQLite:\n",
315
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
316
+ "FROM game\n",
317
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
318
+ "\n",
319
+ "Request:\n",
320
+ "\"How many fast break points did the Atlanta Hawks score at home?\"\n",
321
+ "SQLite:\n",
322
+ "SELECT SUM(pts_fb_home) as total_fb_points FROM other_stats WHERE team_abbreviation_home = 'ATL';\n",
323
+ "\n",
324
+ "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
325
+ "\"\"\""
326
  ]
327
  },
328
  {
 
334
  },
335
  {
336
  "cell_type": "code",
337
+ "execution_count": 37,
338
  "metadata": {},
339
  "outputs": [
340
  {
 
342
  "output_type": "stream",
343
  "text": [
344
  "SQLite:\n",
345
+ "SELECT MAX(pts_home - pts_away) AS largest_deficit\n",
346
+ "FROM game\n",
347
+ "WHERE wl_home = 'W' AND team_name_home = 'Miami Heat';\n",
348
  "\n"
349
  ]
350
  }
 
369
  },
370
  {
371
  "cell_type": "code",
372
+ "execution_count": 38,
373
  "metadata": {},
374
  "outputs": [
375
  {
 
377
  "output_type": "stream",
378
  "text": [
379
  "cleaned\n",
380
+ "(43.0,)\n"
381
  ]
382
  }
383
  ],
 
415
  },
416
  {
417
  "cell_type": "code",
418
+ "execution_count": 39,
419
  "metadata": {},
420
  "outputs": [
421
  {
422
  "name": "stdout",
423
  "output_type": "stream",
424
  "text": [
425
+ "How many home games did the Chicago Bulls play in the 2020 season?\n",
426
+ "SELECT COUNT(*) FROM game WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';\n",
427
+ "36.0\n",
428
  "SQLite:\n",
429
+ "SELECT COUNT(*) as total_home_games \n",
430
  "FROM game \n",
431
+ "WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';\n",
 
 
432
  "\n",
433
  "Statement valid? True\n",
434
  "SQLite matched? False\n",
 
561
  },
562
  {
563
  "cell_type": "code",
564
+ "execution_count": 40,
565
  "metadata": {},
566
  "outputs": [],
567
  "source": [
 
609
  },
610
  {
611
  "cell_type": "code",
612
+ "execution_count": 41,
613
  "metadata": {},
614
  "outputs": [
615
  {
 
617
  "output_type": "stream",
618
  "text": [
619
  "Completed 50\n",
 
 
 
620
  "\n",
621
  "Less than 90 results:\n",
622
+ "Percent valid: 0.62\n",
623
+ "Percent SQLite matched: 0.12\n",
624
+ "Percent result matched: 0.4\n",
625
  "Dataset length: 245\n"
626
  ]
627
  }
628
  ],
629
  "source": [
630
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
631
+ "run_evaluation(less_than_90_df.sample(n=50), \"Less than 90\")\n",
632
  "print(\"Dataset length: \" + str(len(less_than_90_df)))"
633
  ]
634
  },
 
641
  },
642
  {
643
  "cell_type": "code",
644
+ "execution_count": 25,
645
  "metadata": {},
646
  "outputs": [
647
  {
 
666
  "Completed 800\n",
667
  "\n",
668
  "Queries from game results:\n",
669
+ "Percent valid: 0.6181384248210023\n",
670
+ "Percent SQLite matched: 0.015513126491646777\n",
671
+ "Percent result matched: 0.24343675417661098\n",
672
  "Dataset length: 838\n"
673
  ]
674
  }
 
688
  },
689
  {
690
  "cell_type": "code",
691
+ "execution_count": 23,
692
  "metadata": {},
693
  "outputs": [
694
  {
 
700
  "Completed 150\n",
701
  "\n",
702
  "Queries from other stats results:\n",
703
+ "Percent valid: 0.6168831168831169\n",
704
+ "Percent SQLite matched: 0.06493506493506493\n",
705
+ "Percent result matched: 0.34415584415584416\n",
706
  "Dataset length: 154\n"
707
  ]
708
  }
 
722
  },
723
  {
724
  "cell_type": "code",
725
+ "execution_count": 24,
726
  "metadata": {},
727
  "outputs": [
728
  {
 
732
  "Completed 50\n",
733
  "\n",
734
  "Queries from team results:\n",
735
+ "Percent valid: 0.8846153846153846\n",
736
+ "Percent SQLite matched: 0.6346153846153846\n",
737
+ "Percent result matched: 0.8269230769230769\n",
738
  "Dataset length: 52\n"
739
  ]
740
  }