licesma commited on
Commit
569052e
Β·
1 Parent(s): 2b3100e

Prepare fine-tune for colab

Browse files
Files changed (1) hide show
  1. finetune_model.ipynb +86 -219
finetune_model.ipynb CHANGED
@@ -7,6 +7,73 @@
7
  "# Finetune DeepSeek Coder 1.3B for NBA Kaggle Database SQLite Generation"
8
  ]
9
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  {
11
  "cell_type": "markdown",
12
  "metadata": {},
@@ -28,185 +95,7 @@
28
  }
29
  ],
30
  "source": [
31
- "input_prompt = \"\"\"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
32
- "Database Schema and Explanations\n",
33
- "\n",
34
- "team Table\n",
35
- "Stores information about NBA teams.\n",
36
- "CREATE TABLE IF NOT EXISTS \"team\" (\n",
37
- " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
38
- " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
39
- " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
40
- " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
41
- " \"city\" TEXT, -- City where the team is based\n",
42
- " \"state\" TEXT, -- State where the team is located\n",
43
- " \"year_founded\" REAL -- Year the team was established\n",
44
- ");\n",
45
- "\n",
46
- "game Table\n",
47
- "Contains detailed statistics for each NBA game, including home and away team performance.\n",
48
- "CREATE TABLE IF NOT EXISTS \"game\" (\n",
49
- " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
50
- " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
51
- " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
52
- " \"team_name_home\" TEXT, -- Full name of the home team\n",
53
- " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
54
- " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
55
- " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
56
- " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
57
- " \"min\" INTEGER, -- Total minutes played in the game\n",
58
- " \"fgm_home\" REAL, -- Field goals made by the home team\n",
59
- " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
60
- " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
61
- " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
62
- " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
63
- " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
64
- " \"ftm_home\" REAL, -- Free throws made by the home team\n",
65
- " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
66
- " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
67
- " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
68
- " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
69
- " \"reb_home\" REAL, -- Total rebounds by the home team\n",
70
- " \"ast_home\" REAL, -- Assists by the home team\n",
71
- " \"stl_home\" REAL, -- Steals by the home team\n",
72
- " \"blk_home\" REAL, -- Blocks by the home team\n",
73
- " \"tov_home\" REAL, -- Turnovers by the home team\n",
74
- " \"pf_home\" REAL, -- Personal fouls by the home team\n",
75
- " \"pts_home\" REAL, -- Total points scored by the home team\n",
76
- " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
77
- " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
78
- " \"team_id_away\" TEXT, -- ID of the away team\n",
79
- " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
80
- " \"team_name_away\" TEXT, -- Full name of the away team\n",
81
- " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
82
- " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
83
- " \"fgm_away\" REAL, -- Field goals made by the away team\n",
84
- " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
85
- " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
86
- " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
87
- " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
88
- " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
89
- " \"ftm_away\" REAL, -- Free throws made by the away team\n",
90
- " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
91
- " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
92
- " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
93
- " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
94
- " \"reb_away\" REAL, -- Total rebounds by the away team\n",
95
- " \"ast_away\" REAL, -- Assists by the away team\n",
96
- " \"stl_away\" REAL, -- Steals by the away team\n",
97
- " \"blk_away\" REAL, -- Blocks by the away team\n",
98
- " \"tov_away\" REAL, -- Turnovers by the away team\n",
99
- " \"pf_away\" REAL, -- Personal fouls by the away team\n",
100
- " \"pts_away\" REAL, -- Total points scored by the away team\n",
101
- " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
102
- " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
103
- " \"season_type\" TEXT -- Regular season or playoffs\n",
104
- ");\n",
105
- "\n",
106
- "other_stats Table\n",
107
- "Stores additional statistics, linked to the game table via game_id.\n",
108
- "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
109
- " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
110
- " \"league_id\" TEXT, -- League identifier\n",
111
- " \"team_id_home\" TEXT, -- Home team identifier\n",
112
- " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
113
- " \"team_city_home\" TEXT, -- Home team city\n",
114
- " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
115
- " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
116
- " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
117
- " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
118
- " \"lead_changes\" INTEGER, -- Number of lead changes \n",
119
- " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
120
- " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
121
- " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
122
- " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
123
- " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
124
- " \"team_id_away\" TEXT, -- Away team identifier\n",
125
- " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
126
- " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
127
- " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
128
- " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
129
- " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
130
- " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
131
- " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
132
- " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
133
- " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
134
- ");\n",
135
- "\n",
136
- "\n",
137
- "Team Name Information\n",
138
- "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
139
- "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
140
- "Notice they are separated by the | character in the following list:\n",
141
- "\n",
142
- "Atlanta Hawks|ATL\n",
143
- "Boston Celtics|BOS\n",
144
- "Cleveland Cavaliers|CLE\n",
145
- "New Orleans Pelicans|NOP\n",
146
- "Chicago Bulls|CHI\n",
147
- "Dallas Mavericks|DAL\n",
148
- "Denver Nuggets|DEN\n",
149
- "Golden State Warriors|GSW\n",
150
- "Houston Rockets|HOU\n",
151
- "Los Angeles Clippers|LAC\n",
152
- "Los Angeles Lakers|LAL\n",
153
- "Miami Heat|MIA\n",
154
- "Milwaukee Bucks|MIL\n",
155
- "Minnesota Timberwolves|MIN\n",
156
- "Brooklyn Nets|BKN\n",
157
- "New York Knicks|NYK\n",
158
- "Orlando Magic|ORL\n",
159
- "Indiana Pacers|IND\n",
160
- "Philadelphia 76ers|PHI\n",
161
- "Phoenix Suns|PHX\n",
162
- "Portland Trail Blazers|POR\n",
163
- "Sacramento Kings|SAC\n",
164
- "San Antonio Spurs|SAS\n",
165
- "Oklahoma City Thunder|OKC\n",
166
- "Toronto Raptors|TOR\n",
167
- "Utah Jazz|UTA\n",
168
- "Memphis Grizzlies|MEM\n",
169
- "Washington Wizards|WAS\n",
170
- "Detroit Pistons|DET\n",
171
- "Charlotte Hornets|CHA\n",
172
- "\n",
173
- "Query Guidelines\n",
174
- "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
175
- "\n",
176
- "To filter by season, use season_id = '2YYYY'.\n",
177
- "\n",
178
- "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
179
- "\n",
180
- "Ensure queries return relevant columns and avoid unnecessary joins.\n",
181
- "\n",
182
- "Example User Requests and SQLite Queries\n",
183
- "Request:\n",
184
- "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
185
- "SQLite:\n",
186
- "SELECT MAX(pts_home) FROM game WHERE team_name_home = 'Los Angeles Lakers';\n",
187
- "\n",
188
- "Request:\n",
189
- "\"Which teams are located in the state of California?\"\n",
190
- "SQLite:\n",
191
- "SELECT full_name FROM team WHERE state = 'California';\n",
192
- "\n",
193
- "Request:\n",
194
- "\"Which team had the highest number of team turnovers in an away game?\"\n",
195
- "SQLite:\n",
196
- "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
197
- "\n",
198
- "Request:\n",
199
- "\"Which teams were founded before 1979?\"\n",
200
- "SQLite:\n",
201
- "SELECT full_name FROM team WHERE year_founded < 1979;\n",
202
- "\n",
203
- "Request:\n",
204
- "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
205
- "SQLite:\n",
206
- "SELECT MAX(pts_home - pts_away) AS biggest_win FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
207
- "\n",
208
- "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
209
- "\"\"\"\n",
210
  "\n",
211
  "print(len(input_prompt))"
212
  ]
@@ -220,30 +109,14 @@
220
  },
221
  {
222
  "cell_type": "code",
223
- "execution_count": 2,
224
  "metadata": {},
225
  "outputs": [
226
  {
227
  "name": "stderr",
228
  "output_type": "stream",
229
  "text": [
230
- "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
231
- " from .autonotebook import tqdm as notebook_tqdm\n"
232
- ]
233
- },
234
- {
235
- "name": "stdout",
236
- "output_type": "stream",
237
- "text": [
238
- "WARNING:tensorflow:From c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
239
- "\n"
240
- ]
241
- },
242
- {
243
- "name": "stderr",
244
- "output_type": "stream",
245
- "text": [
246
- "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_6496\\2921743792.py:18: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
247
  " df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n"
248
  ]
249
  },
@@ -274,7 +147,7 @@
274
  "name": "stderr",
275
  "output_type": "stream",
276
  "text": [
277
- "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1044/1044 [00:37<00:00, 27.57 examples/s]"
278
  ]
279
  },
280
  {
@@ -295,22 +168,14 @@
295
  }
296
  ],
297
  "source": [
298
- "import pandas as pd\n",
299
- "import torch\n",
300
- "from datasets import Dataset\n",
301
- "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig, EarlyStoppingCallback, PreTrainedTokenizer\n",
302
- "from torch.utils.data import DataLoader\n",
303
- "from peft import LoraConfig, get_peft_model, TaskType\n",
304
- "import os\n",
305
- "import re\n",
306
- "import numpy as np\n",
307
  "\n",
308
  "# Model output directories\n",
309
- "MODEL_DIR = \"./fine-tuned-model-16\"\n",
310
- "VAL_OUTPUT = \"val-16.hf\"\n",
311
  "\n",
312
  "# Load dataset\n",
313
- "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
314
  "\n",
315
  "df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
316
  "\n",
@@ -319,14 +184,16 @@
319
  "print(df.head())\n",
320
  "\n",
321
  "# Load tokenizer\n",
322
- "model_name = \"./deepseek-coder-1.3b-instruct\"\n",
323
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
324
  "\n",
325
  "# Enable 8-bit quantization for lower memory usage\n",
326
- "bnb_config = BitsAndBytesConfig(\n",
327
- " load_in_8bit=True, \n",
328
- " bnb_8bit_compute_dtype=torch.float16\n",
329
- ")\n",
 
 
330
  "\n",
331
  "# Load model with quantization\n",
332
  "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
@@ -870,7 +737,7 @@
870
  },
871
  {
872
  "cell_type": "code",
873
- "execution_count": 8,
874
  "metadata": {},
875
  "outputs": [
876
  {
@@ -893,7 +760,7 @@
893
  "print(prompt_length)\n",
894
  "\n",
895
  "# Create connection to sqlite3 database\n",
896
- "connection = sql.connect('./nba-data/nba.sqlite')\n",
897
  "cursor = connection.cursor()\n",
898
  "\n",
899
  "for v in val_dataset:\n",
@@ -4248,7 +4115,7 @@
4248
  ],
4249
  "metadata": {
4250
  "kernelspec": {
4251
- "display_name": "Python 3",
4252
  "language": "python",
4253
  "name": "python3"
4254
  },
@@ -4262,7 +4129,7 @@
4262
  "name": "python",
4263
  "nbconvert_exporter": "python",
4264
  "pygments_lexer": "ipython3",
4265
- "version": "3.12.6"
4266
  }
4267
  },
4268
  "nbformat": 4,
 
7
  "# Finetune DeepSeek Coder 1.3B for NBA Kaggle Database SQLite Generation"
8
  ]
9
  },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stderr",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "/opt/anaconda3/envs/CSCI544/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
20
+ " from .autonotebook import tqdm as notebook_tqdm\n"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "import pandas as pd\n",
26
+ "import torch\n",
27
+ "from datasets import Dataset\n",
28
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig, EarlyStoppingCallback, PreTrainedTokenizer\n",
29
+ "from torch.utils.data import DataLoader\n",
30
+ "import sys\n",
31
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
32
+ "from huggingface_hub import snapshot_download\n",
33
+ "import os\n",
34
+ "import re\n",
35
+ "import numpy as np"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "is_google_colab = False\n",
45
+ "use_bnb = True"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "current_read_path = \"./\"\n",
55
+ "current_write_path = \"./\"\n",
56
+ "\n",
57
+ "def read_path(rel_path):\n",
58
+ " return os.path.join(current_read_path, rel_path)\n",
59
+ "\n",
60
+ "def write_path(rel_path):\n",
61
+ " return os.path.join(current_write_path, rel_path)\n",
62
+ "\n",
63
+ "if is_google_colab:\n",
64
+ " from google.colab import drive\n",
65
+ " drive.mount('/content/drive')\n",
66
+ " current_write_path = \"/content/drive/MyDrive/sql_gen\"\n",
67
+ "\n",
68
+ " hugging_face_path = snapshot_download(\n",
69
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
70
+ " repo_type=\"model\", \n",
71
+ " allow_patterns=[\"src/*\", \"train-data/*\", \"deepseek-coder-1.3b-instruct/*\", \"nba-data/*\"], \n",
72
+ " )\n",
73
+ " sys.path.append(hugging_face_path)\n",
74
+ " current_path = hugging_face_path"
75
+ ]
76
+ },
77
  {
78
  "cell_type": "markdown",
79
  "metadata": {},
 
95
  }
96
  ],
97
  "source": [
98
+ "from src.prompts.prompt import input_text as input_prompt\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "\n",
100
  "print(len(input_prompt))"
101
  ]
 
109
  },
110
  {
111
  "cell_type": "code",
112
+ "execution_count": null,
113
  "metadata": {},
114
  "outputs": [
115
  {
116
  "name": "stderr",
117
  "output_type": "stream",
118
  "text": [
119
+ "/var/folders/g0/47tr69v179dg7w6zyphp9b280000gn/T/ipykernel_35112/48906000.py:8: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  " df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n"
121
  ]
122
  },
 
147
  "name": "stderr",
148
  "output_type": "stream",
149
  "text": [
150
+ "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1044/1044 [00:17<00:00, 59.19 examples/s]"
151
  ]
152
  },
153
  {
 
168
  }
169
  ],
170
  "source": [
171
+ "\n",
 
 
 
 
 
 
 
 
172
  "\n",
173
  "# Model output directories\n",
174
+ "MODEL_DIR = write_path(\"fine-tuned-model-16-test\")\n",
175
+ "VAL_OUTPUT = write_path(\"val-16.hf\")\n",
176
  "\n",
177
  "# Load dataset\n",
178
+ "df = pd.read_csv(read_path(\"train-data/sql_train.tsv\"), sep='\\t')\n",
179
  "\n",
180
  "df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n",
181
  "\n",
 
184
  "print(df.head())\n",
185
  "\n",
186
  "# Load tokenizer\n",
187
+ "model_name = read_path(\"deepseek-coder-1.3b-instruct\")\n",
188
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
189
  "\n",
190
  "# Enable 8-bit quantization for lower memory usage\n",
191
+ "bnb_config = None\n",
192
+ "if use_bnb:\n",
193
+ " bnb_config = BitsAndBytesConfig(\n",
194
+ " load_in_8bit=True, \n",
195
+ " bnb_8bit_compute_dtype=torch.float16\n",
196
+ " )\n",
197
  "\n",
198
  "# Load model with quantization\n",
199
  "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
 
737
  },
738
  {
739
  "cell_type": "code",
740
+ "execution_count": null,
741
  "metadata": {},
742
  "outputs": [
743
  {
 
760
  "print(prompt_length)\n",
761
  "\n",
762
  "# Create connection to sqlite3 database\n",
763
+ "connection = sql.connect(read_path('nba-data/nba.sqlite'))\n",
764
  "cursor = connection.cursor()\n",
765
  "\n",
766
  "for v in val_dataset:\n",
 
4115
  ],
4116
  "metadata": {
4117
  "kernelspec": {
4118
+ "display_name": "CSCI544",
4119
  "language": "python",
4120
  "name": "python3"
4121
  },
 
4129
  "name": "python",
4130
  "nbconvert_exporter": "python",
4131
  "pygments_lexer": "ipython3",
4132
+ "version": "3.11.11"
4133
  }
4134
  },
4135
  "nbformat": 4,