DeanGumas commited on
Commit
f2c9e91
·
1 Parent(s): 506f5a9

updating prompt to include team name and abbreviations

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +73 -28
test_pretrained.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 7,
20
  "metadata": {},
21
  "outputs": [
22
  {
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "List the full names of all teams founded in the 1980s.\n",
30
- "SELECT full_name FROM team WHERE year_founded BETWEEN 1980 AND 1989;\n",
31
- "Dallas Mavericks, Miami Heat, Minnesota Timberwolves, Orlando Magic, Charlotte Hornets\n"
32
  ]
33
  }
34
  ],
@@ -56,9 +56,18 @@
56
  },
57
  {
58
  "cell_type": "code",
59
- "execution_count": 8,
60
  "metadata": {},
61
- "outputs": [],
 
 
 
 
 
 
 
 
 
62
  "source": [
63
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
64
  "import torch\n",
@@ -80,7 +89,7 @@
80
  },
81
  {
82
  "cell_type": "code",
83
- "execution_count": 9,
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
@@ -189,6 +198,44 @@
189
  ");\n",
190
  "\n",
191
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  "Query Guidelines\n",
193
  "Use team_name_home and team_name_away to match teams.\n",
194
  "\n",
@@ -238,7 +285,7 @@
238
  },
239
  {
240
  "cell_type": "code",
241
- "execution_count": 10,
242
  "metadata": {},
243
  "outputs": [
244
  {
@@ -248,7 +295,10 @@
248
  "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
249
  " warnings.warn(\n",
250
  "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
251
- "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n"
 
 
 
252
  ]
253
  },
254
  {
@@ -256,9 +306,9 @@
256
  "output_type": "stream",
257
  "text": [
258
  "SQLite:\n",
259
- "SELECT full_name \n",
260
- "FROM team \n",
261
- "WHERE year_founded BETWEEN 1980 AND 1989;\n",
262
  "\n"
263
  ]
264
  }
@@ -283,7 +333,7 @@
283
  },
284
  {
285
  "cell_type": "code",
286
- "execution_count": 11,
287
  "metadata": {},
288
  "outputs": [
289
  {
@@ -291,11 +341,7 @@
291
  "output_type": "stream",
292
  "text": [
293
  "cleaned\n",
294
- "('Dallas Mavericks',)\n",
295
- "('Miami Heat',)\n",
296
- "('Minnesota Timberwolves',)\n",
297
- "('Orlando Magic',)\n",
298
- "('Charlotte Hornets',)\n"
299
  ]
300
  }
301
  ],
@@ -329,7 +375,7 @@
329
  },
330
  {
331
  "cell_type": "code",
332
- "execution_count": 76,
333
  "metadata": {},
334
  "outputs": [
335
  {
@@ -344,18 +390,17 @@
344
  "name": "stdout",
345
  "output_type": "stream",
346
  "text": [
347
- "How many games did the Indiana Pacers win at home with more than 15 fast break points in 1996?\n",
348
- "SELECT COUNT(*) as wins FROM other_stats os JOIN game g ON os.game_id = g.game_id WHERE g.team_name_home = 'Indiana Pacers' AND g.wl_home = 'W' AND os.pts_fb_home > 15 AND g.season_id = '21996';\n",
349
- "7.0\n",
350
  "SQLite:\n",
351
- "SELECT COUNT(*) \n",
352
  "FROM game \n",
353
- "WHERE wl_home = 'W' \n",
354
- "AND season_id = '2196' \n",
355
- "AND pts_fb_home > 15;\n",
356
  "\n",
357
- "SQL matched? False\n",
358
- "Result matched? False\n"
 
359
  ]
360
  }
361
  ],
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [
22
  {
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "What is the highest combined pts in any game involving the Miami Heat?\n",
30
+ "SELECT MAX(pts_home + pts_away) FROM game WHERE team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat';\n",
31
+ "290.0\n"
32
  ]
33
  }
34
  ],
 
56
  },
57
  {
58
  "cell_type": "code",
59
+ "execution_count": 2,
60
  "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stderr",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
67
+ " from .autonotebook import tqdm as notebook_tqdm\n"
68
+ ]
69
+ }
70
+ ],
71
  "source": [
72
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
73
  "import torch\n",
 
89
  },
90
  {
91
  "cell_type": "code",
92
+ "execution_count": 3,
93
  "metadata": {},
94
  "outputs": [],
95
  "source": [
 
198
  ");\n",
199
  "\n",
200
  "\n",
201
+ "Team Name Information\n",
202
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
203
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
204
+ "Notice they are separated by the | character in the following list:\n",
205
+ "\n",
206
+ "Atlanta Hawks|ATL\n",
207
+ "Boston Celtics|BOS\n",
208
+ "Cleveland Cavaliers|CLE\n",
209
+ "New Orleans Pelicans|NOP\n",
210
+ "Chicago Bulls|CHI\n",
211
+ "Dallas Mavericks|DAL\n",
212
+ "Denver Nuggets|DEN\n",
213
+ "Golden State Warriors|GSW\n",
214
+ "Houston Rockets|HOU\n",
215
+ "Los Angeles Clippers|LAC\n",
216
+ "Los Angeles Lakers|LAL\n",
217
+ "Miami Heat|MIA\n",
218
+ "Milwaukee Bucks|MIL\n",
219
+ "Minnesota Timberwolves|MIN\n",
220
+ "Brooklyn Nets|BKN\n",
221
+ "New York Knicks|NYK\n",
222
+ "Orlando Magic|ORL\n",
223
+ "Indiana Pacers|IND\n",
224
+ "Philadelphia 76ers|PHI\n",
225
+ "Phoenix Suns|PHX\n",
226
+ "Portland Trail Blazers|POR\n",
227
+ "Sacramento Kings|SAC\n",
228
+ "San Antonio Spurs|SAS\n",
229
+ "Oklahoma City Thunder|OKC\n",
230
+ "Toronto Raptors|TOR\n",
231
+ "Utah Jazz|UTA\n",
232
+ "Memphis Grizzlies|MEM\n",
233
+ "Washington Wizards|WAS\n",
234
+ "Detroit Pistons|DET\n",
235
+ "Charlotte Hornets|CHA\n",
236
+ "\n",
237
+ "\n",
238
+ "\n",
239
  "Query Guidelines\n",
240
  "Use team_name_home and team_name_away to match teams.\n",
241
  "\n",
 
285
  },
286
  {
287
  "cell_type": "code",
288
+ "execution_count": 4,
289
  "metadata": {},
290
  "outputs": [
291
  {
 
295
  "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
296
  " warnings.warn(\n",
297
  "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
298
+ "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n",
299
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
300
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\integrations\\sdpa_attention.py:53: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n",
301
+ " attn_output = torch.nn.functional.scaled_dot_product_attention(\n"
302
  ]
303
  },
304
  {
 
306
  "output_type": "stream",
307
  "text": [
308
  "SQLite:\n",
309
+ "SELECT MAX(pts_home + pts_away) \n",
310
+ "FROM game \n",
311
+ "WHERE (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
312
  "\n"
313
  ]
314
  }
 
333
  },
334
  {
335
  "cell_type": "code",
336
+ "execution_count": 5,
337
  "metadata": {},
338
  "outputs": [
339
  {
 
341
  "output_type": "stream",
342
  "text": [
343
  "cleaned\n",
344
+ "(290.0,)\n"
 
 
 
 
345
  ]
346
  }
347
  ],
 
375
  },
376
  {
377
  "cell_type": "code",
378
+ "execution_count": 16,
379
  "metadata": {},
380
  "outputs": [
381
  {
 
390
  "name": "stdout",
391
  "output_type": "stream",
392
  "text": [
393
+ "What is the average number of reb in away games by the Detroit Pistons?\n",
394
+ "SELECT AVG(reb_away) FROM game WHERE team_name_away = 'Detroit Pistons';\n",
395
+ "42.10948081264108\n",
396
  "SQLite:\n",
397
+ "SELECT AVG(reb_away) \n",
398
  "FROM game \n",
399
+ "WHERE team_name_away = 'Detroit Pistons';\n",
 
 
400
  "\n",
401
+ "[(42.10948081264108,)]\n",
402
+ "SQL matched? True\n",
403
+ "Result matched? True\n"
404
  ]
405
  }
406
  ],