First decent attempt, seems to generate valid queries after pre-training. Needs evaluation still
Browse files- fine-tuned-model/{checkpoint-236 β checkpoint-295}/README.md +0 -0
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/adapter_config.json +3 -3
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/adapter_model.safetensors +1 -1
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/optimizer.pt +1 -1
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/rng_state.pth +1 -1
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/scaler.pt +1 -1
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/scheduler.pt +1 -1
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/special_tokens_map.json +0 -0
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/tokenizer.json +0 -0
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/tokenizer_config.json +0 -0
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/trainer_state.json +64 -33
- fine-tuned-model/{checkpoint-236 β checkpoint-295}/training_args.bin +1 -1
- fine-tuned-model/model.safetensors +1 -1
- fine-tuned-model/runs/Apr03_12-32-10_DESKTOP-SMJC97K/events.out.tfevents.1743708731.DESKTOP-SMJC97K.4788.0 +3 -0
- fine-tuned-model/runs/Apr03_12-35-47_DESKTOP-SMJC97K/events.out.tfevents.1743708948.DESKTOP-SMJC97K.21476.0 +3 -0
- fine-tuned-model/runs/Apr03_14-34-10_DESKTOP-SMJC97K/events.out.tfevents.1743716051.DESKTOP-SMJC97K.9248.0 +3 -0
- finetune_model.ipynb +415 -26
fine-tuned-model/{checkpoint-236 β checkpoint-295}/README.md
RENAMED
File without changes
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/adapter_config.json
RENAMED
@@ -24,10 +24,10 @@
|
|
24 |
"rank_pattern": {},
|
25 |
"revision": null,
|
26 |
"target_modules": [
|
27 |
-
"q_proj",
|
28 |
"k_proj",
|
29 |
-
"
|
30 |
-
"o_proj"
|
|
|
31 |
],
|
32 |
"task_type": "CAUSAL_LM",
|
33 |
"trainable_token_indices": null,
|
|
|
24 |
"rank_pattern": {},
|
25 |
"revision": null,
|
26 |
"target_modules": [
|
|
|
27 |
"k_proj",
|
28 |
+
"q_proj",
|
29 |
+
"o_proj",
|
30 |
+
"v_proj"
|
31 |
],
|
32 |
"task_type": "CAUSAL_LM",
|
33 |
"trainable_token_indices": null,
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/adapter_model.safetensors
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 25191536
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5dbc52cfab26c0dfe8ab09edbd2f3c33aa892b2ea1e668a652968307cf887d90
|
3 |
size 25191536
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/optimizer.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 50492858
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc34cbc76e28ac17c0afc1333f5d7b30a8beebce84a6f3e73589b31ff00733f2
|
3 |
size 50492858
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/rng_state.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 14244
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:47500e947b0979f4187ced89844b0b41af88c14cc3ed27ad8cb01fdb1072cb88
|
3 |
size 14244
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/scaler.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 988
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a0cac5ada32db5a3f9944a8e52bf38c14646b18fcd0190b4fb8155b619d8f5ab
|
3 |
size 988
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/scheduler.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1064
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ee2a611a1529c905642325059ae97e660232f5907273b45d09703eb7bc5fa03
|
3 |
size 1064
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/special_tokens_map.json
RENAMED
File without changes
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/tokenizer.json
RENAMED
File without changes
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/tokenizer_config.json
RENAMED
File without changes
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/trainer_state.json
RENAMED
@@ -2,62 +2,93 @@
|
|
2 |
"best_global_step": null,
|
3 |
"best_metric": null,
|
4 |
"best_model_checkpoint": null,
|
5 |
-
"epoch":
|
6 |
"eval_steps": 500,
|
7 |
-
"global_step":
|
8 |
"is_hyper_param_search": false,
|
9 |
"is_local_process_zero": true,
|
10 |
"is_world_process_zero": true,
|
11 |
"log_history": [
|
12 |
{
|
13 |
-
"epoch": 0.
|
14 |
-
"grad_norm":
|
15 |
-
"learning_rate":
|
16 |
-
"loss":
|
17 |
"step": 50
|
18 |
},
|
19 |
{
|
20 |
-
"epoch": 0
|
21 |
-
"
|
22 |
-
"
|
23 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
"step": 100
|
25 |
},
|
26 |
{
|
27 |
-
"epoch":
|
28 |
-
"eval_loss":
|
29 |
-
"eval_runtime": 5.
|
30 |
-
"eval_samples_per_second":
|
31 |
-
"eval_steps_per_second":
|
32 |
"step": 118
|
33 |
},
|
34 |
{
|
35 |
-
"epoch":
|
36 |
-
"grad_norm":
|
37 |
-
"learning_rate":
|
38 |
-
"loss": 1.
|
39 |
"step": 150
|
40 |
},
|
41 |
{
|
42 |
-
"epoch":
|
43 |
-
"
|
44 |
-
"
|
45 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"step": 200
|
47 |
},
|
48 |
{
|
49 |
-
"epoch":
|
50 |
-
"eval_loss":
|
51 |
-
"eval_runtime": 5.
|
52 |
-
"eval_samples_per_second":
|
53 |
-
"eval_steps_per_second":
|
54 |
"step": 236
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
}
|
56 |
],
|
57 |
"logging_steps": 50,
|
58 |
-
"max_steps":
|
59 |
"num_input_tokens_seen": 0,
|
60 |
-
"num_train_epochs":
|
61 |
"save_steps": 500,
|
62 |
"stateful_callbacks": {
|
63 |
"TrainerControl": {
|
@@ -66,13 +97,13 @@
|
|
66 |
"should_evaluate": false,
|
67 |
"should_log": false,
|
68 |
"should_save": true,
|
69 |
-
"should_training_stop":
|
70 |
},
|
71 |
"attributes": {}
|
72 |
}
|
73 |
},
|
74 |
-
"total_flos":
|
75 |
-
"train_batch_size":
|
76 |
"trial_name": null,
|
77 |
"trial_params": null
|
78 |
}
|
|
|
2 |
"best_global_step": null,
|
3 |
"best_metric": null,
|
4 |
"best_model_checkpoint": null,
|
5 |
+
"epoch": 5.0,
|
6 |
"eval_steps": 500,
|
7 |
+
"global_step": 295,
|
8 |
"is_hyper_param_search": false,
|
9 |
"is_local_process_zero": true,
|
10 |
"is_world_process_zero": true,
|
11 |
"log_history": [
|
12 |
{
|
13 |
+
"epoch": 0.847457627118644,
|
14 |
+
"grad_norm": 3.6365857124328613,
|
15 |
+
"learning_rate": 4.152542372881356e-05,
|
16 |
+
"loss": 8.5103,
|
17 |
"step": 50
|
18 |
},
|
19 |
{
|
20 |
+
"epoch": 1.0,
|
21 |
+
"eval_loss": 1.4866021871566772,
|
22 |
+
"eval_runtime": 5.4305,
|
23 |
+
"eval_samples_per_second": 19.335,
|
24 |
+
"eval_steps_per_second": 1.289,
|
25 |
+
"step": 59
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"epoch": 1.694915254237288,
|
29 |
+
"grad_norm": 3.137465715408325,
|
30 |
+
"learning_rate": 3.305084745762712e-05,
|
31 |
+
"loss": 1.7098,
|
32 |
"step": 100
|
33 |
},
|
34 |
{
|
35 |
+
"epoch": 2.0,
|
36 |
+
"eval_loss": 1.2273037433624268,
|
37 |
+
"eval_runtime": 5.362,
|
38 |
+
"eval_samples_per_second": 19.582,
|
39 |
+
"eval_steps_per_second": 1.305,
|
40 |
"step": 118
|
41 |
},
|
42 |
{
|
43 |
+
"epoch": 2.542372881355932,
|
44 |
+
"grad_norm": 1.6243258714675903,
|
45 |
+
"learning_rate": 2.457627118644068e-05,
|
46 |
+
"loss": 1.5421,
|
47 |
"step": 150
|
48 |
},
|
49 |
{
|
50 |
+
"epoch": 3.0,
|
51 |
+
"eval_loss": 1.1611202955245972,
|
52 |
+
"eval_runtime": 5.348,
|
53 |
+
"eval_samples_per_second": 19.634,
|
54 |
+
"eval_steps_per_second": 1.309,
|
55 |
+
"step": 177
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"epoch": 3.389830508474576,
|
59 |
+
"grad_norm": 1.7812302112579346,
|
60 |
+
"learning_rate": 1.6101694915254237e-05,
|
61 |
+
"loss": 1.4875,
|
62 |
"step": 200
|
63 |
},
|
64 |
{
|
65 |
+
"epoch": 4.0,
|
66 |
+
"eval_loss": 1.153254508972168,
|
67 |
+
"eval_runtime": 5.347,
|
68 |
+
"eval_samples_per_second": 19.637,
|
69 |
+
"eval_steps_per_second": 1.309,
|
70 |
"step": 236
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"epoch": 4.237288135593221,
|
74 |
+
"grad_norm": 2.1582489013671875,
|
75 |
+
"learning_rate": 7.627118644067798e-06,
|
76 |
+
"loss": 1.3883,
|
77 |
+
"step": 250
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"epoch": 5.0,
|
81 |
+
"eval_loss": 1.1216797828674316,
|
82 |
+
"eval_runtime": 5.3095,
|
83 |
+
"eval_samples_per_second": 19.776,
|
84 |
+
"eval_steps_per_second": 1.318,
|
85 |
+
"step": 295
|
86 |
}
|
87 |
],
|
88 |
"logging_steps": 50,
|
89 |
+
"max_steps": 295,
|
90 |
"num_input_tokens_seen": 0,
|
91 |
+
"num_train_epochs": 5,
|
92 |
"save_steps": 500,
|
93 |
"stateful_callbacks": {
|
94 |
"TrainerControl": {
|
|
|
97 |
"should_evaluate": false,
|
98 |
"should_log": false,
|
99 |
"should_save": true,
|
100 |
+
"should_training_stop": true
|
101 |
},
|
102 |
"attributes": {}
|
103 |
}
|
104 |
},
|
105 |
+
"total_flos": 9279085168558080.0,
|
106 |
+
"train_batch_size": 16,
|
107 |
"trial_name": null,
|
108 |
"trial_params": null
|
109 |
}
|
fine-tuned-model/{checkpoint-236 β checkpoint-295}/training_args.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:145881f21d543612154f9ea67fda44fa38754eb20bd4eb13e1492be30ee670d7
|
3 |
size 5368
|
fine-tuned-model/model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1480793144
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9062451f577293bdd513174888e96a053a05c6f763eabe4c1196ad82987caf2c
|
3 |
size 1480793144
|
fine-tuned-model/runs/Apr03_12-32-10_DESKTOP-SMJC97K/events.out.tfevents.1743708731.DESKTOP-SMJC97K.4788.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e13d3e6b9cb02db1498428eecc437b77881de7829603019b54d53fcc6c0a9c7
|
3 |
+
size 5656
|
fine-tuned-model/runs/Apr03_12-35-47_DESKTOP-SMJC97K/events.out.tfevents.1743708948.DESKTOP-SMJC97K.21476.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5d2c6a4697d12bb25ba105c3c95241aa9d53ab1ea6c5a0116ec7a5416885586
|
3 |
+
size 8402
|
fine-tuned-model/runs/Apr03_14-34-10_DESKTOP-SMJC97K/events.out.tfevents.1743716051.DESKTOP-SMJC97K.9248.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1c85607a6a7da30d998111f662d1777c8550adb9a1d046bfae143ea1947ce2a
|
3 |
+
size 8402
|
finetune_model.ipynb
CHANGED
@@ -251,13 +251,396 @@
|
|
251 |
"name": "stderr",
|
252 |
"output_type": "stream",
|
253 |
"text": [
|
254 |
-
"Map:
|
255 |
]
|
256 |
},
|
257 |
{
|
258 |
"name": "stdout",
|
259 |
"output_type": "stream",
|
260 |
"text": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
"939\n",
|
262 |
"105\n"
|
263 |
]
|
@@ -296,7 +679,10 @@
|
|
296 |
" Tokenizes input natural language queries and corresponding SQL queries.\n",
|
297 |
" \"\"\"\n",
|
298 |
" inputs = [input_prompt + q for q in examples[\"natural_query\"]]\n",
|
299 |
-
" targets = examples[\"sql_query\"]\n",
|
|
|
|
|
|
|
300 |
"\n",
|
301 |
" model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=256)\n",
|
302 |
" labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=256)\n",
|
@@ -396,7 +782,7 @@
|
|
396 |
"text": [
|
397 |
"c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of π€ Transformers. Use `eval_strategy` instead\n",
|
398 |
" warnings.warn(\n",
|
399 |
-
"C:\\Users\\Dean\\AppData\\Local\\Temp\\
|
400 |
" trainer = Trainer(\n",
|
401 |
"No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
|
402 |
]
|
@@ -407,10 +793,10 @@
|
|
407 |
" output_dir=\"./fine-tuned-model\",\n",
|
408 |
" evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
|
409 |
" save_strategy=\"epoch\", # Save model every epoch\n",
|
410 |
-
" per_device_train_batch_size=
|
411 |
-
" per_device_eval_batch_size=
|
412 |
-
" num_train_epochs=
|
413 |
-
" learning_rate=5e-
|
414 |
" weight_decay=0.01,\n",
|
415 |
" logging_steps=50, # Print loss every 50 steps\n",
|
416 |
" save_total_limit=2, # Keep last 2 checkpoints\n",
|
@@ -454,8 +840,8 @@
|
|
454 |
"\n",
|
455 |
" <div>\n",
|
456 |
" \n",
|
457 |
-
" <progress value='
|
458 |
-
" [
|
459 |
" </div>\n",
|
460 |
" <table border=\"1\" class=\"dataframe\">\n",
|
461 |
" <thead>\n",
|
@@ -468,18 +854,28 @@
|
|
468 |
" <tbody>\n",
|
469 |
" <tr>\n",
|
470 |
" <td>1</td>\n",
|
471 |
-
" <td>
|
472 |
-
" <td>
|
473 |
" </tr>\n",
|
474 |
" <tr>\n",
|
475 |
" <td>2</td>\n",
|
476 |
-
" <td>1.
|
477 |
-
" <td>
|
478 |
" </tr>\n",
|
479 |
" <tr>\n",
|
480 |
" <td>3</td>\n",
|
481 |
-
" <td>1.
|
482 |
-
" <td>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
" </tr>\n",
|
484 |
" </tbody>\n",
|
485 |
"</table><p>"
|
@@ -531,22 +927,15 @@
|
|
531 |
},
|
532 |
{
|
533 |
"cell_type": "code",
|
534 |
-
"execution_count":
|
535 |
"metadata": {},
|
536 |
"outputs": [
|
537 |
-
{
|
538 |
-
"name": "stderr",
|
539 |
-
"output_type": "stream",
|
540 |
-
"text": [
|
541 |
-
"c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
|
542 |
-
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
|
543 |
-
]
|
544 |
-
},
|
545 |
{
|
546 |
"name": "stdout",
|
547 |
"output_type": "stream",
|
548 |
"text": [
|
549 |
-
"Generated SQL:
|
|
|
550 |
]
|
551 |
}
|
552 |
],
|
@@ -555,7 +944,7 @@
|
|
555 |
"tokenizer = AutoTokenizer.from_pretrained(\"./fine-tuned-model\")\n",
|
556 |
"\n",
|
557 |
"# Prepare query with the same prompt\n",
|
558 |
-
"input_text = \"
|
559 |
"message = [{'role': 'user', 'content': input_prompt + input_text}]\n",
|
560 |
"inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
|
561 |
"\n",
|
|
|
251 |
"name": "stderr",
|
252 |
"output_type": "stream",
|
253 |
"text": [
|
254 |
+
"Map: 0%| | 0/1044 [00:00<?, ? examples/s]"
|
255 |
]
|
256 |
},
|
257 |
{
|
258 |
"name": "stdout",
|
259 |
"output_type": "stream",
|
260 |
"text": [
|
261 |
+
"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
|
262 |
+
"Database Schema and Explanations\n",
|
263 |
+
"\n",
|
264 |
+
"team Table\n",
|
265 |
+
"Stores information about NBA teams.\n",
|
266 |
+
"CREATE TABLE IF NOT EXISTS \"team\" (\n",
|
267 |
+
" \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
|
268 |
+
" \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
|
269 |
+
" \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
|
270 |
+
" \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
|
271 |
+
" \"city\" TEXT, -- City where the team is based\n",
|
272 |
+
" \"state\" TEXT, -- State where the team is located\n",
|
273 |
+
" \"year_founded\" REAL -- Year the team was established\n",
|
274 |
+
");\n",
|
275 |
+
"\n",
|
276 |
+
"game Table\n",
|
277 |
+
"Contains detailed statistics for each NBA game, including home and away team performance.\n",
|
278 |
+
"CREATE TABLE IF NOT EXISTS \"game\" (\n",
|
279 |
+
" \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
|
280 |
+
" \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
|
281 |
+
" \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
|
282 |
+
" \"team_name_home\" TEXT, -- Full name of the home team\n",
|
283 |
+
" \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
|
284 |
+
" \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
|
285 |
+
" \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
|
286 |
+
" \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
|
287 |
+
" \"min\" INTEGER, -- Total minutes played in the game\n",
|
288 |
+
" \"fgm_home\" REAL, -- Field goals made by the home team\n",
|
289 |
+
" \"fga_home\" REAL, -- Field goals attempted by the home team\n",
|
290 |
+
" \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
|
291 |
+
" \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
|
292 |
+
" \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
|
293 |
+
" \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
|
294 |
+
" \"ftm_home\" REAL, -- Free throws made by the home team\n",
|
295 |
+
" \"fta_home\" REAL, -- Free throws attempted by the home team\n",
|
296 |
+
" \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
|
297 |
+
" \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
|
298 |
+
" \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
|
299 |
+
" \"reb_home\" REAL, -- Total rebounds by the home team\n",
|
300 |
+
" \"ast_home\" REAL, -- Assists by the home team\n",
|
301 |
+
" \"stl_home\" REAL, -- Steals by the home team\n",
|
302 |
+
" \"blk_home\" REAL, -- Blocks by the home team\n",
|
303 |
+
" \"tov_home\" REAL, -- Turnovers by the home team\n",
|
304 |
+
" \"pf_home\" REAL, -- Personal fouls by the home team\n",
|
305 |
+
" \"pts_home\" REAL, -- Total points scored by the home team\n",
|
306 |
+
" \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
|
307 |
+
" \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
308 |
+
" \"team_id_away\" TEXT, -- ID of the away team\n",
|
309 |
+
" \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
|
310 |
+
" \"team_name_away\" TEXT, -- Full name of the away team\n",
|
311 |
+
" \"matchup_away\" TEXT, -- Matchup details from the away teamβs perspective\n",
|
312 |
+
" \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
|
313 |
+
" \"fgm_away\" REAL, -- Field goals made by the away team\n",
|
314 |
+
" \"fga_away\" REAL, -- Field goals attempted by the away team\n",
|
315 |
+
" \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
|
316 |
+
" \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
|
317 |
+
" \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
|
318 |
+
" \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
|
319 |
+
" \"ftm_away\" REAL, -- Free throws made by the away team\n",
|
320 |
+
" \"fta_away\" REAL, -- Free throws attempted by the away team\n",
|
321 |
+
" \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
|
322 |
+
" \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
|
323 |
+
" \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
|
324 |
+
" \"reb_away\" REAL, -- Total rebounds by the away team\n",
|
325 |
+
" \"ast_away\" REAL, -- Assists by the away team\n",
|
326 |
+
" \"stl_away\" REAL, -- Steals by the away team\n",
|
327 |
+
" \"blk_away\" REAL, -- Blocks by the away team\n",
|
328 |
+
" \"tov_away\" REAL, -- Turnovers by the away team\n",
|
329 |
+
" \"pf_away\" REAL, -- Personal fouls by the away team\n",
|
330 |
+
" \"pts_away\" REAL, -- Total points scored by the away team\n",
|
331 |
+
" \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
|
332 |
+
" \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
333 |
+
" \"season_type\" TEXT -- Regular season or playoffs\n",
|
334 |
+
");\n",
|
335 |
+
"\n",
|
336 |
+
"other_stats Table\n",
|
337 |
+
"Stores additional statistics, linked to the game table via game_id.\n",
|
338 |
+
"CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
|
339 |
+
" \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
|
340 |
+
" \"league_id\" TEXT, -- League identifier\n",
|
341 |
+
" \"team_id_home\" TEXT, -- Home team identifier\n",
|
342 |
+
" \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
|
343 |
+
" \"team_city_home\" TEXT, -- Home team city\n",
|
344 |
+
" \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
|
345 |
+
" \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
|
346 |
+
" \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
|
347 |
+
" \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
|
348 |
+
" \"lead_changes\" INTEGER, -- Number of lead changes \n",
|
349 |
+
" \"times_tied\" INTEGER, -- Number of times the score was tied\n",
|
350 |
+
" \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
|
351 |
+
" \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
|
352 |
+
" \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
|
353 |
+
" \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
|
354 |
+
" \"team_id_away\" TEXT, -- Away team identifier\n",
|
355 |
+
" \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
|
356 |
+
" \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
|
357 |
+
" \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
|
358 |
+
" \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
|
359 |
+
" \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
|
360 |
+
" \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
|
361 |
+
" \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
|
362 |
+
" \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
|
363 |
+
" \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
|
364 |
+
");\n",
|
365 |
+
"\n",
|
366 |
+
"\n",
|
367 |
+
"Team Name Information\n",
|
368 |
+
"In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
|
369 |
+
"The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
|
370 |
+
"Notice they are separated by the | character in the following list:\n",
|
371 |
+
"\n",
|
372 |
+
"Atlanta Hawks|ATL\n",
|
373 |
+
"Boston Celtics|BOS\n",
|
374 |
+
"Cleveland Cavaliers|CLE\n",
|
375 |
+
"New Orleans Pelicans|NOP\n",
|
376 |
+
"Chicago Bulls|CHI\n",
|
377 |
+
"Dallas Mavericks|DAL\n",
|
378 |
+
"Denver Nuggets|DEN\n",
|
379 |
+
"Golden State Warriors|GSW\n",
|
380 |
+
"Houston Rockets|HOU\n",
|
381 |
+
"Los Angeles Clippers|LAC\n",
|
382 |
+
"Los Angeles Lakers|LAL\n",
|
383 |
+
"Miami Heat|MIA\n",
|
384 |
+
"Milwaukee Bucks|MIL\n",
|
385 |
+
"Minnesota Timberwolves|MIN\n",
|
386 |
+
"Brooklyn Nets|BKN\n",
|
387 |
+
"New York Knicks|NYK\n",
|
388 |
+
"Orlando Magic|ORL\n",
|
389 |
+
"Indiana Pacers|IND\n",
|
390 |
+
"Philadelphia 76ers|PHI\n",
|
391 |
+
"Phoenix Suns|PHX\n",
|
392 |
+
"Portland Trail Blazers|POR\n",
|
393 |
+
"Sacramento Kings|SAC\n",
|
394 |
+
"San Antonio Spurs|SAS\n",
|
395 |
+
"Oklahoma City Thunder|OKC\n",
|
396 |
+
"Toronto Raptors|TOR\n",
|
397 |
+
"Utah Jazz|UTA\n",
|
398 |
+
"Memphis Grizzlies|MEM\n",
|
399 |
+
"Washington Wizards|WAS\n",
|
400 |
+
"Detroit Pistons|DET\n",
|
401 |
+
"Charlotte Hornets|CHA\n",
|
402 |
+
"\n",
|
403 |
+
"Query Guidelines\n",
|
404 |
+
"Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
|
405 |
+
"\n",
|
406 |
+
"To filter by season, use season_id = '2YYYY'.\n",
|
407 |
+
"\n",
|
408 |
+
"Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
|
409 |
+
"\n",
|
410 |
+
"Ensure queries return relevant columns and avoid unnecessary joins.\n",
|
411 |
+
"\n",
|
412 |
+
"Example User Requests and SQLite Queries\n",
|
413 |
+
"Request:\n",
|
414 |
+
"\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
|
415 |
+
"SQLite:\n",
|
416 |
+
"SELECT MAX(pts_home) \n",
|
417 |
+
"FROM game \n",
|
418 |
+
"WHERE team_name_home = 'Los Angeles Lakers';\n",
|
419 |
+
"\n",
|
420 |
+
"Request:\n",
|
421 |
+
"\"Which teams are located in the state of California?\"\n",
|
422 |
+
"SQLite:\n",
|
423 |
+
"SELECT full_name FROM team WHERE state = 'California';\n",
|
424 |
+
"\n",
|
425 |
+
"Request:\n",
|
426 |
+
"\"Which team had the highest number of team turnovers in an away game?\"\n",
|
427 |
+
"SQLite:\n",
|
428 |
+
"SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
|
429 |
+
"\n",
|
430 |
+
"Request:\n",
|
431 |
+
"\"Which teams were founded before 1979?\"\n",
|
432 |
+
"SQLite:\n",
|
433 |
+
"SELECT full_name FROM team WHERE year_founded < 1979;\n",
|
434 |
+
"\n",
|
435 |
+
"Request:\n",
|
436 |
+
"\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
|
437 |
+
"SQLite:\n",
|
438 |
+
"SELECT MAX(pts_home - pts_away) AS biggest_win\n",
|
439 |
+
"FROM game\n",
|
440 |
+
"WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
|
441 |
+
"\n",
|
442 |
+
"Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
|
443 |
+
"Which NBA teams were established after the year 2000? List their names and founding years, sorted from newest to oldest\n",
|
444 |
+
"SQLite: \n",
|
445 |
+
"SELECT full_name FROM team WHERE year_founded > 2000 ORDER BY year_founded DESC;\n"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"name": "stderr",
|
450 |
+
"output_type": "stream",
|
451 |
+
"text": [
|
452 |
+
"Map: 100%|ββββββββββ| 1044/1044 [00:01<00:00, 552.67 examples/s]"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"name": "stdout",
|
457 |
+
"output_type": "stream",
|
458 |
+
"text": [
|
459 |
+
"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
|
460 |
+
"Database Schema and Explanations\n",
|
461 |
+
"\n",
|
462 |
+
"team Table\n",
|
463 |
+
"Stores information about NBA teams.\n",
|
464 |
+
"CREATE TABLE IF NOT EXISTS \"team\" (\n",
|
465 |
+
" \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
|
466 |
+
" \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
|
467 |
+
" \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
|
468 |
+
" \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
|
469 |
+
" \"city\" TEXT, -- City where the team is based\n",
|
470 |
+
" \"state\" TEXT, -- State where the team is located\n",
|
471 |
+
" \"year_founded\" REAL -- Year the team was established\n",
|
472 |
+
");\n",
|
473 |
+
"\n",
|
474 |
+
"game Table\n",
|
475 |
+
"Contains detailed statistics for each NBA game, including home and away team performance.\n",
|
476 |
+
"CREATE TABLE IF NOT EXISTS \"game\" (\n",
|
477 |
+
" \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
|
478 |
+
" \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
|
479 |
+
" \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
|
480 |
+
" \"team_name_home\" TEXT, -- Full name of the home team\n",
|
481 |
+
" \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
|
482 |
+
" \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
|
483 |
+
" \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
|
484 |
+
" \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
|
485 |
+
" \"min\" INTEGER, -- Total minutes played in the game\n",
|
486 |
+
" \"fgm_home\" REAL, -- Field goals made by the home team\n",
|
487 |
+
" \"fga_home\" REAL, -- Field goals attempted by the home team\n",
|
488 |
+
" \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
|
489 |
+
" \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
|
490 |
+
" \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
|
491 |
+
" \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
|
492 |
+
" \"ftm_home\" REAL, -- Free throws made by the home team\n",
|
493 |
+
" \"fta_home\" REAL, -- Free throws attempted by the home team\n",
|
494 |
+
" \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
|
495 |
+
" \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
|
496 |
+
" \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
|
497 |
+
" \"reb_home\" REAL, -- Total rebounds by the home team\n",
|
498 |
+
" \"ast_home\" REAL, -- Assists by the home team\n",
|
499 |
+
" \"stl_home\" REAL, -- Steals by the home team\n",
|
500 |
+
" \"blk_home\" REAL, -- Blocks by the home team\n",
|
501 |
+
" \"tov_home\" REAL, -- Turnovers by the home team\n",
|
502 |
+
" \"pf_home\" REAL, -- Personal fouls by the home team\n",
|
503 |
+
" \"pts_home\" REAL, -- Total points scored by the home team\n",
|
504 |
+
" \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
|
505 |
+
" \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
506 |
+
" \"team_id_away\" TEXT, -- ID of the away team\n",
|
507 |
+
" \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
|
508 |
+
" \"team_name_away\" TEXT, -- Full name of the away team\n",
|
509 |
+
" \"matchup_away\" TEXT, -- Matchup details from the away teamβs perspective\n",
|
510 |
+
" \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
|
511 |
+
" \"fgm_away\" REAL, -- Field goals made by the away team\n",
|
512 |
+
" \"fga_away\" REAL, -- Field goals attempted by the away team\n",
|
513 |
+
" \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
|
514 |
+
" \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
|
515 |
+
" \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
|
516 |
+
" \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
|
517 |
+
" \"ftm_away\" REAL, -- Free throws made by the away team\n",
|
518 |
+
" \"fta_away\" REAL, -- Free throws attempted by the away team\n",
|
519 |
+
" \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
|
520 |
+
" \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
|
521 |
+
" \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
|
522 |
+
" \"reb_away\" REAL, -- Total rebounds by the away team\n",
|
523 |
+
" \"ast_away\" REAL, -- Assists by the away team\n",
|
524 |
+
" \"stl_away\" REAL, -- Steals by the away team\n",
|
525 |
+
" \"blk_away\" REAL, -- Blocks by the away team\n",
|
526 |
+
" \"tov_away\" REAL, -- Turnovers by the away team\n",
|
527 |
+
" \"pf_away\" REAL, -- Personal fouls by the away team\n",
|
528 |
+
" \"pts_away\" REAL, -- Total points scored by the away team\n",
|
529 |
+
" \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
|
530 |
+
" \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
531 |
+
" \"season_type\" TEXT -- Regular season or playoffs\n",
|
532 |
+
");\n",
|
533 |
+
"\n",
|
534 |
+
"other_stats Table\n",
|
535 |
+
"Stores additional statistics, linked to the game table via game_id.\n",
|
536 |
+
"CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
|
537 |
+
" \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
|
538 |
+
" \"league_id\" TEXT, -- League identifier\n",
|
539 |
+
" \"team_id_home\" TEXT, -- Home team identifier\n",
|
540 |
+
" \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
|
541 |
+
" \"team_city_home\" TEXT, -- Home team city\n",
|
542 |
+
" \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
|
543 |
+
" \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
|
544 |
+
" \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
|
545 |
+
" \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
|
546 |
+
" \"lead_changes\" INTEGER, -- Number of lead changes \n",
|
547 |
+
" \"times_tied\" INTEGER, -- Number of times the score was tied\n",
|
548 |
+
" \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
|
549 |
+
" \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
|
550 |
+
" \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
|
551 |
+
" \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
|
552 |
+
" \"team_id_away\" TEXT, -- Away team identifier\n",
|
553 |
+
" \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
|
554 |
+
" \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
|
555 |
+
" \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
|
556 |
+
" \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
|
557 |
+
" \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
|
558 |
+
" \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
|
559 |
+
" \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
|
560 |
+
" \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
|
561 |
+
" \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
|
562 |
+
");\n",
|
563 |
+
"\n",
|
564 |
+
"\n",
|
565 |
+
"Team Name Information\n",
|
566 |
+
"In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
|
567 |
+
"The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
|
568 |
+
"Notice they are separated by the | character in the following list:\n",
|
569 |
+
"\n",
|
570 |
+
"Atlanta Hawks|ATL\n",
|
571 |
+
"Boston Celtics|BOS\n",
|
572 |
+
"Cleveland Cavaliers|CLE\n",
|
573 |
+
"New Orleans Pelicans|NOP\n",
|
574 |
+
"Chicago Bulls|CHI\n",
|
575 |
+
"Dallas Mavericks|DAL\n",
|
576 |
+
"Denver Nuggets|DEN\n",
|
577 |
+
"Golden State Warriors|GSW\n",
|
578 |
+
"Houston Rockets|HOU\n",
|
579 |
+
"Los Angeles Clippers|LAC\n",
|
580 |
+
"Los Angeles Lakers|LAL\n",
|
581 |
+
"Miami Heat|MIA\n",
|
582 |
+
"Milwaukee Bucks|MIL\n",
|
583 |
+
"Minnesota Timberwolves|MIN\n",
|
584 |
+
"Brooklyn Nets|BKN\n",
|
585 |
+
"New York Knicks|NYK\n",
|
586 |
+
"Orlando Magic|ORL\n",
|
587 |
+
"Indiana Pacers|IND\n",
|
588 |
+
"Philadelphia 76ers|PHI\n",
|
589 |
+
"Phoenix Suns|PHX\n",
|
590 |
+
"Portland Trail Blazers|POR\n",
|
591 |
+
"Sacramento Kings|SAC\n",
|
592 |
+
"San Antonio Spurs|SAS\n",
|
593 |
+
"Oklahoma City Thunder|OKC\n",
|
594 |
+
"Toronto Raptors|TOR\n",
|
595 |
+
"Utah Jazz|UTA\n",
|
596 |
+
"Memphis Grizzlies|MEM\n",
|
597 |
+
"Washington Wizards|WAS\n",
|
598 |
+
"Detroit Pistons|DET\n",
|
599 |
+
"Charlotte Hornets|CHA\n",
|
600 |
+
"\n",
|
601 |
+
"Query Guidelines\n",
|
602 |
+
"Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
|
603 |
+
"\n",
|
604 |
+
"To filter by season, use season_id = '2YYYY'.\n",
|
605 |
+
"\n",
|
606 |
+
"Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
|
607 |
+
"\n",
|
608 |
+
"Ensure queries return relevant columns and avoid unnecessary joins.\n",
|
609 |
+
"\n",
|
610 |
+
"Example User Requests and SQLite Queries\n",
|
611 |
+
"Request:\n",
|
612 |
+
"\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
|
613 |
+
"SQLite:\n",
|
614 |
+
"SELECT MAX(pts_home) \n",
|
615 |
+
"FROM game \n",
|
616 |
+
"WHERE team_name_home = 'Los Angeles Lakers';\n",
|
617 |
+
"\n",
|
618 |
+
"Request:\n",
|
619 |
+
"\"Which teams are located in the state of California?\"\n",
|
620 |
+
"SQLite:\n",
|
621 |
+
"SELECT full_name FROM team WHERE state = 'California';\n",
|
622 |
+
"\n",
|
623 |
+
"Request:\n",
|
624 |
+
"\"Which team had the highest number of team turnovers in an away game?\"\n",
|
625 |
+
"SQLite:\n",
|
626 |
+
"SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
|
627 |
+
"\n",
|
628 |
+
"Request:\n",
|
629 |
+
"\"Which teams were founded before 1979?\"\n",
|
630 |
+
"SQLite:\n",
|
631 |
+
"SELECT full_name FROM team WHERE year_founded < 1979;\n",
|
632 |
+
"\n",
|
633 |
+
"Request:\n",
|
634 |
+
"\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
|
635 |
+
"SQLite:\n",
|
636 |
+
"SELECT MAX(pts_home - pts_away) AS biggest_win\n",
|
637 |
+
"FROM game\n",
|
638 |
+
"WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
|
639 |
+
"\n",
|
640 |
+
"Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
|
641 |
+
"How many points did the Golden State Warriors score in their first game of the 2005 season?\n",
|
642 |
+
"SQLite: \n",
|
643 |
+
"SELECT pts_home FROM game WHERE team_abbreviation_home = 'GSW' AND season_id = '22005' ORDER BY game_date ASC LIMIT 1;\n",
|
644 |
"939\n",
|
645 |
"105\n"
|
646 |
]
|
|
|
679 |
" Tokenizes input natural language queries and corresponding SQL queries.\n",
|
680 |
" \"\"\"\n",
|
681 |
" inputs = [input_prompt + q for q in examples[\"natural_query\"]]\n",
|
682 |
+
" targets = [\"SQLite: \\n\" + q for q in examples[\"sql_query\"]]\n",
|
683 |
+
"\n",
|
684 |
+
" print(inputs[0])\n",
|
685 |
+
" print(targets[0])\n",
|
686 |
"\n",
|
687 |
" model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=256)\n",
|
688 |
" labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=256)\n",
|
|
|
782 |
"text": [
|
783 |
"c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of π€ Transformers. Use `eval_strategy` instead\n",
|
784 |
" warnings.warn(\n",
|
785 |
+
"C:\\Users\\Dean\\AppData\\Local\\Temp\\ipykernel_9248\\2737143648.py:17: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
|
786 |
" trainer = Trainer(\n",
|
787 |
"No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
|
788 |
]
|
|
|
793 |
" output_dir=\"./fine-tuned-model\",\n",
|
794 |
" evaluation_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
|
795 |
" save_strategy=\"epoch\", # Save model every epoch\n",
|
796 |
+
" per_device_train_batch_size=16, # LoRA allows higher batch size\n",
|
797 |
+
" per_device_eval_batch_size=16,\n",
|
798 |
+
" num_train_epochs=5, # Increase if needed\n",
|
799 |
+
" learning_rate=5e-5, # Higher LR since we're only training LoRA layers\n",
|
800 |
" weight_decay=0.01,\n",
|
801 |
" logging_steps=50, # Print loss every 50 steps\n",
|
802 |
" save_total_limit=2, # Keep last 2 checkpoints\n",
|
|
|
840 |
"\n",
|
841 |
" <div>\n",
|
842 |
" \n",
|
843 |
+
" <progress value='295' max='295' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
844 |
+
" [295/295 1:04:01, Epoch 5/5]\n",
|
845 |
" </div>\n",
|
846 |
" <table border=\"1\" class=\"dataframe\">\n",
|
847 |
" <thead>\n",
|
|
|
854 |
" <tbody>\n",
|
855 |
" <tr>\n",
|
856 |
" <td>1</td>\n",
|
857 |
+
" <td>8.510300</td>\n",
|
858 |
+
" <td>1.486602</td>\n",
|
859 |
" </tr>\n",
|
860 |
" <tr>\n",
|
861 |
" <td>2</td>\n",
|
862 |
+
" <td>1.709800</td>\n",
|
863 |
+
" <td>1.227304</td>\n",
|
864 |
" </tr>\n",
|
865 |
" <tr>\n",
|
866 |
" <td>3</td>\n",
|
867 |
+
" <td>1.542100</td>\n",
|
868 |
+
" <td>1.161120</td>\n",
|
869 |
+
" </tr>\n",
|
870 |
+
" <tr>\n",
|
871 |
+
" <td>4</td>\n",
|
872 |
+
" <td>1.487500</td>\n",
|
873 |
+
" <td>1.153255</td>\n",
|
874 |
+
" </tr>\n",
|
875 |
+
" <tr>\n",
|
876 |
+
" <td>5</td>\n",
|
877 |
+
" <td>1.388300</td>\n",
|
878 |
+
" <td>1.121680</td>\n",
|
879 |
" </tr>\n",
|
880 |
" </tbody>\n",
|
881 |
"</table><p>"
|
|
|
927 |
},
|
928 |
{
|
929 |
"cell_type": "code",
|
930 |
+
"execution_count": 7,
|
931 |
"metadata": {},
|
932 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
{
|
934 |
"name": "stdout",
|
935 |
"output_type": "stream",
|
936 |
"text": [
|
937 |
+
"Generated SQL: SQLite: SELECT AVG(pts_home) FROM game WHERE team_name_home = 'Los Angeles Lakers';\n",
|
938 |
+
"\n"
|
939 |
]
|
940 |
}
|
941 |
],
|
|
|
944 |
"tokenizer = AutoTokenizer.from_pretrained(\"./fine-tuned-model\")\n",
|
945 |
"\n",
|
946 |
"# Prepare query with the same prompt\n",
|
947 |
+
"input_text = \"How many points to the Los Angeles Lakers average at home?\"\n",
|
948 |
"message = [{'role': 'user', 'content': input_prompt + input_text}]\n",
|
949 |
"inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
|
950 |
"\n",
|