DeanGumas commited on
Commit
8c47142
·
1 Parent(s): 32c1934

Initial attempt at fine-tuning using LoRA with basic cross-entropy loss

Browse files
Files changed (1) hide show
  1. finetune_model.ipynb +524 -0
finetune_model.ipynb ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Finetune DeepSeek Coder 1.3B for NBA Kaggle Database SQLite Generation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## First load data and convert to Dataset object tokenized by the DeepSeek model"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stderr",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
27
+ " from .autonotebook import tqdm as notebook_tqdm\n"
28
+ ]
29
+ },
30
+ {
31
+ "name": "stdout",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "WARNING:tensorflow:From c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
35
+ "\n",
36
+ "Total dataset examples: 1044\n",
37
+ " natural_query \\\n",
38
+ "0 Which NBA teams were established after the yea... \n",
39
+ "1 What is the most points the Los Angeles Lakers... \n",
40
+ "2 What is the second-highest number of points th... \n",
41
+ "3 How many home games did the Golden State Warri... \n",
42
+ "4 What is the average number of assists by the B... \n",
43
+ "\n",
44
+ " sql_query result \n",
45
+ "0 SELECT full_name FROM team WHERE year_founded ... New Orleans Pelicans \n",
46
+ "1 SELECT MAX(pts_home) FROM game WHERE team_nam... 162 \n",
47
+ "2 SELECT pts_home FROM game WHERE team_name_home... 156 \n",
48
+ "3 SELECT COUNT(*) FROM game WHERE team_abbrevi... 29 \n",
49
+ "4 SELECT AVG(ast_home) FROM game WHERE team_ab... 26.51355662 \n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "Map: 100%|██████████| 1044/1044 [00:00<00:00, 4433.07 examples/s]"
57
+ ]
58
+ },
59
+ {
60
+ "name": "stdout",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "939\n",
64
+ "105\n"
65
+ ]
66
+ },
67
+ {
68
+ "name": "stderr",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "\n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "import pandas as pd\n",
77
+ "import torch\n",
78
+ "from datasets import Dataset\n",
79
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig\n",
80
+ "from torch.utils.data import DataLoader\n",
81
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
82
+ "import os\n",
83
+ "\n",
84
+ "# Load dataset\n",
85
+ "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
86
+ "\n",
87
+ "# Display dataset info\n",
88
+ "print(f\"Total dataset examples: {len(df)}\")\n",
89
+ "print(df.head())\n",
90
+ "\n",
91
+ "# Load tokenizer\n",
92
+ "model_name = \"./deepseek-coder-1.3b-instruct\"\n",
93
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
94
+ "\n",
95
+ "# Preprocessing function\n",
96
+ "def preprocess_function(examples):\n",
97
+ " \"\"\"\n",
98
+ " Tokenizes input natural language queries and corresponding SQL queries.\n",
99
+ " \"\"\"\n",
100
+ " inputs = [\"Translate to SQL: \" + q for q in examples[\"natural_query\"]]\n",
101
+ " targets = examples[\"sql_query\"]\n",
102
+ "\n",
103
+ " model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=256)\n",
104
+ " labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=256)\n",
105
+ "\n",
106
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
107
+ " return model_inputs\n",
108
+ "\n",
109
+ "# Convert to Hugging Face Dataset\n",
110
+ "dataset = Dataset.from_pandas(df)\n",
111
+ "\n",
112
+ "# Apply tokenization\n",
113
+ "tokenized_dataset = dataset.map(preprocess_function, batched=True)\n",
114
+ "\n",
115
+ "# Split into train/validation\n",
116
+ "split = int(0.9 * len(tokenized_dataset)) # 90% train, 10% validation\n",
117
+ "train_dataset = tokenized_dataset.select(range(split))\n",
118
+ "val_dataset = tokenized_dataset.select(range(split, len(tokenized_dataset)))\n",
119
+ "\n",
120
+ "print(len(train_dataset))\n",
121
+ "print(len(val_dataset))"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": [
128
+ "## Load model and define training arguments"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 2,
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "trainable params: 6,291,456 || all params: 1,352,763,392 || trainable%: 0.4651\n"
141
+ ]
142
+ }
143
+ ],
144
+ "source": [
145
+ "# Enable 8-bit quantization for lower memory usage\n",
146
+ "bnb_config = BitsAndBytesConfig(\n",
147
+ " load_in_8bit=True, \n",
148
+ " bnb_8bit_compute_dtype=torch.float16\n",
149
+ ")\n",
150
+ "\n",
151
+ "# Load model with quantization\n",
152
+ "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
153
+ "device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
154
+ "device = torch.device(device_name)\n",
155
+ "model = AutoModelForCausalLM.from_pretrained(\n",
156
+ " model_name, \n",
157
+ " quantization_config=bnb_config,\n",
158
+ " device_map=device\n",
159
+ ")\n",
160
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
161
+ "\n",
162
+ "# Define LoRA configuration\n",
163
+ "lora_config = LoraConfig(\n",
164
+ " r=16, # Rank of LoRA matrices (adjust for memory vs. accuracy)\n",
165
+ " lora_alpha=32, # Scaling factor\n",
166
+ " lora_dropout=0.1, # Dropout for regularization\n",
167
+ " bias=\"none\",\n",
168
+ " task_type=TaskType.CAUSAL_LM,\n",
169
+ " target_modules=[\n",
170
+ " \"q_proj\",\n",
171
+ " \"k_proj\",\n",
172
+ " \"v_proj\",\n",
173
+ " \"o_proj\"\n",
174
+ " ]\n",
175
+ ")\n",
176
+ "\n",
177
+ "# Wrap model with LoRA adapters\n",
178
+ "model = get_peft_model(model, lora_config)\n",
179
+ "model = model.to(device)\n",
180
+ "model.print_trainable_parameters() # Show trainable parameters count"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "markdown",
185
+ "metadata": {},
186
+ "source": [
187
+ "## Define prompt for model"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 3,
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "input_prompt = \"\"\"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
197
+ "Database Schema and Explanations\n",
198
+ "\n",
199
+ "team Table\n",
200
+ "Stores information about NBA teams.\n",
201
+ "CREATE TABLE IF NOT EXISTS \"team\" (\n",
202
+ " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
203
+ " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
204
+ " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
205
+ " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
206
+ " \"city\" TEXT, -- City where the team is based\n",
207
+ " \"state\" TEXT, -- State where the team is located\n",
208
+ " \"year_founded\" REAL -- Year the team was established\n",
209
+ ");\n",
210
+ "\n",
211
+ "game Table\n",
212
+ "Contains detailed statistics for each NBA game, including home and away team performance.\n",
213
+ "CREATE TABLE IF NOT EXISTS \"game\" (\n",
214
+ " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
215
+ " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
216
+ " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
217
+ " \"team_name_home\" TEXT, -- Full name of the home team\n",
218
+ " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
219
+ " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
220
+ " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
221
+ " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
222
+ " \"min\" INTEGER, -- Total minutes played in the game\n",
223
+ " \"fgm_home\" REAL, -- Field goals made by the home team\n",
224
+ " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
225
+ " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
226
+ " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
227
+ " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
228
+ " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
229
+ " \"ftm_home\" REAL, -- Free throws made by the home team\n",
230
+ " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
231
+ " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
232
+ " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
233
+ " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
234
+ " \"reb_home\" REAL, -- Total rebounds by the home team\n",
235
+ " \"ast_home\" REAL, -- Assists by the home team\n",
236
+ " \"stl_home\" REAL, -- Steals by the home team\n",
237
+ " \"blk_home\" REAL, -- Blocks by the home team\n",
238
+ " \"tov_home\" REAL, -- Turnovers by the home team\n",
239
+ " \"pf_home\" REAL, -- Personal fouls by the home team\n",
240
+ " \"pts_home\" REAL, -- Total points scored by the home team\n",
241
+ " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
242
+ " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
243
+ " \"team_id_away\" TEXT, -- ID of the away team\n",
244
+ " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
245
+ " \"team_name_away\" TEXT, -- Full name of the away team\n",
246
+ " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
247
+ " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
248
+ " \"fgm_away\" REAL, -- Field goals made by the away team\n",
249
+ " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
250
+ " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
251
+ " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
252
+ " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
253
+ " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
254
+ " \"ftm_away\" REAL, -- Free throws made by the away team\n",
255
+ " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
256
+ " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
257
+ " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
258
+ " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
259
+ " \"reb_away\" REAL, -- Total rebounds by the away team\n",
260
+ " \"ast_away\" REAL, -- Assists by the away team\n",
261
+ " \"stl_away\" REAL, -- Steals by the away team\n",
262
+ " \"blk_away\" REAL, -- Blocks by the away team\n",
263
+ " \"tov_away\" REAL, -- Turnovers by the away team\n",
264
+ " \"pf_away\" REAL, -- Personal fouls by the away team\n",
265
+ " \"pts_away\" REAL, -- Total points scored by the away team\n",
266
+ " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
267
+ " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
268
+ " \"season_type\" TEXT -- Regular season or playoffs\n",
269
+ ");\n",
270
+ "\n",
271
+ "other_stats Table\n",
272
+ "Stores additional statistics, linked to the game table via game_id.\n",
273
+ "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
274
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
275
+ " \"league_id\" TEXT, -- League identifier\n",
276
+ " \"team_id_home\" TEXT, -- Home team identifier\n",
277
+ " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
278
+ " \"team_city_home\" TEXT, -- Home team city\n",
279
+ " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
280
+ " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
281
+ " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
282
+ " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
283
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
284
+ " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
285
+ " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
286
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
287
+ " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
288
+ " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
289
+ " \"team_id_away\" TEXT, -- Away team identifier\n",
290
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
291
+ " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
292
+ " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
293
+ " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
294
+ " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
295
+ " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
296
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
297
+ " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
298
+ " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
299
+ ");\n",
300
+ "\n",
301
+ "\n",
302
+ "Team Name Information\n",
303
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
304
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
305
+ "Notice they are separated by the | character in the following list:\n",
306
+ "\n",
307
+ "Atlanta Hawks|ATL\n",
308
+ "Boston Celtics|BOS\n",
309
+ "Cleveland Cavaliers|CLE\n",
310
+ "New Orleans Pelicans|NOP\n",
311
+ "Chicago Bulls|CHI\n",
312
+ "Dallas Mavericks|DAL\n",
313
+ "Denver Nuggets|DEN\n",
314
+ "Golden State Warriors|GSW\n",
315
+ "Houston Rockets|HOU\n",
316
+ "Los Angeles Clippers|LAC\n",
317
+ "Los Angeles Lakers|LAL\n",
318
+ "Miami Heat|MIA\n",
319
+ "Milwaukee Bucks|MIL\n",
320
+ "Minnesota Timberwolves|MIN\n",
321
+ "Brooklyn Nets|BKN\n",
322
+ "New York Knicks|NYK\n",
323
+ "Orlando Magic|ORL\n",
324
+ "Indiana Pacers|IND\n",
325
+ "Philadelphia 76ers|PHI\n",
326
+ "Phoenix Suns|PHX\n",
327
+ "Portland Trail Blazers|POR\n",
328
+ "Sacramento Kings|SAC\n",
329
+ "San Antonio Spurs|SAS\n",
330
+ "Oklahoma City Thunder|OKC\n",
331
+ "Toronto Raptors|TOR\n",
332
+ "Utah Jazz|UTA\n",
333
+ "Memphis Grizzlies|MEM\n",
334
+ "Washington Wizards|WAS\n",
335
+ "Detroit Pistons|DET\n",
336
+ "Charlotte Hornets|CHA\n",
337
+ "\n",
338
+ "Query Guidelines\n",
339
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
340
+ "\n",
341
+ "To filter by season, use season_id = '2YYYY'.\n",
342
+ "\n",
343
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
344
+ "\n",
345
+ "Ensure queries return relevant columns and avoid unnecessary joins.\n",
346
+ "\n",
347
+ "Example User Requests and SQLite Queries\n",
348
+ "Request:\n",
349
+ "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
350
+ "SQLite:\n",
351
+ "SELECT MAX(pts_home) \n",
352
+ "FROM game \n",
353
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
354
+ "\n",
355
+ "Request:\n",
356
+ "\"Which teams are located in the state of California?\"\n",
357
+ "SQLite:\n",
358
+ "SELECT full_name FROM team WHERE state = 'California';\n",
359
+ "\n",
360
+ "Request:\n",
361
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
362
+ "SQLite:\n",
363
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
364
+ "\n",
365
+ "Request:\n",
366
+ "\"Which teams were founded before 1979?\"\n",
367
+ "SQLite:\n",
368
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
369
+ "\n",
370
+ "Request:\n",
371
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
372
+ "SQLite:\n",
373
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
374
+ "FROM game\n",
375
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
376
+ "\n",
377
+ "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
378
+ "\"\"\""
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "markdown",
383
+ "metadata": {},
384
+ "source": [
385
+ "## Setup model trainer"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 4,
391
+ "metadata": {},
392
+ "outputs": [
393
+ {
394
+ "name": "stderr",
395
+ "output_type": "stream",
396
+ "text": [
397
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
398
+ " warnings.warn(\n",
399
+ "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_12256\\3557190339.py:17: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
400
+ " trainer = Trainer(\n",
401
+ "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
402
+ ]
403
+ }
404
+ ],
405
+ "source": [
406
+ "training_args = TrainingArguments(\n",
407
+ " output_dir=\"./fine-tuned-model\",\n",
408
+ " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
409
+ " save_strategy=\"epoch\", # Save model every epoch\n",
410
+ " per_device_train_batch_size=8, # LoRA allows higher batch size\n",
411
+ " per_device_eval_batch_size=8,\n",
412
+ " num_train_epochs=3, # Increase if needed\n",
413
+ " learning_rate=5e-4, # Higher LR since we're only training LoRA layers\n",
414
+ " weight_decay=0.01,\n",
415
+ " logging_steps=50, # Print loss every 50 steps\n",
416
+ " save_total_limit=2, # Keep last 2 checkpoints\n",
417
+ " fp16=True if torch.cuda.is_available() else False,\n",
418
+ " push_to_hub=False\n",
419
+ ")\n",
420
+ "\n",
421
+ "# Trainer setup\n",
422
+ "trainer = Trainer(\n",
423
+ " model=model,\n",
424
+ " args=training_args,\n",
425
+ " train_dataset=train_dataset,\n",
426
+ " eval_dataset=val_dataset,\n",
427
+ " tokenizer=tokenizer\n",
428
+ ")"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "metadata": {},
434
+ "source": [
435
+ "## Run fine-tuning and save model weights when complete"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": 5,
441
+ "metadata": {},
442
+ "outputs": [
443
+ {
444
+ "name": "stderr",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\integrations\\sdpa_attention.py:54: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n",
448
+ " attn_output = torch.nn.functional.scaled_dot_product_attention(\n"
449
+ ]
450
+ },
451
+ {
452
+ "data": {
453
+ "text/html": [
454
+ "\n",
455
+ " <div>\n",
456
+ " \n",
457
+ " <progress value='6' max='354' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
458
+ " [ 6/354 00:03 < 05:16, 1.10 it/s, Epoch 0.04/3]\n",
459
+ " </div>\n",
460
+ " <table border=\"1\" class=\"dataframe\">\n",
461
+ " <thead>\n",
462
+ " <tr style=\"text-align: left;\">\n",
463
+ " <th>Epoch</th>\n",
464
+ " <th>Training Loss</th>\n",
465
+ " <th>Validation Loss</th>\n",
466
+ " </tr>\n",
467
+ " </thead>\n",
468
+ " <tbody>\n",
469
+ " </tbody>\n",
470
+ "</table><p>"
471
+ ],
472
+ "text/plain": [
473
+ "<IPython.core.display.HTML object>"
474
+ ]
475
+ },
476
+ "metadata": {},
477
+ "output_type": "display_data"
478
+ },
479
+ {
480
+ "ename": "KeyboardInterrupt",
481
+ "evalue": "",
482
+ "output_type": "error",
483
+ "traceback": [
484
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
485
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
486
+ "Cell \u001b[1;32mIn[5], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Run training\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m# Save model and tokenizer weights\u001b[39;00m\n\u001b[0;32m 5\u001b[0m model\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./fine-tuned-model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
487
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\trainer.py:2245\u001b[0m, in \u001b[0;36mTrainer.train\u001b[1;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[0;32m 2243\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[0;32m 2244\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2245\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2246\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2247\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2248\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2249\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2250\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
488
+ "File \u001b[1;32mc:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\trainer.py:2561\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[1;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[0;32m 2555\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m 2556\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining_step(model, inputs, num_items_in_batch)\n\u001b[0;32m 2558\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[0;32m 2559\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[0;32m 2560\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m-> 2561\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43misinf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss_step\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 2562\u001b[0m ):\n\u001b[0;32m 2563\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[0;32m 2564\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n\u001b[0;32m 2565\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
489
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
490
+ ]
491
+ }
492
+ ],
493
+ "source": [
494
+ "# Run training\n",
495
+ "trainer.train()\n",
496
+ "\n",
497
+ "# Save model and tokenizer weights\n",
498
+ "model.save_pretrained(\"./fine-tuned-model\")\n",
499
+ "tokenizer.save_pretrained(\"./fine-tuned-model\")"
500
+ ]
501
+ }
502
+ ],
503
+ "metadata": {
504
+ "kernelspec": {
505
+ "display_name": "Python 3",
506
+ "language": "python",
507
+ "name": "python3"
508
+ },
509
+ "language_info": {
510
+ "codemirror_mode": {
511
+ "name": "ipython",
512
+ "version": 3
513
+ },
514
+ "file_extension": ".py",
515
+ "mimetype": "text/x-python",
516
+ "name": "python",
517
+ "nbconvert_exporter": "python",
518
+ "pygments_lexer": "ipython3",
519
+ "version": "3.12.6"
520
+ }
521
+ },
522
+ "nbformat": 4,
523
+ "nbformat_minor": 2
524
+ }