DeanGumas commited on
Commit
62c7f8d
·
1 Parent(s): fdaf162

added python notebook for testing finetuned model

Browse files
Files changed (1) hide show
  1. test_finetuned.ipynb +772 -0
test_finetuned.ipynb ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Run fine-tuned DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## First load dataset into pandas dataframe"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stdout",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "Total dataset examples: 1044\n",
27
+ "\n",
28
+ "\n",
29
+ "In which season did the Chicago Bulls have the highest average fg_pct at home?\n",
30
+ "SELECT season_id, AVG(fg_pct_home) as avg_stat FROM game WHERE team_name_home = 'Chicago Bulls' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1;\n",
31
+ "12022.0\n"
32
+ ]
33
+ }
34
+ ],
35
+ "source": [
36
+ "import pandas as pd \n",
37
+ "import warnings\n",
38
+ "warnings.filterwarnings(\"ignore\")\n",
39
+ "\n",
40
+ "# Load dataset and check length\n",
41
+ "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
42
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
43
+ "print(\"\\n\")\n",
44
+ "\n",
45
+ "# Test sampling\n",
46
+ "sample = df.sample(n=1)\n",
47
+ "print(sample[\"natural_query\"].values[0])\n",
48
+ "print(sample[\"sql_query\"].values[0])\n",
49
+ "print(sample[\"result\"].values[0])"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "## Load fine-tuned DeepSeek model using transformers and pytorch packages"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "cuda\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
74
+ "import torch\n",
75
+ "\n",
76
+ "# Set device to cuda if available, otherwise CPU\n",
77
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
+ "print(device)\n",
79
+ "\n",
80
+ "# Load model and tokenizer\n",
81
+ "tokenizer = AutoTokenizer.from_pretrained(\"./fine-tuned-model\")\n",
82
+ "model = AutoModelForCausalLM.from_pretrained(\"./fine-tuned-model\", torch_dtype=torch.bfloat16, device_map=device) \n",
83
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## Create prompt to setup the model for better performance"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 3,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "input_text = \"\"\"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
100
+ "Database Schema and Explanations\n",
101
+ "\n",
102
+ "team Table\n",
103
+ "Stores information about NBA teams.\n",
104
+ "CREATE TABLE IF NOT EXISTS \"team\" (\n",
105
+ " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
106
+ " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
107
+ " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
108
+ " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
109
+ " \"city\" TEXT, -- City where the team is based\n",
110
+ " \"state\" TEXT, -- State where the team is located\n",
111
+ " \"year_founded\" REAL -- Year the team was established\n",
112
+ ");\n",
113
+ "\n",
114
+ "game Table\n",
115
+ "Contains detailed statistics for each NBA game, including home and away team performance.\n",
116
+ "CREATE TABLE IF NOT EXISTS \"game\" (\n",
117
+ " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
118
+ " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
119
+ " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
120
+ " \"team_name_home\" TEXT, -- Full name of the home team\n",
121
+ " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
122
+ " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
123
+ " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
124
+ " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
125
+ " \"min\" INTEGER, -- Total minutes played in the game\n",
126
+ " \"fgm_home\" REAL, -- Field goals made by the home team\n",
127
+ " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
128
+ " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
129
+ " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
130
+ " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
131
+ " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
132
+ " \"ftm_home\" REAL, -- Free throws made by the home team\n",
133
+ " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
134
+ " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
135
+ " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
136
+ " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
137
+ " \"reb_home\" REAL, -- Total rebounds by the home team\n",
138
+ " \"ast_home\" REAL, -- Assists by the home team\n",
139
+ " \"stl_home\" REAL, -- Steals by the home team\n",
140
+ " \"blk_home\" REAL, -- Blocks by the home team\n",
141
+ " \"tov_home\" REAL, -- Turnovers by the home team\n",
142
+ " \"pf_home\" REAL, -- Personal fouls by the home team\n",
143
+ " \"pts_home\" REAL, -- Total points scored by the home team\n",
144
+ " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
145
+ " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
146
+ " \"team_id_away\" TEXT, -- ID of the away team\n",
147
+ " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
148
+ " \"team_name_away\" TEXT, -- Full name of the away team\n",
149
+ " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
150
+ " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
151
+ " \"fgm_away\" REAL, -- Field goals made by the away team\n",
152
+ " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
153
+ " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
154
+ " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
155
+ " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
156
+ " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
157
+ " \"ftm_away\" REAL, -- Free throws made by the away team\n",
158
+ " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
159
+ " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
160
+ " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
161
+ " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
162
+ " \"reb_away\" REAL, -- Total rebounds by the away team\n",
163
+ " \"ast_away\" REAL, -- Assists by the away team\n",
164
+ " \"stl_away\" REAL, -- Steals by the away team\n",
165
+ " \"blk_away\" REAL, -- Blocks by the away team\n",
166
+ " \"tov_away\" REAL, -- Turnovers by the away team\n",
167
+ " \"pf_away\" REAL, -- Personal fouls by the away team\n",
168
+ " \"pts_away\" REAL, -- Total points scored by the away team\n",
169
+ " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
170
+ " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
171
+ " \"season_type\" TEXT -- Regular season or playoffs\n",
172
+ ");\n",
173
+ "\n",
174
+ "other_stats Table\n",
175
+ "Stores additional statistics, linked to the game table via game_id.\n",
176
+ "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
177
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
178
+ " \"league_id\" TEXT, -- League identifier\n",
179
+ " \"team_id_home\" TEXT, -- Home team identifier\n",
180
+ " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
181
+ " \"team_city_home\" TEXT, -- Home team city\n",
182
+ " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
183
+ " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
184
+ " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
185
+ " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
186
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
187
+ " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
188
+ " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
189
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
190
+ " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
191
+ " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
192
+ " \"team_id_away\" TEXT, -- Away team identifier\n",
193
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
194
+ " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
195
+ " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
196
+ " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
197
+ " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
198
+ " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
199
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
200
+ " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
201
+ " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
202
+ ");\n",
203
+ "\n",
204
+ "\n",
205
+ "Team Name Information\n",
206
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
207
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
208
+ "Notice they are separated by the | character in the following list:\n",
209
+ "\n",
210
+ "Atlanta Hawks|ATL\n",
211
+ "Boston Celtics|BOS\n",
212
+ "Cleveland Cavaliers|CLE\n",
213
+ "New Orleans Pelicans|NOP\n",
214
+ "Chicago Bulls|CHI\n",
215
+ "Dallas Mavericks|DAL\n",
216
+ "Denver Nuggets|DEN\n",
217
+ "Golden State Warriors|GSW\n",
218
+ "Houston Rockets|HOU\n",
219
+ "Los Angeles Clippers|LAC\n",
220
+ "Los Angeles Lakers|LAL\n",
221
+ "Miami Heat|MIA\n",
222
+ "Milwaukee Bucks|MIL\n",
223
+ "Minnesota Timberwolves|MIN\n",
224
+ "Brooklyn Nets|BKN\n",
225
+ "New York Knicks|NYK\n",
226
+ "Orlando Magic|ORL\n",
227
+ "Indiana Pacers|IND\n",
228
+ "Philadelphia 76ers|PHI\n",
229
+ "Phoenix Suns|PHX\n",
230
+ "Portland Trail Blazers|POR\n",
231
+ "Sacramento Kings|SAC\n",
232
+ "San Antonio Spurs|SAS\n",
233
+ "Oklahoma City Thunder|OKC\n",
234
+ "Toronto Raptors|TOR\n",
235
+ "Utah Jazz|UTA\n",
236
+ "Memphis Grizzlies|MEM\n",
237
+ "Washington Wizards|WAS\n",
238
+ "Detroit Pistons|DET\n",
239
+ "Charlotte Hornets|CHA\n",
240
+ "\n",
241
+ "Query Guidelines\n",
242
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
243
+ "\n",
244
+ "To filter by season, use season_id = '2YYYY'.\n",
245
+ "\n",
246
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
247
+ "\n",
248
+ "Ensure queries return relevant columns and avoid unnecessary joins.\n",
249
+ "\n",
250
+ "Example User Requests and SQLite Queries\n",
251
+ "Request:\n",
252
+ "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
253
+ "SQLite:\n",
254
+ "SELECT MAX(pts_home) \n",
255
+ "FROM game \n",
256
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
257
+ "\n",
258
+ "Request:\n",
259
+ "\"Which teams are located in the state of California?\"\n",
260
+ "SQLite:\n",
261
+ "SELECT full_name FROM team WHERE state = 'California';\n",
262
+ "\n",
263
+ "Request:\n",
264
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
265
+ "SQLite:\n",
266
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
267
+ "\n",
268
+ "Request:\n",
269
+ "\"Which teams were founded before 1979?\"\n",
270
+ "SQLite:\n",
271
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
272
+ "\n",
273
+ "Request:\n",
274
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
275
+ "SQLite:\n",
276
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
277
+ "FROM game\n",
278
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
279
+ "\n",
280
+ "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
281
+ "\"\"\""
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "markdown",
286
+ "metadata": {},
287
+ "source": [
288
+ "## Test model performance on a single example"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 4,
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "SQLite: SELECT season_id FROM game WHERE team_name_home = 'Chicago Bulls' GROUP BY season_id ORDER BY AVG(fg_pct_home) DESC LIMIT 1;\n",
301
+ "\n"
302
+ ]
303
+ }
304
+ ],
305
+ "source": [
306
+ "# Create message with sample query and run model\n",
307
+ "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
308
+ "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
309
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
310
+ "\n",
311
+ "# Print output\n",
312
+ "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
313
+ "print(query_output)"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": [
320
+ "# Test sample output on sqlite3 database"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 5,
326
+ "metadata": {},
327
+ "outputs": [
328
+ {
329
+ "name": "stdout",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "SELECT season_id FROM game WHERE team_name_home = 'Chicago Bulls' GROUP BY season_id ORDER BY AVG(fg_pct_home) DESC LIMIT 1;\n",
333
+ "('12022',)\n"
334
+ ]
335
+ }
336
+ ],
337
+ "source": [
338
+ "import sqlite3 as sql\n",
339
+ "\n",
340
+ "# Create connection to sqlite3 database\n",
341
+ "connection = sql.connect('./nba-data/nba.sqlite')\n",
342
+ "cursor = connection.cursor()\n",
343
+ "\n",
344
+ "# Execute query from model output and print result\n",
345
+ "if query_output[0:8] == \"SQLite: \":\n",
346
+ " query = query_output[8:]\n",
347
+ "elif query_output[0:5] == \"SQL: \":\n",
348
+ " query = query_output[5:]\n",
349
+ "else:\n",
350
+ " query = query_output\n",
351
+ "\n",
352
+ "for i in range(len(query)):\n",
353
+ " if query[i] == \";\":\n",
354
+ " query = query[:i+1]\n",
355
+ " break\n",
356
+ "\n",
357
+ "print(query)\n",
358
+ "\n",
359
+ "try:\n",
360
+ " cursor.execute(query)\n",
361
+ " rows = cursor.fetchall()\n",
362
+ " for row in rows:\n",
363
+ " print(row)\n",
364
+ "except:\n",
365
+ " pass"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "markdown",
370
+ "metadata": {},
371
+ "source": [
372
+ "## Create function to compare output to ground truth result from examples"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": 6,
378
+ "metadata": {},
379
+ "outputs": [
380
+ {
381
+ "name": "stdout",
382
+ "output_type": "stream",
383
+ "text": [
384
+ "In which season did the Chicago Bulls have the highest average fg_pct at home?\n",
385
+ "SELECT season_id, AVG(fg_pct_home) as avg_stat FROM game WHERE team_name_home = 'Chicago Bulls' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1;\n",
386
+ "12022.0\n",
387
+ "SQLite: SELECT season_id FROM game WHERE team_name_home = 'Chicago Bulls' GROUP BY season_id ORDER BY AVG(fg_pct_home) DESC LIMIT 1;\n",
388
+ "\n",
389
+ "Statement valid? True\n",
390
+ "SQLite matched? False\n",
391
+ "Result matched? True\n"
392
+ ]
393
+ }
394
+ ],
395
+ "source": [
396
+ "import math\n",
397
+ "\n",
398
+ "def compare_result(sample_query, sample_result, query_output):\n",
399
+ " # Clean model output to only have the query output\n",
400
+ " if query_output[0:8] == \"SQLite: \":\n",
401
+ " query = query_output[8:]\n",
402
+ " elif query_output[0:5] == \"SQL: \":\n",
403
+ " query = query_output[5:]\n",
404
+ " else:\n",
405
+ " query = query_output\n",
406
+ "\n",
407
+ " # Clean any excess text after the query semicolon\n",
408
+ " for i in range(len(query)):\n",
409
+ " if query[i] == \";\":\n",
410
+ " query = query[:i+1]\n",
411
+ " break\n",
412
+ " \n",
413
+ " # Try to execute query, if it fails, then this is a failure of the model\n",
414
+ " try:\n",
415
+ " # Execute query and obtain result\n",
416
+ " cursor.execute(query)\n",
417
+ " rows = cursor.fetchall()\n",
418
+ "\n",
419
+ " # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n",
420
+ " query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
421
+ " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
422
+ " query_match = (query == sample_query)\n",
423
+ "\n",
424
+ " # If the queries match, the results clearly also match\n",
425
+ " if query_match:\n",
426
+ " return True, True, True\n",
427
+ "\n",
428
+ " # Check if this is a multi-line query\n",
429
+ " if \"|\" in sample_result or \"(\" in sample_result:\n",
430
+ " #print(rows)\n",
431
+ " # Create list of results by stripping separators and splitting on them\n",
432
+ " if \"(\" in sample_result:\n",
433
+ " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
434
+ " result_list = sample_result.split(\",\") \n",
435
+ " else:\n",
436
+ " result_list = sample_result.split(\"|\") \n",
437
+ "\n",
438
+ " # Strip all results in list\n",
439
+ " for i in range(len(result_list)):\n",
440
+ " result_list[i] = str(result_list[i]).strip()\n",
441
+ " \n",
442
+ " # Loop through model result and see if it matches training example\n",
443
+ " result = False\n",
444
+ " for row in rows:\n",
445
+ " for r in row:\n",
446
+ " for res in result_list:\n",
447
+ " try:\n",
448
+ " if math.isclose(float(r), float(res), abs_tol=0.5):\n",
449
+ " return True, query_match, True\n",
450
+ " except:\n",
451
+ " if r in res or res in r:\n",
452
+ " return True, query_match, True\n",
453
+ " \n",
454
+ " # Check if the model returned a sum of examples as opposed to the whole thing\n",
455
+ " if len(rows) == 1:\n",
456
+ " for r in rows[0]:\n",
457
+ " if r == str(len(result_list)):\n",
458
+ " return True, query_match, True\n",
459
+ " \n",
460
+ " return True, query_match, result\n",
461
+ " # Else the sample result is a single value or string\n",
462
+ " else:\n",
463
+ " #print(rows)\n",
464
+ " result = False\n",
465
+ " # Loop through model result and see if it contains the sample result\n",
466
+ " for row in rows:\n",
467
+ " for r in row:\n",
468
+ " # Check by string\n",
469
+ " if str(r) in str(sample_result):\n",
470
+ " try:\n",
471
+ " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
472
+ " return True, query_match, True\n",
473
+ " except:\n",
474
+ " return True, query_match, True\n",
475
+ " # Check by number, using try incase the cast as float fails\n",
476
+ " try:\n",
477
+ " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
478
+ " return True, query_match, True\n",
479
+ " except:\n",
480
+ " pass\n",
481
+ "\n",
482
+ " # Check if the model returned a list of examples instead of a total sum (both acceptable)\n",
483
+ " try:\n",
484
+ " if len(rows) > 1 and len(rows) == int(sample_result):\n",
485
+ " return True, query_match, True\n",
486
+ " if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):\n",
487
+ " return True, query_match, True\n",
488
+ " except:\n",
489
+ " pass\n",
490
+ "\n",
491
+ " # Compare results and return\n",
492
+ " return True, query_match, result\n",
493
+ " except:\n",
494
+ " return False, False, False\n",
495
+ "\n",
496
+ "# Obtain sample\n",
497
+ "#sample = df.sample(n=1)\n",
498
+ "print(sample[\"natural_query\"].values[0])\n",
499
+ "print(sample[\"sql_query\"].values[0])\n",
500
+ "print(sample[\"result\"].values[0])\n",
501
+ "\n",
502
+ "# Create message with sample query and run model\n",
503
+ "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
504
+ "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
505
+ "outputs = model.generate(inputs, max_new_tokens=256, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
506
+ "\n",
507
+ "# Print output\n",
508
+ "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
509
+ "print(query_output)\n",
510
+ "\n",
511
+ "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
512
+ "print(\"Statement valid? \" + str(result[0]))\n",
513
+ "print(\"SQLite matched? \" + str(result[1]))\n",
514
+ "print(\"Result matched? \" + str(result[2]))"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "markdown",
519
+ "metadata": {},
520
+ "source": [
521
+ "## Create function to evaluate finetuned model on full datasets"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 7,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "def run_evaluation(nba_df, title):\n",
531
+ " counter = 0\n",
532
+ " num_valid = 0\n",
533
+ " num_sql_matched = 0\n",
534
+ " num_result_matched = 0\n",
535
+ " for index, row in nba_df.iterrows():\n",
536
+ " # Create message with sample query and run model\n",
537
+ " message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
538
+ " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
539
+ " outputs = model.generate(inputs, max_new_tokens=128, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
540
+ "\n",
541
+ " # Obtain output\n",
542
+ " query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
543
+ "\n",
544
+ " # Evaluate model result\n",
545
+ " valid, sql_matched, result_matched = compare_result(row[\"sql_query\"], row[\"result\"], query_output)\n",
546
+ " if valid:\n",
547
+ " num_valid += 1\n",
548
+ " if sql_matched:\n",
549
+ " num_sql_matched += 1\n",
550
+ " if result_matched:\n",
551
+ " num_result_matched += 1\n",
552
+ "\n",
553
+ " # Break after predefined number of examples\n",
554
+ " counter += 1\n",
555
+ " if counter % 50 == 0:\n",
556
+ " print(\"Completed \" + str(counter))\n",
557
+ "\n",
558
+ " # Print evaluation results\n",
559
+ " print(\"\\n\" + title + \" results:\")\n",
560
+ " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
561
+ " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
562
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "# Evaluate on less than 90 dataset"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": 8,
575
+ "metadata": {},
576
+ "outputs": [
577
+ {
578
+ "name": "stdout",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "Completed 50\n",
582
+ "Completed 100\n",
583
+ "Completed 150\n",
584
+ "Completed 200\n",
585
+ "\n",
586
+ "Less than 90 results:\n",
587
+ "Percent valid: 0.5183673469387755\n",
588
+ "Percent SQLite matched: 0.2857142857142857\n",
589
+ "Percent result matched: 0.42857142857142855\n",
590
+ "Dataset length: 245\n"
591
+ ]
592
+ }
593
+ ],
594
+ "source": [
595
+ "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
596
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
597
+ "print(\"Dataset length: \" + str(len(less_than_90_df)))"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "markdown",
602
+ "metadata": {},
603
+ "source": [
604
+ "# Evaluate on game table queries"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": 9,
610
+ "metadata": {},
611
+ "outputs": [
612
+ {
613
+ "ename": "KeyboardInterrupt",
614
+ "evalue": "",
615
+ "output_type": "error",
616
+ "traceback": [
617
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
618
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
619
+ "Cell \u001b[1;32mIn[9], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m game_queries \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./train-data/queries_from_game.tsv\u001b[39m\u001b[38;5;124m\"\u001b[39m, sep\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m----> 2\u001b[0m \u001b[43mrun_evaluation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgame_queries\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mQueries from game\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset length: \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mlen\u001b[39m(game_queries)))\n",
620
+ "Cell \u001b[1;32mIn[7], line 10\u001b[0m, in \u001b[0;36mrun_evaluation\u001b[1;34m(nba_df, title)\u001b[0m\n\u001b[0;32m 8\u001b[0m message\u001b[38;5;241m=\u001b[39m[{ \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrole\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcontent\u001b[39m\u001b[38;5;124m'\u001b[39m: input_text \u001b[38;5;241m+\u001b[39m row[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnatural_query\u001b[39m\u001b[38;5;124m\"\u001b[39m]}]\n\u001b[0;32m 9\u001b[0m inputs \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mapply_chat_template(message, add_generation_prompt\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mto(model\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m---> 10\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_p\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.95\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_return_sequences\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meos_token_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 12\u001b[0m \u001b[38;5;66;03m# Obtain output\u001b[39;00m\n\u001b[0;32m 13\u001b[0m query_output \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mdecode(outputs[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;28mlen\u001b[39m(inputs[\u001b[38;5;241m0\u001b[39m]):], skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
621
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
622
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\utils.py:2326\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[1;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)\u001b[0m\n\u001b[0;32m 2318\u001b[0m input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[0;32m 2319\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[0;32m 2320\u001b[0m expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[0;32m 2321\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 2322\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[0;32m 2323\u001b[0m )\n\u001b[0;32m 2325\u001b[0m \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[1;32m-> 2326\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2327\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2328\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2329\u001b[0m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2330\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2331\u001b[0m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2332\u001b[0m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2333\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2334\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2336\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[0;32m 2337\u001b[0m \u001b[38;5;66;03m# 11. interleave input_ids with `num_beams` additional sequences per batch\u001b[39;00m\n\u001b[0;32m 2338\u001b[0m input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[0;32m 2339\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[0;32m 2340\u001b[0m expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_beams,\n\u001b[0;32m 2341\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 2342\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[0;32m 2343\u001b[0m )\n",
623
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\utils.py:3289\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[1;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[0;32m 3287\u001b[0m is_prefill \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 3288\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 3289\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 3291\u001b[0m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[0;32m 3292\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_model_kwargs_for_generation(\n\u001b[0;32m 3293\u001b[0m outputs,\n\u001b[0;32m 3294\u001b[0m model_kwargs,\n\u001b[0;32m 3295\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[0;32m 3296\u001b[0m )\n",
624
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
625
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
626
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\accelerate\\hooks.py:170\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[1;34m(module, *args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 170\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
627
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\utils\\deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[0;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[0;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
628
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\models\\llama\\modeling_llama.py:853\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[0;32m 850\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m 852\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[1;32m--> 853\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 854\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 855\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 856\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 865\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 867\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 868\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n",
629
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
630
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
631
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\accelerate\\hooks.py:170\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[1;34m(module, *args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 170\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
632
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\models\\llama\\modeling_llama.py:601\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[1;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[0;32m 589\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[0;32m 590\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[0;32m 591\u001b[0m hidden_states,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 598\u001b[0m position_embeddings,\n\u001b[0;32m 599\u001b[0m )\n\u001b[0;32m 600\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 601\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 602\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 603\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 604\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 605\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 606\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 607\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 608\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 610\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 611\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 613\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 615\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n",
633
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
634
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
635
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\accelerate\\hooks.py:170\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[1;34m(module, *args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 170\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
636
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\models\\llama\\modeling_llama.py:343\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[1;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[0;32m 340\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[0;32m 342\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m--> 343\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 344\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 345\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 346\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 347\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 353\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 354\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[0;32m 356\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
637
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
638
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
639
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\accelerate\\hooks.py:170\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[1;34m(module, *args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 170\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
640
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\models\\llama\\modeling_llama.py:277\u001b[0m, in \u001b[0;36mLlamaAttention.forward\u001b[1;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[0;32m 274\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m 275\u001b[0m hidden_shape \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim)\n\u001b[1;32m--> 277\u001b[0m query_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mq_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mview(hidden_shape)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m 278\u001b[0m key_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mk_proj(hidden_states)\u001b[38;5;241m.\u001b[39mview(hidden_shape)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m 279\u001b[0m value_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv_proj(hidden_states)\u001b[38;5;241m.\u001b[39mview(hidden_shape)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m)\n",
641
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
642
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
643
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\accelerate\\hooks.py:170\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[1;34m(module, *args, **kwargs)\u001b[0m\n\u001b[0;32m 168\u001b[0m output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 169\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 170\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
644
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\nn\\modules.py:990\u001b[0m, in \u001b[0;36mLinear8bitLt.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 987\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m!=\u001b[39m x\u001b[38;5;241m.\u001b[39mdtype:\n\u001b[0;32m 988\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mto(x\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m--> 990\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mbnb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmatmul\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mhas_fp16_weights \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mCB \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 993\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mCB\n",
645
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:509\u001b[0m, in \u001b[0;36mmatmul\u001b[1;34m(A, B, out, state, threshold, bias)\u001b[0m\n\u001b[0;32m 507\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m threshold \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0.0\u001b[39m:\n\u001b[0;32m 508\u001b[0m state\u001b[38;5;241m.\u001b[39mthreshold \u001b[38;5;241m=\u001b[39m threshold\n\u001b[1;32m--> 509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mMatMul8bitLt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mB\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n",
646
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\torch\\autograd\\function.py:574\u001b[0m, in \u001b[0;36mFunction.apply\u001b[1;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[0;32m 571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[0;32m 572\u001b[0m \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[0;32m 573\u001b[0m args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[1;32m--> 574\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 576\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_setup_ctx_defined:\n\u001b[0;32m 577\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[0;32m 578\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 579\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 580\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 581\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/main/notes/extending.func.html\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 582\u001b[0m )\n",
647
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:326\u001b[0m, in \u001b[0;36mMatMul8bitLt.forward\u001b[1;34m(ctx, A, B, out, bias, state)\u001b[0m\n\u001b[0;32m 323\u001b[0m CA, CAt, SCA, SCAt, outlier_cols \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mint8_double_quant(A\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mfloat16), threshold\u001b[38;5;241m=\u001b[39mstate\u001b[38;5;241m.\u001b[39mthreshold)\n\u001b[0;32m 324\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 325\u001b[0m \u001b[38;5;66;03m# Fast path\u001b[39;00m\n\u001b[1;32m--> 326\u001b[0m CA, SCA, outlier_cols \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mint8_vectorwise_quant\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mthreshold\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mthreshold\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 327\u001b[0m CAt \u001b[38;5;241m=\u001b[39m SCAt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 329\u001b[0m has_grad \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
648
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\functional.py:2789\u001b[0m, in \u001b[0;36mint8_vectorwise_quant\u001b[1;34m(A, threshold)\u001b[0m\n\u001b[0;32m 2786\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m outliers\u001b[38;5;241m.\u001b[39many():\n\u001b[0;32m 2787\u001b[0m outlier_cols \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margwhere(outliers\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m))\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m-> 2789\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_cuda_device_of\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[0;32m 2790\u001b[0m lib\u001b[38;5;241m.\u001b[39mcint8_vector_quant(\n\u001b[0;32m 2791\u001b[0m get_ptr(A),\n\u001b[0;32m 2792\u001b[0m get_ptr(out_row),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2797\u001b[0m _get_tensor_stream(A),\n\u001b[0;32m 2798\u001b[0m )\n\u001b[0;32m 2800\u001b[0m \u001b[38;5;66;03m# Zero out values from outlier columns across all rows.\u001b[39;00m\n\u001b[0;32m 2801\u001b[0m \u001b[38;5;66;03m# The kernel will handle this for outliers themselves, so we can optimize for rows=1.\u001b[39;00m\n",
649
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\functional.py:205\u001b[0m, in \u001b[0;36m_cuda_device_of\u001b[1;34m(a)\u001b[0m\n\u001b[0;32m 202\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 203\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mcontextlib\u001b[39;00m\n\u001b[1;32m--> 205\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_cuda_device_of\u001b[39m(a: torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m 206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext()\n\u001b[0;32m 209\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_paged\u001b[39m(\u001b[38;5;241m*\u001b[39mshape, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, device\u001b[38;5;241m=\u001b[39mFIRST_CUDA_DEVICE):\n",
650
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
651
+ ]
652
+ }
653
+ ],
654
+ "source": [
655
+ "game_queries = pd.read_csv(\"./train-data/queries_from_game.tsv\", sep='\\t')\n",
656
+ "run_evaluation(game_queries, \"Queries from game\")\n",
657
+ "print(\"Dataset length: \" + str(len(game_queries)))"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "markdown",
662
+ "metadata": {},
663
+ "source": [
664
+ "## Evaluate on other stats queries"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": null,
670
+ "metadata": {},
671
+ "outputs": [],
672
+ "source": [
673
+ "other_stats_queries = pd.read_csv(\"./train-data/queries_from_other_stats.tsv\", sep='\\t')\n",
674
+ "run_evaluation(other_stats_queries, \"Queries from other stats\")\n",
675
+ "print(\"Dataset length: \" + str(len(other_stats_queries)))"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "markdown",
680
+ "metadata": {},
681
+ "source": [
682
+ "## Evaluate on team queries"
683
+ ]
684
+ },
685
+ {
686
+ "cell_type": "code",
687
+ "execution_count": null,
688
+ "metadata": {},
689
+ "outputs": [],
690
+ "source": [
691
+ "team_queries = pd.read_csv(\"./train-data/queries_from_team.tsv\", sep='\\t')\n",
692
+ "run_evaluation(team_queries, \"Queries from team\")\n",
693
+ "print(\"Dataset length: \" + str(len(team_queries)))"
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "markdown",
698
+ "metadata": {},
699
+ "source": [
700
+ "## Evaluate on queries requiring join statements"
701
+ ]
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "execution_count": null,
706
+ "metadata": {},
707
+ "outputs": [],
708
+ "source": [
709
+ "join_queries = pd.read_csv(\"./train-data/with_join.tsv\", sep='\\t')\n",
710
+ "run_evaluation(join_queries, \"Queries with join\")\n",
711
+ "print(\"Dataset length: \" + str(len(join_queries)))"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "markdown",
716
+ "metadata": {},
717
+ "source": [
718
+ "## Evaluate on queries not requiring join statements"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": null,
724
+ "metadata": {},
725
+ "outputs": [],
726
+ "source": [
727
+ "no_join_queries = pd.read_csv(\"./train-data/without_join.tsv\", sep='\\t')\n",
728
+ "run_evaluation(no_join_queries, \"Queries without join\")\n",
729
+ "print(\"Dataset length: \" + str(len(no_join_queries)))"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "markdown",
734
+ "metadata": {},
735
+ "source": [
736
+ "## Evaluate on full training dataset"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": null,
742
+ "metadata": {},
743
+ "outputs": [],
744
+ "source": [
745
+ "# Run evaluation on all training data\n",
746
+ "run_evaluation(df, \"All training data\")\n",
747
+ "print(\"Dataset length: \" + str(len(df)))"
748
+ ]
749
+ }
750
+ ],
751
+ "metadata": {
752
+ "kernelspec": {
753
+ "display_name": "Python 3",
754
+ "language": "python",
755
+ "name": "python3"
756
+ },
757
+ "language_info": {
758
+ "codemirror_mode": {
759
+ "name": "ipython",
760
+ "version": 3
761
+ },
762
+ "file_extension": ".py",
763
+ "mimetype": "text/x-python",
764
+ "name": "python",
765
+ "nbconvert_exporter": "python",
766
+ "pygments_lexer": "ipython3",
767
+ "version": "3.12.6"
768
+ }
769
+ },
770
+ "nbformat": 4,
771
+ "nbformat_minor": 2
772
+ }