DeanGumas commited on
Commit
fdaf162
Β·
1 Parent(s): 1f9c86c

First decent attempt, seems to generate valid queries after pre-training. Needs evaluation still

Browse files
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/README.md RENAMED
File without changes
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/adapter_config.json RENAMED
@@ -24,10 +24,10 @@
24
  "rank_pattern": {},
25
  "revision": null,
26
  "target_modules": [
27
- "q_proj",
28
  "k_proj",
29
- "v_proj",
30
- "o_proj"
 
31
  ],
32
  "task_type": "CAUSAL_LM",
33
  "trainable_token_indices": null,
 
24
  "rank_pattern": {},
25
  "revision": null,
26
  "target_modules": [
 
27
  "k_proj",
28
+ "q_proj",
29
+ "o_proj",
30
+ "v_proj"
31
  ],
32
  "task_type": "CAUSAL_LM",
33
  "trainable_token_indices": null,
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/adapter_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1675aadfd9bd5995a73de58148ef251da4d532994321a4944157eb47e23efe1
3
  size 25191536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dbc52cfab26c0dfe8ab09edbd2f3c33aa892b2ea1e668a652968307cf887d90
3
  size 25191536
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/optimizer.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4154204936127b67041d1fa14b3b47592d28098e3af1fbc90ccb186d4b7aa920
3
  size 50492858
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc34cbc76e28ac17c0afc1333f5d7b30a8beebce84a6f3e73589b31ff00733f2
3
  size 50492858
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/rng_state.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9594002d0ce67a88462882fb88058ad890c8624e05cfe9253cbfdb98dbcea4d7
3
  size 14244
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47500e947b0979f4187ced89844b0b41af88c14cc3ed27ad8cb01fdb1072cb88
3
  size 14244
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/scaler.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:713b6197e1c6fd2587869ada366ebb7d57d251ebefa1a0bc3b17bf2ed26a407b
3
  size 988
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0cac5ada32db5a3f9944a8e52bf38c14646b18fcd0190b4fb8155b619d8f5ab
3
  size 988
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/scheduler.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:babe16327608a66b40524b5a3c6405ef24c518e11abf91e1eb89f529de80b3db
3
  size 1064
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee2a611a1529c905642325059ae97e660232f5907273b45d09703eb7bc5fa03
3
  size 1064
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/special_tokens_map.json RENAMED
File without changes
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/tokenizer.json RENAMED
File without changes
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/tokenizer_config.json RENAMED
File without changes
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/trainer_state.json RENAMED
@@ -2,62 +2,93 @@
2
  "best_global_step": null,
3
  "best_metric": null,
4
  "best_model_checkpoint": null,
5
- "epoch": 2.0,
6
  "eval_steps": 500,
7
- "global_step": 236,
8
  "is_hyper_param_search": false,
9
  "is_local_process_zero": true,
10
  "is_world_process_zero": true,
11
  "log_history": [
12
  {
13
- "epoch": 0.423728813559322,
14
- "grad_norm": 1.3252296447753906,
15
- "learning_rate": 0.0004307909604519774,
16
- "loss": 3.0243,
17
  "step": 50
18
  },
19
  {
20
- "epoch": 0.847457627118644,
21
- "grad_norm": 1.302614450454712,
22
- "learning_rate": 0.00036016949152542374,
23
- "loss": 1.2356,
 
 
 
 
 
 
 
 
24
  "step": 100
25
  },
26
  {
27
- "epoch": 1.0,
28
- "eval_loss": 0.9709166288375854,
29
- "eval_runtime": 5.194,
30
- "eval_samples_per_second": 20.216,
31
- "eval_steps_per_second": 2.695,
32
  "step": 118
33
  },
34
  {
35
- "epoch": 1.271186440677966,
36
- "grad_norm": 0.6076271533966064,
37
- "learning_rate": 0.0002895480225988701,
38
- "loss": 1.2005,
39
  "step": 150
40
  },
41
  {
42
- "epoch": 1.694915254237288,
43
- "grad_norm": 1.1516226530075073,
44
- "learning_rate": 0.0002189265536723164,
45
- "loss": 1.2331,
 
 
 
 
 
 
 
 
46
  "step": 200
47
  },
48
  {
49
- "epoch": 2.0,
50
- "eval_loss": 0.9370157718658447,
51
- "eval_runtime": 5.1425,
52
- "eval_samples_per_second": 20.418,
53
- "eval_steps_per_second": 2.722,
54
  "step": 236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
  ],
57
  "logging_steps": 50,
58
- "max_steps": 354,
59
  "num_input_tokens_seen": 0,
60
- "num_train_epochs": 3,
61
  "save_steps": 500,
62
  "stateful_callbacks": {
63
  "TrainerControl": {
@@ -66,13 +97,13 @@
66
  "should_evaluate": false,
67
  "should_log": false,
68
  "should_save": true,
69
- "should_training_stop": false
70
  },
71
  "attributes": {}
72
  }
73
  },
74
- "total_flos": 3711634067423232.0,
75
- "train_batch_size": 8,
76
  "trial_name": null,
77
  "trial_params": null
78
  }
 
2
  "best_global_step": null,
3
  "best_metric": null,
4
  "best_model_checkpoint": null,
5
+ "epoch": 5.0,
6
  "eval_steps": 500,
7
+ "global_step": 295,
8
  "is_hyper_param_search": false,
9
  "is_local_process_zero": true,
10
  "is_world_process_zero": true,
11
  "log_history": [
12
  {
13
+ "epoch": 0.847457627118644,
14
+ "grad_norm": 3.6365857124328613,
15
+ "learning_rate": 4.152542372881356e-05,
16
+ "loss": 8.5103,
17
  "step": 50
18
  },
19
  {
20
+ "epoch": 1.0,
21
+ "eval_loss": 1.4866021871566772,
22
+ "eval_runtime": 5.4305,
23
+ "eval_samples_per_second": 19.335,
24
+ "eval_steps_per_second": 1.289,
25
+ "step": 59
26
+ },
27
+ {
28
+ "epoch": 1.694915254237288,
29
+ "grad_norm": 3.137465715408325,
30
+ "learning_rate": 3.305084745762712e-05,
31
+ "loss": 1.7098,
32
  "step": 100
33
  },
34
  {
35
+ "epoch": 2.0,
36
+ "eval_loss": 1.2273037433624268,
37
+ "eval_runtime": 5.362,
38
+ "eval_samples_per_second": 19.582,
39
+ "eval_steps_per_second": 1.305,
40
  "step": 118
41
  },
42
  {
43
+ "epoch": 2.542372881355932,
44
+ "grad_norm": 1.6243258714675903,
45
+ "learning_rate": 2.457627118644068e-05,
46
+ "loss": 1.5421,
47
  "step": 150
48
  },
49
  {
50
+ "epoch": 3.0,
51
+ "eval_loss": 1.1611202955245972,
52
+ "eval_runtime": 5.348,
53
+ "eval_samples_per_second": 19.634,
54
+ "eval_steps_per_second": 1.309,
55
+ "step": 177
56
+ },
57
+ {
58
+ "epoch": 3.389830508474576,
59
+ "grad_norm": 1.7812302112579346,
60
+ "learning_rate": 1.6101694915254237e-05,
61
+ "loss": 1.4875,
62
  "step": 200
63
  },
64
  {
65
+ "epoch": 4.0,
66
+ "eval_loss": 1.153254508972168,
67
+ "eval_runtime": 5.347,
68
+ "eval_samples_per_second": 19.637,
69
+ "eval_steps_per_second": 1.309,
70
  "step": 236
71
+ },
72
+ {
73
+ "epoch": 4.237288135593221,
74
+ "grad_norm": 2.1582489013671875,
75
+ "learning_rate": 7.627118644067798e-06,
76
+ "loss": 1.3883,
77
+ "step": 250
78
+ },
79
+ {
80
+ "epoch": 5.0,
81
+ "eval_loss": 1.1216797828674316,
82
+ "eval_runtime": 5.3095,
83
+ "eval_samples_per_second": 19.776,
84
+ "eval_steps_per_second": 1.318,
85
+ "step": 295
86
  }
87
  ],
88
  "logging_steps": 50,
89
+ "max_steps": 295,
90
  "num_input_tokens_seen": 0,
91
+ "num_train_epochs": 5,
92
  "save_steps": 500,
93
  "stateful_callbacks": {
94
  "TrainerControl": {
 
97
  "should_evaluate": false,
98
  "should_log": false,
99
  "should_save": true,
100
+ "should_training_stop": true
101
  },
102
  "attributes": {}
103
  }
104
  },
105
+ "total_flos": 9279085168558080.0,
106
+ "train_batch_size": 16,
107
  "trial_name": null,
108
  "trial_params": null
109
  }
fine-tuned-model/{checkpoint-236 β†’ checkpoint-295}/training_args.bin RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:067bd988483e44c78bdd5214e17314d730fc21ddb68a86665b5964dfa1dff3cf
3
  size 5368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:145881f21d543612154f9ea67fda44fa38754eb20bd4eb13e1492be30ee670d7
3
  size 5368
fine-tuned-model/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d7d5c722eaa522148cdd50895d0b066e1cdeb624bf1cf7c77f0de1b647d74ad4
3
  size 1480793144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9062451f577293bdd513174888e96a053a05c6f763eabe4c1196ad82987caf2c
3
  size 1480793144
fine-tuned-model/runs/Apr03_12-32-10_DESKTOP-SMJC97K/events.out.tfevents.1743708731.DESKTOP-SMJC97K.4788.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e13d3e6b9cb02db1498428eecc437b77881de7829603019b54d53fcc6c0a9c7
3
+ size 5656
fine-tuned-model/runs/Apr03_12-35-47_DESKTOP-SMJC97K/events.out.tfevents.1743708948.DESKTOP-SMJC97K.21476.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5d2c6a4697d12bb25ba105c3c95241aa9d53ab1ea6c5a0116ec7a5416885586
3
+ size 8402
fine-tuned-model/runs/Apr03_14-34-10_DESKTOP-SMJC97K/events.out.tfevents.1743716051.DESKTOP-SMJC97K.9248.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1c85607a6a7da30d998111f662d1777c8550adb9a1d046bfae143ea1947ce2a
3
+ size 8402
finetune_model.ipynb CHANGED
@@ -251,13 +251,396 @@
251
  "name": "stderr",
252
  "output_type": "stream",
253
  "text": [
254
- "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1044/1044 [00:01<00:00, 543.61 examples/s]"
255
  ]
256
  },
257
  {
258
  "name": "stdout",
259
  "output_type": "stream",
260
  "text": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  "939\n",
262
  "105\n"
263
  ]
@@ -296,7 +679,10 @@
296
  " Tokenizes input natural language queries and corresponding SQL queries.\n",
297
  " \"\"\"\n",
298
  " inputs = [input_prompt + q for q in examples[\"natural_query\"]]\n",
299
- " targets = examples[\"sql_query\"]\n",
 
 
 
300
  "\n",
301
  " model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=256)\n",
302
  " labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=256)\n",
@@ -396,7 +782,7 @@
396
  "text": [
397
  "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of πŸ€— Transformers. Use `eval_strategy` instead\n",
398
  " warnings.warn(\n",
399
- "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_10280\\3557190339.py:17: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
400
  " trainer = Trainer(\n",
401
  "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
402
  ]
@@ -407,10 +793,10 @@
407
  " output_dir=\"./fine-tuned-model\",\n",
408
  " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
409
  " save_strategy=\"epoch\", # Save model every epoch\n",
410
- " per_device_train_batch_size=8, # LoRA allows higher batch size\n",
411
- " per_device_eval_batch_size=8,\n",
412
- " num_train_epochs=3, # Increase if needed\n",
413
- " learning_rate=5e-4, # Higher LR since we're only training LoRA layers\n",
414
  " weight_decay=0.01,\n",
415
  " logging_steps=50, # Print loss every 50 steps\n",
416
  " save_total_limit=2, # Keep last 2 checkpoints\n",
@@ -454,8 +840,8 @@
454
  "\n",
455
  " <div>\n",
456
  " \n",
457
- " <progress value='354' max='354' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
458
- " [354/354 05:21, Epoch 3/3]\n",
459
  " </div>\n",
460
  " <table border=\"1\" class=\"dataframe\">\n",
461
  " <thead>\n",
@@ -468,18 +854,28 @@
468
  " <tbody>\n",
469
  " <tr>\n",
470
  " <td>1</td>\n",
471
- " <td>1.235600</td>\n",
472
- " <td>0.970917</td>\n",
473
  " </tr>\n",
474
  " <tr>\n",
475
  " <td>2</td>\n",
476
- " <td>1.233100</td>\n",
477
- " <td>0.937016</td>\n",
478
  " </tr>\n",
479
  " <tr>\n",
480
  " <td>3</td>\n",
481
- " <td>1.157600</td>\n",
482
- " <td>0.940143</td>\n",
 
 
 
 
 
 
 
 
 
 
483
  " </tr>\n",
484
  " </tbody>\n",
485
  "</table><p>"
@@ -531,22 +927,15 @@
531
  },
532
  {
533
  "cell_type": "code",
534
- "execution_count": 6,
535
  "metadata": {},
536
  "outputs": [
537
- {
538
- "name": "stderr",
539
- "output_type": "stream",
540
- "text": [
541
- "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
542
- " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
543
- ]
544
- },
545
  {
546
  "name": "stdout",
547
  "output_type": "stream",
548
  "text": [
549
- "Generated SQL: SELECT_________________________________________________________________________________________________________\n"
 
550
  ]
551
  }
552
  ],
@@ -555,7 +944,7 @@
555
  "tokenizer = AutoTokenizer.from_pretrained(\"./fine-tuned-model\")\n",
556
  "\n",
557
  "# Prepare query with the same prompt\n",
558
- "input_text = \"Show the top 5 highest-scoring NBA games in history.\"\n",
559
  "message = [{'role': 'user', 'content': input_prompt + input_text}]\n",
560
  "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
561
  "\n",
 
251
  "name": "stderr",
252
  "output_type": "stream",
253
  "text": [
254
+ "Map: 0%| | 0/1044 [00:00<?, ? examples/s]"
255
  ]
256
  },
257
  {
258
  "name": "stdout",
259
  "output_type": "stream",
260
  "text": [
261
+ "You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
262
+ "Database Schema and Explanations\n",
263
+ "\n",
264
+ "team Table\n",
265
+ "Stores information about NBA teams.\n",
266
+ "CREATE TABLE IF NOT EXISTS \"team\" (\n",
267
+ " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
268
+ " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
269
+ " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
270
+ " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
271
+ " \"city\" TEXT, -- City where the team is based\n",
272
+ " \"state\" TEXT, -- State where the team is located\n",
273
+ " \"year_founded\" REAL -- Year the team was established\n",
274
+ ");\n",
275
+ "\n",
276
+ "game Table\n",
277
+ "Contains detailed statistics for each NBA game, including home and away team performance.\n",
278
+ "CREATE TABLE IF NOT EXISTS \"game\" (\n",
279
+ " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
280
+ " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
281
+ " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
282
+ " \"team_name_home\" TEXT, -- Full name of the home team\n",
283
+ " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
284
+ " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
285
+ " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
286
+ " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
287
+ " \"min\" INTEGER, -- Total minutes played in the game\n",
288
+ " \"fgm_home\" REAL, -- Field goals made by the home team\n",
289
+ " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
290
+ " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
291
+ " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
292
+ " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
293
+ " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
294
+ " \"ftm_home\" REAL, -- Free throws made by the home team\n",
295
+ " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
296
+ " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
297
+ " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
298
+ " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
299
+ " \"reb_home\" REAL, -- Total rebounds by the home team\n",
300
+ " \"ast_home\" REAL, -- Assists by the home team\n",
301
+ " \"stl_home\" REAL, -- Steals by the home team\n",
302
+ " \"blk_home\" REAL, -- Blocks by the home team\n",
303
+ " \"tov_home\" REAL, -- Turnovers by the home team\n",
304
+ " \"pf_home\" REAL, -- Personal fouls by the home team\n",
305
+ " \"pts_home\" REAL, -- Total points scored by the home team\n",
306
+ " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
307
+ " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
308
+ " \"team_id_away\" TEXT, -- ID of the away team\n",
309
+ " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
310
+ " \"team_name_away\" TEXT, -- Full name of the away team\n",
311
+ " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
312
+ " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
313
+ " \"fgm_away\" REAL, -- Field goals made by the away team\n",
314
+ " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
315
+ " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
316
+ " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
317
+ " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
318
+ " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
319
+ " \"ftm_away\" REAL, -- Free throws made by the away team\n",
320
+ " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
321
+ " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
322
+ " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
323
+ " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
324
+ " \"reb_away\" REAL, -- Total rebounds by the away team\n",
325
+ " \"ast_away\" REAL, -- Assists by the away team\n",
326
+ " \"stl_away\" REAL, -- Steals by the away team\n",
327
+ " \"blk_away\" REAL, -- Blocks by the away team\n",
328
+ " \"tov_away\" REAL, -- Turnovers by the away team\n",
329
+ " \"pf_away\" REAL, -- Personal fouls by the away team\n",
330
+ " \"pts_away\" REAL, -- Total points scored by the away team\n",
331
+ " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
332
+ " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
333
+ " \"season_type\" TEXT -- Regular season or playoffs\n",
334
+ ");\n",
335
+ "\n",
336
+ "other_stats Table\n",
337
+ "Stores additional statistics, linked to the game table via game_id.\n",
338
+ "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
339
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
340
+ " \"league_id\" TEXT, -- League identifier\n",
341
+ " \"team_id_home\" TEXT, -- Home team identifier\n",
342
+ " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
343
+ " \"team_city_home\" TEXT, -- Home team city\n",
344
+ " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
345
+ " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
346
+ " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
347
+ " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
348
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
349
+ " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
350
+ " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
351
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
352
+ " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
353
+ " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
354
+ " \"team_id_away\" TEXT, -- Away team identifier\n",
355
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
356
+ " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
357
+ " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
358
+ " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
359
+ " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
360
+ " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
361
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
362
+ " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
363
+ " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
364
+ ");\n",
365
+ "\n",
366
+ "\n",
367
+ "Team Name Information\n",
368
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
369
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
370
+ "Notice they are separated by the | character in the following list:\n",
371
+ "\n",
372
+ "Atlanta Hawks|ATL\n",
373
+ "Boston Celtics|BOS\n",
374
+ "Cleveland Cavaliers|CLE\n",
375
+ "New Orleans Pelicans|NOP\n",
376
+ "Chicago Bulls|CHI\n",
377
+ "Dallas Mavericks|DAL\n",
378
+ "Denver Nuggets|DEN\n",
379
+ "Golden State Warriors|GSW\n",
380
+ "Houston Rockets|HOU\n",
381
+ "Los Angeles Clippers|LAC\n",
382
+ "Los Angeles Lakers|LAL\n",
383
+ "Miami Heat|MIA\n",
384
+ "Milwaukee Bucks|MIL\n",
385
+ "Minnesota Timberwolves|MIN\n",
386
+ "Brooklyn Nets|BKN\n",
387
+ "New York Knicks|NYK\n",
388
+ "Orlando Magic|ORL\n",
389
+ "Indiana Pacers|IND\n",
390
+ "Philadelphia 76ers|PHI\n",
391
+ "Phoenix Suns|PHX\n",
392
+ "Portland Trail Blazers|POR\n",
393
+ "Sacramento Kings|SAC\n",
394
+ "San Antonio Spurs|SAS\n",
395
+ "Oklahoma City Thunder|OKC\n",
396
+ "Toronto Raptors|TOR\n",
397
+ "Utah Jazz|UTA\n",
398
+ "Memphis Grizzlies|MEM\n",
399
+ "Washington Wizards|WAS\n",
400
+ "Detroit Pistons|DET\n",
401
+ "Charlotte Hornets|CHA\n",
402
+ "\n",
403
+ "Query Guidelines\n",
404
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
405
+ "\n",
406
+ "To filter by season, use season_id = '2YYYY'.\n",
407
+ "\n",
408
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
409
+ "\n",
410
+ "Ensure queries return relevant columns and avoid unnecessary joins.\n",
411
+ "\n",
412
+ "Example User Requests and SQLite Queries\n",
413
+ "Request:\n",
414
+ "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
415
+ "SQLite:\n",
416
+ "SELECT MAX(pts_home) \n",
417
+ "FROM game \n",
418
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
419
+ "\n",
420
+ "Request:\n",
421
+ "\"Which teams are located in the state of California?\"\n",
422
+ "SQLite:\n",
423
+ "SELECT full_name FROM team WHERE state = 'California';\n",
424
+ "\n",
425
+ "Request:\n",
426
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
427
+ "SQLite:\n",
428
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
429
+ "\n",
430
+ "Request:\n",
431
+ "\"Which teams were founded before 1979?\"\n",
432
+ "SQLite:\n",
433
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
434
+ "\n",
435
+ "Request:\n",
436
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
437
+ "SQLite:\n",
438
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
439
+ "FROM game\n",
440
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
441
+ "\n",
442
+ "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
443
+ "Which NBA teams were established after the year 2000? List their names and founding years, sorted from newest to oldest\n",
444
+ "SQLite: \n",
445
+ "SELECT full_name FROM team WHERE year_founded > 2000 ORDER BY year_founded DESC;\n"
446
+ ]
447
+ },
448
+ {
449
+ "name": "stderr",
450
+ "output_type": "stream",
451
+ "text": [
452
+ "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1044/1044 [00:01<00:00, 552.67 examples/s]"
453
+ ]
454
+ },
455
+ {
456
+ "name": "stdout",
457
+ "output_type": "stream",
458
+ "text": [
459
+ "You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
460
+ "Database Schema and Explanations\n",
461
+ "\n",
462
+ "team Table\n",
463
+ "Stores information about NBA teams.\n",
464
+ "CREATE TABLE IF NOT EXISTS \"team\" (\n",
465
+ " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
466
+ " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
467
+ " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
468
+ " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
469
+ " \"city\" TEXT, -- City where the team is based\n",
470
+ " \"state\" TEXT, -- State where the team is located\n",
471
+ " \"year_founded\" REAL -- Year the team was established\n",
472
+ ");\n",
473
+ "\n",
474
+ "game Table\n",
475
+ "Contains detailed statistics for each NBA game, including home and away team performance.\n",
476
+ "CREATE TABLE IF NOT EXISTS \"game\" (\n",
477
+ " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
478
+ " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
479
+ " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
480
+ " \"team_name_home\" TEXT, -- Full name of the home team\n",
481
+ " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
482
+ " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
483
+ " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
484
+ " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
485
+ " \"min\" INTEGER, -- Total minutes played in the game\n",
486
+ " \"fgm_home\" REAL, -- Field goals made by the home team\n",
487
+ " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
488
+ " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
489
+ " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
490
+ " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
491
+ " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
492
+ " \"ftm_home\" REAL, -- Free throws made by the home team\n",
493
+ " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
494
+ " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
495
+ " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
496
+ " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
497
+ " \"reb_home\" REAL, -- Total rebounds by the home team\n",
498
+ " \"ast_home\" REAL, -- Assists by the home team\n",
499
+ " \"stl_home\" REAL, -- Steals by the home team\n",
500
+ " \"blk_home\" REAL, -- Blocks by the home team\n",
501
+ " \"tov_home\" REAL, -- Turnovers by the home team\n",
502
+ " \"pf_home\" REAL, -- Personal fouls by the home team\n",
503
+ " \"pts_home\" REAL, -- Total points scored by the home team\n",
504
+ " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
505
+ " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
506
+ " \"team_id_away\" TEXT, -- ID of the away team\n",
507
+ " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
508
+ " \"team_name_away\" TEXT, -- Full name of the away team\n",
509
+ " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
510
+ " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
511
+ " \"fgm_away\" REAL, -- Field goals made by the away team\n",
512
+ " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
513
+ " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
514
+ " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
515
+ " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
516
+ " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
517
+ " \"ftm_away\" REAL, -- Free throws made by the away team\n",
518
+ " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
519
+ " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
520
+ " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
521
+ " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
522
+ " \"reb_away\" REAL, -- Total rebounds by the away team\n",
523
+ " \"ast_away\" REAL, -- Assists by the away team\n",
524
+ " \"stl_away\" REAL, -- Steals by the away team\n",
525
+ " \"blk_away\" REAL, -- Blocks by the away team\n",
526
+ " \"tov_away\" REAL, -- Turnovers by the away team\n",
527
+ " \"pf_away\" REAL, -- Personal fouls by the away team\n",
528
+ " \"pts_away\" REAL, -- Total points scored by the away team\n",
529
+ " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
530
+ " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
531
+ " \"season_type\" TEXT -- Regular season or playoffs\n",
532
+ ");\n",
533
+ "\n",
534
+ "other_stats Table\n",
535
+ "Stores additional statistics, linked to the game table via game_id.\n",
536
+ "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
537
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
538
+ " \"league_id\" TEXT, -- League identifier\n",
539
+ " \"team_id_home\" TEXT, -- Home team identifier\n",
540
+ " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
541
+ " \"team_city_home\" TEXT, -- Home team city\n",
542
+ " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
543
+ " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
544
+ " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
545
+ " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
546
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
547
+ " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
548
+ " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
549
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
550
+ " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
551
+ " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
552
+ " \"team_id_away\" TEXT, -- Away team identifier\n",
553
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
554
+ " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
555
+ " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
556
+ " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
557
+ " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
558
+ " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
559
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
560
+ " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
561
+ " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
562
+ ");\n",
563
+ "\n",
564
+ "\n",
565
+ "Team Name Information\n",
566
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
567
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
568
+ "Notice they are separated by the | character in the following list:\n",
569
+ "\n",
570
+ "Atlanta Hawks|ATL\n",
571
+ "Boston Celtics|BOS\n",
572
+ "Cleveland Cavaliers|CLE\n",
573
+ "New Orleans Pelicans|NOP\n",
574
+ "Chicago Bulls|CHI\n",
575
+ "Dallas Mavericks|DAL\n",
576
+ "Denver Nuggets|DEN\n",
577
+ "Golden State Warriors|GSW\n",
578
+ "Houston Rockets|HOU\n",
579
+ "Los Angeles Clippers|LAC\n",
580
+ "Los Angeles Lakers|LAL\n",
581
+ "Miami Heat|MIA\n",
582
+ "Milwaukee Bucks|MIL\n",
583
+ "Minnesota Timberwolves|MIN\n",
584
+ "Brooklyn Nets|BKN\n",
585
+ "New York Knicks|NYK\n",
586
+ "Orlando Magic|ORL\n",
587
+ "Indiana Pacers|IND\n",
588
+ "Philadelphia 76ers|PHI\n",
589
+ "Phoenix Suns|PHX\n",
590
+ "Portland Trail Blazers|POR\n",
591
+ "Sacramento Kings|SAC\n",
592
+ "San Antonio Spurs|SAS\n",
593
+ "Oklahoma City Thunder|OKC\n",
594
+ "Toronto Raptors|TOR\n",
595
+ "Utah Jazz|UTA\n",
596
+ "Memphis Grizzlies|MEM\n",
597
+ "Washington Wizards|WAS\n",
598
+ "Detroit Pistons|DET\n",
599
+ "Charlotte Hornets|CHA\n",
600
+ "\n",
601
+ "Query Guidelines\n",
602
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
603
+ "\n",
604
+ "To filter by season, use season_id = '2YYYY'.\n",
605
+ "\n",
606
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
607
+ "\n",
608
+ "Ensure queries return relevant columns and avoid unnecessary joins.\n",
609
+ "\n",
610
+ "Example User Requests and SQLite Queries\n",
611
+ "Request:\n",
612
+ "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
613
+ "SQLite:\n",
614
+ "SELECT MAX(pts_home) \n",
615
+ "FROM game \n",
616
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
617
+ "\n",
618
+ "Request:\n",
619
+ "\"Which teams are located in the state of California?\"\n",
620
+ "SQLite:\n",
621
+ "SELECT full_name FROM team WHERE state = 'California';\n",
622
+ "\n",
623
+ "Request:\n",
624
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
625
+ "SQLite:\n",
626
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
627
+ "\n",
628
+ "Request:\n",
629
+ "\"Which teams were founded before 1979?\"\n",
630
+ "SQLite:\n",
631
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
632
+ "\n",
633
+ "Request:\n",
634
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
635
+ "SQLite:\n",
636
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
637
+ "FROM game\n",
638
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
639
+ "\n",
640
+ "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
641
+ "How many points did the Golden State Warriors score in their first game of the 2005 season?\n",
642
+ "SQLite: \n",
643
+ "SELECT pts_home FROM game WHERE team_abbreviation_home = 'GSW' AND season_id = '22005' ORDER BY game_date ASC LIMIT 1;\n",
644
  "939\n",
645
  "105\n"
646
  ]
 
679
  " Tokenizes input natural language queries and corresponding SQL queries.\n",
680
  " \"\"\"\n",
681
  " inputs = [input_prompt + q for q in examples[\"natural_query\"]]\n",
682
+ " targets = [\"SQLite: \\n\" + q for q in examples[\"sql_query\"]]\n",
683
+ "\n",
684
+ " print(inputs[0])\n",
685
+ " print(targets[0])\n",
686
  "\n",
687
  " model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=256)\n",
688
  " labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=256)\n",
 
782
  "text": [
783
  "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of πŸ€— Transformers. Use `eval_strategy` instead\n",
784
  " warnings.warn(\n",
785
+ "C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_9248\\2737143648.py:17: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
786
  " trainer = Trainer(\n",
787
  "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
788
  ]
 
793
  " output_dir=\"./fine-tuned-model\",\n",
794
  " evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
795
  " save_strategy=\"epoch\", # Save model every epoch\n",
796
+ " per_device_train_batch_size=16, # LoRA allows higher batch size\n",
797
+ " per_device_eval_batch_size=16,\n",
798
+ " num_train_epochs=5, # Increase if needed\n",
799
+ " learning_rate=5e-5, # Higher LR since we're only training LoRA layers\n",
800
  " weight_decay=0.01,\n",
801
  " logging_steps=50, # Print loss every 50 steps\n",
802
  " save_total_limit=2, # Keep last 2 checkpoints\n",
 
840
  "\n",
841
  " <div>\n",
842
  " \n",
843
+ " <progress value='295' max='295' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
844
+ " [295/295 1:04:01, Epoch 5/5]\n",
845
  " </div>\n",
846
  " <table border=\"1\" class=\"dataframe\">\n",
847
  " <thead>\n",
 
854
  " <tbody>\n",
855
  " <tr>\n",
856
  " <td>1</td>\n",
857
+ " <td>8.510300</td>\n",
858
+ " <td>1.486602</td>\n",
859
  " </tr>\n",
860
  " <tr>\n",
861
  " <td>2</td>\n",
862
+ " <td>1.709800</td>\n",
863
+ " <td>1.227304</td>\n",
864
  " </tr>\n",
865
  " <tr>\n",
866
  " <td>3</td>\n",
867
+ " <td>1.542100</td>\n",
868
+ " <td>1.161120</td>\n",
869
+ " </tr>\n",
870
+ " <tr>\n",
871
+ " <td>4</td>\n",
872
+ " <td>1.487500</td>\n",
873
+ " <td>1.153255</td>\n",
874
+ " </tr>\n",
875
+ " <tr>\n",
876
+ " <td>5</td>\n",
877
+ " <td>1.388300</td>\n",
878
+ " <td>1.121680</td>\n",
879
  " </tr>\n",
880
  " </tbody>\n",
881
  "</table><p>"
 
927
  },
928
  {
929
  "cell_type": "code",
930
+ "execution_count": 7,
931
  "metadata": {},
932
  "outputs": [
 
 
 
 
 
 
 
 
933
  {
934
  "name": "stdout",
935
  "output_type": "stream",
936
  "text": [
937
+ "Generated SQL: SQLite: SELECT AVG(pts_home) FROM game WHERE team_name_home = 'Los Angeles Lakers';\n",
938
+ "\n"
939
  ]
940
  }
941
  ],
 
944
  "tokenizer = AutoTokenizer.from_pretrained(\"./fine-tuned-model\")\n",
945
  "\n",
946
  "# Prepare query with the same prompt\n",
947
+ "input_text = \"How many points to the Los Angeles Lakers average at home?\"\n",
948
  "message = [{'role': 'user', 'content': input_prompt + input_text}]\n",
949
  "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
950
  "\n",