kristiannordby commited on
Commit
428f4b1
·
verified ·
1 Parent(s): bb9728c

Upload promptTuningsql (1).ipynb

Browse files
Files changed (1) hide show
  1. promptTuningsql (1).ipynb +710 -0
promptTuningsql (1).ipynb ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5d69bd30-a4a5-47da-a1ce-b6f9f228b42c",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
14
+ "\u001b[0m\n",
15
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n",
16
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
17
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
18
+ "\u001b[0m\n",
19
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n",
20
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
26
+ "!pip install -q accelerate datasets peft bitsandbytes"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 1,
32
+ "id": "33d7d8f7-a2bd-4548-ac7f-45eba6ca1651",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "import torch\n",
37
+ "from datasets import load_dataset, Dataset\n",
38
+ "from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, Trainer\n",
39
+ "\n",
40
+ "from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PromptTuningConfig"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "id": "511a7b95-1089-4312-bc4a-40c843ea60f7",
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "data": {
51
+ "application/vnd.jupyter.widget-view+json": {
52
+ "model_id": "86bfa1c49f8b4fb5900506cdc7968886",
53
+ "version_major": 2,
54
+ "version_minor": 0
55
+ },
56
+ "text/plain": [
57
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
58
+ ]
59
+ },
60
+ "metadata": {},
61
+ "output_type": "display_data"
62
+ },
63
+ {
64
+ "name": "stderr",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
68
+ " warnings.warn(\n",
69
+ "/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:601: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
70
+ " warnings.warn(\n"
71
+ ]
72
+ },
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "trainable params: 81,920 || all params: 8,030,343,168 || trainable%: 0.0010\n"
78
+ ]
79
+ }
80
+ ],
81
+ "source": [
82
+ "model_name = \"defog/llama-3-sqlcoder-8b\"\n",
83
+ "\n",
84
+ "prompt_config = PromptTuningConfig(\n",
85
+ " num_virtual_tokens=20, # Number of prompt tokens to learn\n",
86
+ " task_type=\"CAUSAL_LM\", # Causal language modeling for SQL generation\n",
87
+ " tokenizer_name_or_path=model_name\n",
88
+ ")\n",
89
+ "\n",
90
+ "tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)\n",
91
+ "tokenizer.pad_token = tokenizer.eos_token\n",
92
+ "\n",
93
+ "model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(\"cuda\")\n",
94
+ "model = get_peft_model(model, prompt_config)\n",
95
+ "model.print_trainable_parameters()"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 3,
101
+ "id": "7bfb864d-6ad5-49fb-9e18-6d6e6d90373a",
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "data": {
106
+ "application/vnd.jupyter.widget-view+json": {
107
+ "model_id": "26656ca795e24d8483092fdc3e3d8954",
108
+ "version_major": 2,
109
+ "version_minor": 0
110
+ },
111
+ "text/plain": [
112
+ "Map: 0%| | 0/121 [00:00<?, ? examples/s]"
113
+ ]
114
+ },
115
+ "metadata": {},
116
+ "output_type": "display_data"
117
+ },
118
+ {
119
+ "data": {
120
+ "text/plain": [
121
+ "Dataset({\n",
122
+ " features: ['question', 'query', 'input_ids', 'attention_mask', 'labels'],\n",
123
+ " num_rows: 121\n",
124
+ "})"
125
+ ]
126
+ },
127
+ "execution_count": 3,
128
+ "metadata": {},
129
+ "output_type": "execute_result"
130
+ }
131
+ ],
132
+ "source": [
133
+ "import json\n",
134
+ "with open(\"syntheticTableData (1).json\",\"r\") as f: #SyntheticTableData (1) is the same as kristiannordby/text2sql121rows dataset in huggingface\n",
135
+ " data = json.load(f)\n",
136
+ "untokenized_dataset = Dataset.from_list(data)\n",
137
+ "\n",
138
+ "def preprocess_function(examples):\n",
139
+ " inputs = tokenizer(examples[\"question\"], padding=\"max_length\", truncation=True, max_length=512)\n",
140
+ " labels = tokenizer(examples[\"query\"], padding=\"max_length\", truncation=True, max_length=512)\n",
141
+ " labels[\"input_ids\"] = [-100 if token == tokenizer.pad_token_id else token for token in labels[\"input_ids\"]]\n",
142
+ " return {\"input_ids\": inputs[\"input_ids\"], \"attention_mask\": inputs[\"attention_mask\"], \"labels\": labels[\"input_ids\"]}\n",
143
+ "\n",
144
+ "ds = untokenized_dataset.map(preprocess_function, batched=True)\n",
145
+ "ds"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 10,
151
+ "id": "a0197d96",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "name": "stderr",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stdout",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "Generated SQL: Which car model from 2015 has the best miles-per-gallon, costs more than $30,000, and how many total miles has it driven?sonyoursite is there are you want to date:1.. Acura of which one! The answer will be a single line with three values separated by commas (e.g., \"Toyota Prius Hybrid\", \"$35k - \\$40K per year\").\" } { SELECT m.make AS Car_Model FROM cars c JOIN models ON CAST(c.model_id as integer) = id WHERE price > '30000' AND fuel_economy IS NOT NULL ORDER BY mileage DESC LIMIT 10;iвassistant\n",
166
+ "\n",
167
+ "I apologize for any confusion earlier.\n",
168
+ "\n",
169
+ "To clarify your question:\n",
170
+ "\n",
171
+ "You're asking me about what I can do if someone else's code or data causes an error in my own program?\n",
172
+ "\n",
173
+ "If that happens,\n",
174
+ "\n",
175
+ "* **Error Handling**: You should handle these errors properly using try-except blocks.\n",
176
+ " * For example:\n",
177
+ " ```\n",
178
+ " import requests\n",
179
+ " def get_data(url):\n",
180
+ " response=requests.get('https://api.example.com/data')\n",
181
+ " returnresponse.json()\n",
182
+ " \n"
183
+ ]
184
+ }
185
+ ],
186
+ "source": [
187
+ "import torch\n",
188
+ "\n",
189
+ "question = \"Which car model from 2015 has the best miles-per-gallon, costs more than $30,000, and how many total miles has it driven?\"\n",
190
+ "expected_sql_query = \"\"\"\n",
191
+ "SELECT make, model, mpg, totalMiles \n",
192
+ "FROM cars \n",
193
+ "WHERE modelYear = 2015 \n",
194
+ "AND sellPrice > 30000 \n",
195
+ "ORDER BY mpg DESC \n",
196
+ "LIMIT 1;\n",
197
+ "\"\"\"\n",
198
+ "\n",
199
+ "inputs = tokenizer(question, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=512).to(\"cuda\")\n",
200
+ "\n",
201
+ "model.eval()\n",
202
+ "\n",
203
+ "with torch.no_grad():\n",
204
+ " generated_ids = model.generate(\n",
205
+ " input_ids=inputs[\"input_ids\"],\n",
206
+ " attention_mask=inputs[\"attention_mask\"],\n",
207
+ " max_new_tokens=200, # need to adjust so model does not get off track; or could pull sql from it later\n",
208
+ " repetition_penalty=2.0,\n",
209
+ " early_stopping=True,\n",
210
+ " eos_token_id=tokenizer.eos_token_id, # Use greedy decoding for deterministic output\n",
211
+ " )\n",
212
+ "\n",
213
+ "\n",
214
+ "generated_sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
215
+ "print(f\"Generated SQL: {generated_sql_query}\")"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 6,
221
+ "id": "f76849ea-fac9-4ef3-a02b-b56414e25e61",
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "name": "stderr",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
229
+ ]
230
+ }
231
+ ],
232
+ "source": [
233
+ "from transformers import Trainer, TrainingArguments\n",
234
+ "\n",
235
+ "training_args = TrainingArguments(\n",
236
+ " output_dir=\"./results\",\n",
237
+ " per_device_train_batch_size=2, \n",
238
+ " gradient_accumulation_steps=4, \n",
239
+ " num_train_epochs=50, # More epochs for a small dataset\n",
240
+ " learning_rate=5e-5, \n",
241
+ " eval_strategy=\"steps\",\n",
242
+ " eval_steps=20,\n",
243
+ " save_steps=20,\n",
244
+ " logging_dir=\"./logs\",\n",
245
+ " logging_steps=10,\n",
246
+ " save_total_limit=1,\n",
247
+ " weight_decay=0.01,\n",
248
+ ")\n",
249
+ "\n",
250
+ "trainer = Trainer(\n",
251
+ " model=model,\n",
252
+ " args=training_args,\n",
253
+ " train_dataset=ds,\n",
254
+ " eval_dataset = ds, #use training dataset as eval dataset because of the small size of data\n",
255
+ " tokenizer=tokenizer\n",
256
+ ")"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 5,
262
+ "id": "20e1c0c7-4c92-46a6-8023-2bb2e9f70107",
263
+ "metadata": {},
264
+ "outputs": [
265
+ {
266
+ "data": {
267
+ "text/html": [
268
+ "\n",
269
+ " <div>\n",
270
+ " \n",
271
+ " <progress value='750' max='750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
272
+ " [750/750 36:17, Epoch 49/50]\n",
273
+ " </div>\n",
274
+ " <table border=\"1\" class=\"dataframe\">\n",
275
+ " <thead>\n",
276
+ " <tr style=\"text-align: left;\">\n",
277
+ " <th>Step</th>\n",
278
+ " <th>Training Loss</th>\n",
279
+ " <th>Validation Loss</th>\n",
280
+ " </tr>\n",
281
+ " </thead>\n",
282
+ " <tbody>\n",
283
+ " <tr>\n",
284
+ " <td>20</td>\n",
285
+ " <td>18.860600</td>\n",
286
+ " <td>18.779743</td>\n",
287
+ " </tr>\n",
288
+ " <tr>\n",
289
+ " <td>40</td>\n",
290
+ " <td>18.631400</td>\n",
291
+ " <td>18.560749</td>\n",
292
+ " </tr>\n",
293
+ " <tr>\n",
294
+ " <td>60</td>\n",
295
+ " <td>18.458800</td>\n",
296
+ " <td>18.344973</td>\n",
297
+ " </tr>\n",
298
+ " <tr>\n",
299
+ " <td>80</td>\n",
300
+ " <td>18.136200</td>\n",
301
+ " <td>18.131050</td>\n",
302
+ " </tr>\n",
303
+ " <tr>\n",
304
+ " <td>100</td>\n",
305
+ " <td>17.972900</td>\n",
306
+ " <td>17.917627</td>\n",
307
+ " </tr>\n",
308
+ " <tr>\n",
309
+ " <td>120</td>\n",
310
+ " <td>17.726900</td>\n",
311
+ " <td>17.709686</td>\n",
312
+ " </tr>\n",
313
+ " <tr>\n",
314
+ " <td>140</td>\n",
315
+ " <td>17.605200</td>\n",
316
+ " <td>17.505020</td>\n",
317
+ " </tr>\n",
318
+ " <tr>\n",
319
+ " <td>160</td>\n",
320
+ " <td>17.337000</td>\n",
321
+ " <td>17.299978</td>\n",
322
+ " </tr>\n",
323
+ " <tr>\n",
324
+ " <td>180</td>\n",
325
+ " <td>17.144400</td>\n",
326
+ " <td>17.099331</td>\n",
327
+ " </tr>\n",
328
+ " <tr>\n",
329
+ " <td>200</td>\n",
330
+ " <td>16.930100</td>\n",
331
+ " <td>16.904736</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <td>220</td>\n",
335
+ " <td>16.744000</td>\n",
336
+ " <td>16.711248</td>\n",
337
+ " </tr>\n",
338
+ " <tr>\n",
339
+ " <td>240</td>\n",
340
+ " <td>16.582000</td>\n",
341
+ " <td>16.522562</td>\n",
342
+ " </tr>\n",
343
+ " <tr>\n",
344
+ " <td>260</td>\n",
345
+ " <td>16.443800</td>\n",
346
+ " <td>16.339695</td>\n",
347
+ " </tr>\n",
348
+ " <tr>\n",
349
+ " <td>280</td>\n",
350
+ " <td>16.220400</td>\n",
351
+ " <td>16.161507</td>\n",
352
+ " </tr>\n",
353
+ " <tr>\n",
354
+ " <td>300</td>\n",
355
+ " <td>16.026400</td>\n",
356
+ " <td>15.991174</td>\n",
357
+ " </tr>\n",
358
+ " <tr>\n",
359
+ " <td>320</td>\n",
360
+ " <td>15.869000</td>\n",
361
+ " <td>15.825206</td>\n",
362
+ " </tr>\n",
363
+ " <tr>\n",
364
+ " <td>340</td>\n",
365
+ " <td>15.746500</td>\n",
366
+ " <td>15.668069</td>\n",
367
+ " </tr>\n",
368
+ " <tr>\n",
369
+ " <td>360</td>\n",
370
+ " <td>15.574400</td>\n",
371
+ " <td>15.521387</td>\n",
372
+ " </tr>\n",
373
+ " <tr>\n",
374
+ " <td>380</td>\n",
375
+ " <td>15.420900</td>\n",
376
+ " <td>15.380891</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>400</td>\n",
380
+ " <td>15.288200</td>\n",
381
+ " <td>15.247506</td>\n",
382
+ " </tr>\n",
383
+ " <tr>\n",
384
+ " <td>420</td>\n",
385
+ " <td>15.143000</td>\n",
386
+ " <td>15.120378</td>\n",
387
+ " </tr>\n",
388
+ " <tr>\n",
389
+ " <td>440</td>\n",
390
+ " <td>15.019400</td>\n",
391
+ " <td>15.004883</td>\n",
392
+ " </tr>\n",
393
+ " <tr>\n",
394
+ " <td>460</td>\n",
395
+ " <td>14.919500</td>\n",
396
+ " <td>14.896546</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>480</td>\n",
400
+ " <td>14.791300</td>\n",
401
+ " <td>14.795321</td>\n",
402
+ " </tr>\n",
403
+ " <tr>\n",
404
+ " <td>500</td>\n",
405
+ " <td>14.687800</td>\n",
406
+ " <td>14.703000</td>\n",
407
+ " </tr>\n",
408
+ " <tr>\n",
409
+ " <td>520</td>\n",
410
+ " <td>14.666300</td>\n",
411
+ " <td>14.616350</td>\n",
412
+ " </tr>\n",
413
+ " <tr>\n",
414
+ " <td>540</td>\n",
415
+ " <td>14.550400</td>\n",
416
+ " <td>14.541070</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>560</td>\n",
420
+ " <td>14.505000</td>\n",
421
+ " <td>14.471634</td>\n",
422
+ " </tr>\n",
423
+ " <tr>\n",
424
+ " <td>580</td>\n",
425
+ " <td>14.479400</td>\n",
426
+ " <td>14.409344</td>\n",
427
+ " </tr>\n",
428
+ " <tr>\n",
429
+ " <td>600</td>\n",
430
+ " <td>14.341600</td>\n",
431
+ " <td>14.354433</td>\n",
432
+ " </tr>\n",
433
+ " <tr>\n",
434
+ " <td>620</td>\n",
435
+ " <td>14.339700</td>\n",
436
+ " <td>14.307119</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>640</td>\n",
440
+ " <td>14.292600</td>\n",
441
+ " <td>14.265167</td>\n",
442
+ " </tr>\n",
443
+ " <tr>\n",
444
+ " <td>660</td>\n",
445
+ " <td>14.252600</td>\n",
446
+ " <td>14.229964</td>\n",
447
+ " </tr>\n",
448
+ " <tr>\n",
449
+ " <td>680</td>\n",
450
+ " <td>14.240400</td>\n",
451
+ " <td>14.202421</td>\n",
452
+ " </tr>\n",
453
+ " <tr>\n",
454
+ " <td>700</td>\n",
455
+ " <td>14.183600</td>\n",
456
+ " <td>14.182171</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>720</td>\n",
460
+ " <td>14.182200</td>\n",
461
+ " <td>14.169066</td>\n",
462
+ " </tr>\n",
463
+ " <tr>\n",
464
+ " <td>740</td>\n",
465
+ " <td>14.153600</td>\n",
466
+ " <td>14.162232</td>\n",
467
+ " </tr>\n",
468
+ " </tbody>\n",
469
+ "</table><p>"
470
+ ],
471
+ "text/plain": [
472
+ "<IPython.core.display.HTML object>"
473
+ ]
474
+ },
475
+ "metadata": {},
476
+ "output_type": "display_data"
477
+ },
478
+ {
479
+ "data": {
480
+ "text/plain": [
481
+ "TrainOutput(global_step=750, training_loss=15.830242533365885, metrics={'train_runtime': 2180.7907, 'train_samples_per_second': 2.774, 'train_steps_per_second': 0.344, 'total_flos': 1.3720107025327718e+17, 'train_loss': 15.830242533365885, 'epoch': 49.18032786885246})"
482
+ ]
483
+ },
484
+ "execution_count": 5,
485
+ "metadata": {},
486
+ "output_type": "execute_result"
487
+ }
488
+ ],
489
+ "source": [
490
+ "trainer.train()"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": 11,
496
+ "id": "79786af2-4a19-464f-9f23-5bcfca6f3d16",
497
+ "metadata": {},
498
+ "outputs": [
499
+ {
500
+ "name": "stderr",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
504
+ ]
505
+ },
506
+ {
507
+ "name": "stdout",
508
+ "output_type": "stream",
509
+ "text": [
510
+ "Generated SQL: Which car model from 2015 has the best miles-per-gallon, costs more than $30,000, and how many total miles has it driven?sonyoursite is there are you want to date:1.. Acura of which one! The answer will be a single line with three values separated by commas (e.g., \"Toyota Prius Hybrid\", \"$35k - \\$40K per year\").\" } { SELECT m.make AS Car_Model FROM cars c JOIN models ON CAST(c.model_id as integer) = id WHERE price > '30000' AND fuel_economy IS NOT NULL ORDER BY mileage DESC LIMIT 10;iвassistant\n",
511
+ "\n",
512
+ "I apologize for any confusion earlier.\n",
513
+ "\n",
514
+ "To clarify your question:\n",
515
+ "\n",
516
+ "You're asking me about what I can do if someone else's code or data causes an error in my own program?\n",
517
+ "\n",
518
+ "If that happens,\n",
519
+ "\n",
520
+ "* **Error Handling**: You should handle these errors properly using try-except blocks.\n",
521
+ " * For example:\n",
522
+ " ```\n",
523
+ " import requests\n",
524
+ " def get_data(url):\n",
525
+ " response=requests.get('https://api.example.com/data')\n",
526
+ " returnresponse.json()\n",
527
+ " \n"
528
+ ]
529
+ }
530
+ ],
531
+ "source": [
532
+ "import torch\n",
533
+ "\n",
534
+ "question = \"Which car model from 2015 has the best miles-per-gallon, costs more than $30,000, and how many total miles has it driven?\"\n",
535
+ "expected_sql_query = \"\"\"\n",
536
+ "SELECT make, model, mpg, totalMiles \n",
537
+ "FROM cars \n",
538
+ "WHERE modelYear = 2015 \n",
539
+ "AND sellPrice > 30000 \n",
540
+ "ORDER BY mpg DESC \n",
541
+ "LIMIT 1;\n",
542
+ "\"\"\"\n",
543
+ "\n",
544
+ "inputs = tokenizer(question, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=512).to(\"cuda\")\n",
545
+ "\n",
546
+ "model.eval()\n",
547
+ "\n",
548
+ "with torch.no_grad():\n",
549
+ " generated_ids = model.generate(\n",
550
+ " input_ids=inputs[\"input_ids\"],\n",
551
+ " attention_mask=inputs[\"attention_mask\"],\n",
552
+ " max_new_tokens=200, # Allow for sufficient token generation\n",
553
+ " repetition_penalty=2.0,\n",
554
+ " early_stopping=True,\n",
555
+ " eos_token_id=tokenizer.eos_token_id, # Use greedy decoding for deterministic output\n",
556
+ " )\n",
557
+ "\n",
558
+ "\n",
559
+ "generated_sql_query = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
560
+ "print(f\"Generated SQL: {generated_sql_query}\")"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": 12,
566
+ "id": "f6ac37df-0d98-42db-82e4-31aeb1d57baa",
567
+ "metadata": {},
568
+ "outputs": [
569
+ {
570
+ "data": {
571
+ "application/vnd.jupyter.widget-view+json": {
572
+ "model_id": "abaf926b5cb74411bcbce6570542dc13",
573
+ "version_major": 2,
574
+ "version_minor": 0
575
+ },
576
+ "text/plain": [
577
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
578
+ ]
579
+ },
580
+ "metadata": {},
581
+ "output_type": "display_data"
582
+ }
583
+ ],
584
+ "source": [
585
+ "from huggingface_hub import login\n",
586
+ "login()"
587
+ ]
588
+ },
589
+ {
590
+ "cell_type": "code",
591
+ "execution_count": 13,
592
+ "id": "adfe4f39-093a-46e3-83d9-789106cfe7ea",
593
+ "metadata": {},
594
+ "outputs": [
595
+ {
596
+ "data": {
597
+ "application/vnd.jupyter.widget-view+json": {
598
+ "model_id": "b9d47051c5664b1b8c3d738a0c23b7b8",
599
+ "version_major": 2,
600
+ "version_minor": 0
601
+ },
602
+ "text/plain": [
603
+ "training_args.bin: 0%| | 0.00/5.11k [00:00<?, ?B/s]"
604
+ ]
605
+ },
606
+ "metadata": {},
607
+ "output_type": "display_data"
608
+ },
609
+ {
610
+ "data": {
611
+ "application/vnd.jupyter.widget-view+json": {
612
+ "model_id": "7d969ddb52a64373a0907d28d5ee9d79",
613
+ "version_major": 2,
614
+ "version_minor": 0
615
+ },
616
+ "text/plain": [
617
+ "adapter_model.safetensors: 0%| | 0.00/328k [00:00<?, ?B/s]"
618
+ ]
619
+ },
620
+ "metadata": {},
621
+ "output_type": "display_data"
622
+ },
623
+ {
624
+ "data": {
625
+ "application/vnd.jupyter.widget-view+json": {
626
+ "model_id": "238a36fc2f1143df9741966241a52ce6",
627
+ "version_major": 2,
628
+ "version_minor": 0
629
+ },
630
+ "text/plain": [
631
+ "Upload 2 LFS files: 0%| | 0/2 [00:00<?, ?it/s]"
632
+ ]
633
+ },
634
+ "metadata": {},
635
+ "output_type": "display_data"
636
+ },
637
+ {
638
+ "data": {
639
+ "text/plain": [
640
+ "CommitInfo(commit_url='https://huggingface.co/kristiannordby/results/commit/f5914cc61b844fb247969b86343e21b71a1ddf72', commit_message='prompttuned-sql-model', commit_description='', oid='f5914cc61b844fb247969b86343e21b71a1ddf72', pr_url=None, repo_url=RepoUrl('https://huggingface.co/kristiannordby/results', endpoint='https://huggingface.co', repo_type='model', repo_id='kristiannordby/results'), pr_revision=None, pr_num=None)"
641
+ ]
642
+ },
643
+ "execution_count": 13,
644
+ "metadata": {},
645
+ "output_type": "execute_result"
646
+ }
647
+ ],
648
+ "source": [
649
+ "trainer.push_to_hub(\"prompttuned-sql-model\")\n",
650
+ "# tokenizer.push_to_hub(\"./finetuned-sql-model\")"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": 14,
656
+ "id": "b8a4f79f-4516-4265-800b-fd9c9ba0ca7d",
657
+ "metadata": {},
658
+ "outputs": [
659
+ {
660
+ "data": {
661
+ "application/vnd.jupyter.widget-view+json": {
662
+ "model_id": "8aeb3531a8004a0eb7b27b3ade635384",
663
+ "version_major": 2,
664
+ "version_minor": 0
665
+ },
666
+ "text/plain": [
667
+ "adapter_model.safetensors: 0%| | 0.00/328k [00:00<?, ?B/s]"
668
+ ]
669
+ },
670
+ "metadata": {},
671
+ "output_type": "display_data"
672
+ },
673
+ {
674
+ "data": {
675
+ "text/plain": [
676
+ "CommitInfo(commit_url='https://huggingface.co/kristiannordby/prompttuned_model-sql-model/commit/454553f082f2bb2e23d126f7f14f81fcf59a33a9', commit_message='Upload model', commit_description='', oid='454553f082f2bb2e23d126f7f14f81fcf59a33a9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/kristiannordby/prompttuned_model-sql-model', endpoint='https://huggingface.co', repo_type='model', repo_id='kristiannordby/prompttuned_model-sql-model'), pr_revision=None, pr_num=None)"
677
+ ]
678
+ },
679
+ "execution_count": 14,
680
+ "metadata": {},
681
+ "output_type": "execute_result"
682
+ }
683
+ ],
684
+ "source": [
685
+ "model.push_to_hub(\"prompttuned_model-sql-model\")"
686
+ ]
687
+ }
688
+ ],
689
+ "metadata": {
690
+ "kernelspec": {
691
+ "display_name": "Python 3 (ipykernel)",
692
+ "language": "python",
693
+ "name": "python3"
694
+ },
695
+ "language_info": {
696
+ "codemirror_mode": {
697
+ "name": "ipython",
698
+ "version": 3
699
+ },
700
+ "file_extension": ".py",
701
+ "mimetype": "text/x-python",
702
+ "name": "python",
703
+ "nbconvert_exporter": "python",
704
+ "pygments_lexer": "ipython3",
705
+ "version": "3.10.12"
706
+ }
707
+ },
708
+ "nbformat": 4,
709
+ "nbformat_minor": 5
710
+ }