DeanGumas commited on
Commit
7dc8863
·
1 Parent(s): f2c9e91

fixed issue with inconsistent formatting in training data

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +37 -19
test_pretrained.ipynb CHANGED
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "What is the highest combined pts in any game involving the Miami Heat?\n",
30
- "SELECT MAX(pts_home + pts_away) FROM game WHERE team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat';\n",
31
- "290.0\n"
32
  ]
33
  }
34
  ],
@@ -241,7 +241,7 @@
241
  "\n",
242
  "To filter by season, use season_id = '2YYYY'.\n",
243
  "\n",
244
- "Example: To get games from 2005, use season_id = '22005'.\n",
245
  "\n",
246
  "The game_id column links the game and other_stats tables.\n",
247
  "\n",
@@ -306,9 +306,9 @@
306
  "output_type": "stream",
307
  "text": [
308
  "SQLite:\n",
309
- "SELECT MAX(pts_home + pts_away) \n",
310
- "FROM game \n",
311
- "WHERE (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
312
  "\n"
313
  ]
314
  }
@@ -340,8 +340,18 @@
340
  "name": "stdout",
341
  "output_type": "stream",
342
  "text": [
343
- "cleaned\n",
344
- "(290.0,)\n"
 
 
 
 
 
 
 
 
 
 
345
  ]
346
  }
347
  ],
@@ -375,7 +385,7 @@
375
  },
376
  {
377
  "cell_type": "code",
378
- "execution_count": 16,
379
  "metadata": {},
380
  "outputs": [
381
  {
@@ -390,15 +400,15 @@
390
  "name": "stdout",
391
  "output_type": "stream",
392
  "text": [
393
- "What is the average number of reb in away games by the Detroit Pistons?\n",
394
- "SELECT AVG(reb_away) FROM game WHERE team_name_away = 'Detroit Pistons';\n",
395
- "42.10948081264108\n",
396
  "SQLite:\n",
397
- "SELECT AVG(reb_away) \n",
398
  "FROM game \n",
399
- "WHERE team_name_away = 'Detroit Pistons';\n",
400
  "\n",
401
- "[(42.10948081264108,)]\n",
402
  "SQL matched? True\n",
403
  "Result matched? True\n"
404
  ]
@@ -426,8 +436,13 @@
426
  " query_match = (query == sample_query)\n",
427
  "\n",
428
  " # Check if this is a multi-line query\n",
429
- " if \"|\" in sample_result:\n",
430
- " result_list = sample_result.split(\"|\") \n",
 
 
 
 
 
431
  " for i in range(len(result_list)):\n",
432
  " result_list[i] = str(result_list[i]).strip()\n",
433
  " result = False\n",
@@ -435,7 +450,10 @@
435
  " for r in row:\n",
436
  " if str(r) in result_list:\n",
437
  " return query_match, True\n",
438
- " print(rows)\n",
 
 
 
439
  " return query_match, result\n",
440
  " else:\n",
441
  " print(rows)\n",
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "What was the largest lead the Golden State Warriors had in a game during the 2018 season?\n",
30
+ "SELECT MAX(other_stats.largest_lead_home) FROM other_stats JOIN game ON other_stats.game_id = game.game_id WHERE game.team_name_home = 'Golden State Warriors' AND game.season_id = '22018';\n",
31
+ "44\n"
32
  ]
33
  }
34
  ],
 
241
  "\n",
242
  "To filter by season, use season_id = '2YYYY'.\n",
243
  "\n",
244
+ "Example: To get games from 2005, use season_id = '22005'. To get games from 1972, use season_id = \"21972\". To get games from 2015, use season_id = \"22015\".\n",
245
  "\n",
246
  "The game_id column links the game and other_stats tables.\n",
247
  "\n",
 
306
  "output_type": "stream",
307
  "text": [
308
  "SQLite:\n",
309
+ "SELECT MAX(largest_lead_home) \n",
310
+ "FROM other_stats \n",
311
+ "WHERE team_name_home = 'Golden State Warriors' AND season_id = '22018';\n",
312
  "\n"
313
  ]
314
  }
 
340
  "name": "stdout",
341
  "output_type": "stream",
342
  "text": [
343
+ "cleaned\n"
344
+ ]
345
+ },
346
+ {
347
+ "ename": "OperationalError",
348
+ "evalue": "no such column: team_name_home",
349
+ "output_type": "error",
350
+ "traceback": [
351
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
352
+ "\u001b[1;31mOperationalError\u001b[0m Traceback (most recent call last)",
353
+ "Cell \u001b[1;32mIn[5], line 15\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 14\u001b[0m query \u001b[38;5;241m=\u001b[39m query_output\n\u001b[1;32m---> 15\u001b[0m \u001b[43mcursor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 16\u001b[0m rows \u001b[38;5;241m=\u001b[39m cursor\u001b[38;5;241m.\u001b[39mfetchall()\n\u001b[0;32m 17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m rows:\n",
354
+ "\u001b[1;31mOperationalError\u001b[0m: no such column: team_name_home"
355
  ]
356
  }
357
  ],
 
385
  },
386
  {
387
  "cell_type": "code",
388
+ "execution_count": 17,
389
  "metadata": {},
390
  "outputs": [
391
  {
 
400
  "name": "stdout",
401
  "output_type": "stream",
402
  "text": [
403
+ "How many games had at least one team with 30+ assists?\n",
404
+ "SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30;\n",
405
+ "11305\n",
406
  "SQLite:\n",
407
+ "SELECT COUNT(*) \n",
408
  "FROM game \n",
409
+ "WHERE ast_home >= 30 OR ast_away >= 30;\n",
410
  "\n",
411
+ "[(11305,)]\n",
412
  "SQL matched? True\n",
413
  "Result matched? True\n"
414
  ]
 
436
  " query_match = (query == sample_query)\n",
437
  "\n",
438
  " # Check if this is a multi-line query\n",
439
+ " if \"|\" in sample_result or \"(\" in sample_result:\n",
440
+ " if \"(\" in sample_result:\n",
441
+ " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
442
+ " result_list = sample_result.split(\",\") \n",
443
+ " else:\n",
444
+ " result_list = sample_result.split(\"|\") \n",
445
+ "\n",
446
  " for i in range(len(result_list)):\n",
447
  " result_list[i] = str(result_list[i]).strip()\n",
448
  " result = False\n",
 
450
  " for r in row:\n",
451
  " if str(r) in result_list:\n",
452
  " return query_match, True\n",
453
+ " if len(rows) == 1:\n",
454
+ " for r in rows[0]:\n",
455
+ " if r == str(len(result_list)):\n",
456
+ " return query_match, True\n",
457
  " return query_match, result\n",
458
  " else:\n",
459
  " print(rows)\n",