aarohanverma commited on
Commit
71c12a0
·
verified ·
1 Parent(s): 976912c

Upload text2sql_flant5_qlora.ipynb

Browse files
Files changed (1) hide show
  1. text2sql_flant5_qlora.ipynb +1492 -0
text2sql_flant5_qlora.ipynb ADDED
@@ -0,0 +1,1492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "cadbd30d-57ce-4ef2-889f-24bd0ff06b89",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/workspace\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "!echo $PWD"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "12d99875-d86b-4442-8682-b9751118d90e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "#!pip3 install evaluate datasets bitsandbytes transformers peft rapidfuzz absl-py"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "5f167a6f-5139-46e6-afb2-a1fa4d12f3fd",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import time\n",
39
+ "import logging\n",
40
+ "import re\n",
41
+ "import random\n",
42
+ "import gc\n",
43
+ "import numpy as np\n",
44
+ "import pandas as pd\n",
45
+ "import torch\n",
46
+ "import evaluate\n",
47
+ "\n",
48
+ "from datasets import Dataset, DatasetDict, load_from_disk\n",
49
+ "from transformers import (\n",
50
+ " AutoModelForSeq2SeqLM,\n",
51
+ " AutoTokenizer,\n",
52
+ " TrainingArguments,\n",
53
+ " Trainer,\n",
54
+ " GenerationConfig,\n",
55
+ " BitsAndBytesConfig,\n",
56
+ ")\n",
57
+ "from transformers.trainer_callback import EarlyStoppingCallback\n",
58
+ "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 4,
64
+ "id": "53684b5e-c27e-4eb9-815e-583aa194e096",
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "name": "stdout",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "cuda\n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "# Enable cudnn benchmark for fixed input sizes (can speed up computation)\n",
77
+ "torch.backends.cudnn.benchmark = True\n",
78
+ "\n",
79
+ "# Set device to RTX 4090\n",
80
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
81
+ "print(device)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 5,
87
+ "id": "a47bf3cd-752d-4d1c-9697-70098d6204fa",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "random.seed(42)\n",
92
+ "np.random.seed(42)\n",
93
+ "torch.manual_seed(42)\n",
94
+ "if torch.cuda.is_available():\n",
95
+ " torch.cuda.manual_seed_all(42)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 6,
101
+ "id": "f16df21e-9797-4f78-83a1-a2943759ba55",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def clear_memory():\n",
106
+ " gc.collect()\n",
107
+ " torch.cuda.empty_cache()"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 7,
113
+ "id": "196e83da-6c8c-4cd7-bd70-2598a5e2a16a",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "logging.basicConfig(\n",
118
+ " level=logging.INFO,\n",
119
+ " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
120
+ ")\n",
121
+ "logger = logging.getLogger(__name__)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 8,
127
+ "id": "cea22b9f-f309-4151-81ac-37547c8feeb0",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "def preprocess(text: str) -> str:\n",
132
+ " \"\"\"Remove extra whitespaces and newlines from a text string.\"\"\"\n",
133
+ " if not isinstance(text, str):\n",
134
+ " return \"\"\n",
135
+ " return re.sub(r'\\s+', ' ', text.replace('\\n', ' ')).strip()\n",
136
+ "\n",
137
+ "def clean_df(df, rename=None, drop=None, select=None):\n",
138
+ " \"\"\"\n",
139
+ " Clean and rename dataframe columns:\n",
140
+ " - drop: list of columns to drop\n",
141
+ " - rename: dict mapping old column names to new names\n",
142
+ " - select: list of columns to keep in final order\n",
143
+ " \"\"\"\n",
144
+ " if drop:\n",
145
+ " df = df.drop(columns=drop, errors='ignore')\n",
146
+ " if rename:\n",
147
+ " df = df.rename(columns=rename)\n",
148
+ " for col in ['query', 'context', 'response']:\n",
149
+ " if col in df.columns:\n",
150
+ " df[col] = df[col].apply(preprocess)\n",
151
+ " if select:\n",
152
+ " df = df[select]\n",
153
+ " return df"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 9,
159
+ "id": "d4eb82ce-1713-40b6-981d-43ce35aaa6f6",
160
+ "metadata": {},
161
+ "outputs": [
162
+ {
163
+ "name": "stderr",
164
+ "output_type": "stream",
165
+ "text": [
166
+ "2025-03-17 17:06:42,785 - INFO - Loading raw datasets from various sources...\n",
167
+ "2025-03-17 17:07:15,400 - INFO - Total rows before dropping duplicates: 490241\n",
168
+ "2025-03-17 17:07:16,852 - INFO - Total rows after dropping duplicates: 440785\n"
169
+ ]
170
+ }
171
+ ],
172
+ "source": [
173
+ "logger.info(\"Loading raw datasets from various sources...\")\n",
174
+ "\n",
175
+ "# Load datasets\n",
176
+ "df1 = pd.read_json(\"hf://datasets/Clinton/Text-to-sql-v1/texttosqlv2.jsonl\", lines=True)\n",
177
+ "df2 = pd.read_json(\"hf://datasets/b-mc2/sql-create-context/sql_create_context_v4.json\")\n",
178
+ "df3 = pd.read_parquet(\"hf://datasets/gretelai/synthetic_text_to_sql/synthetic_text_to_sql_train.snappy.parquet\")\n",
179
+ "df4 = pd.read_json(\"hf://datasets/knowrohit07/know_sql/know_sql_val3{ign}.json\")\n",
180
+ "\n",
181
+ "# Clean and rename columns to unify to 'query', 'context', 'response'\n",
182
+ "df1 = clean_df(df1, rename={'instruction': 'query', 'input': 'context'}, drop=['source', 'text'])\n",
183
+ "df2 = clean_df(df2, rename={'question': 'query', 'answer': 'response'})\n",
184
+ "df3 = clean_df(df3, rename={'sql_prompt': 'query', 'sql_context': 'context', 'sql': 'response'},\n",
185
+ " select=['query', 'context', 'response'])\n",
186
+ "df4 = clean_df(df4, rename={'question': 'query', 'answer': 'response'})\n",
187
+ "\n",
188
+ "# Concatenate all DataFrames\n",
189
+ "final_df = pd.concat([df1, df2, df3, df4], ignore_index=True)\n",
190
+ "logger.info(\"Total rows before dropping duplicates: %d\", len(final_df))\n",
191
+ "\n",
192
+ "# Force correct column order and drop rows with missing fields\n",
193
+ "final_df = final_df[['query', 'context', 'response']]\n",
194
+ "final_df = final_df.dropna(subset=['query', 'context', 'response'])\n",
195
+ "final_df = final_df.drop_duplicates()\n",
196
+ "logger.info(\"Total rows after dropping duplicates: %d\", len(final_df))"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 10,
202
+ "id": "8446814e-5a2c-48a4-8c01-059afcf1d3c1",
203
+ "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "name": "stderr",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (1113 > 512). Running this sequence through the model will result in indexing errors\n",
210
+ "2025-03-17 17:10:43,961 - INFO - Total rows after filtering by token length (prompt <= 500 and response <= 250 tokens): 398481\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
216
+ "\n",
217
+ "max_length_prompt = 500\n",
218
+ "max_length_response = 250\n",
219
+ "\n",
220
+ "def tokenize_length_filter(row):\n",
221
+ " start_prompt = \"Context:\\n\"\n",
222
+ " middle_prompt = \"\\n\\nQuery:\\n\"\n",
223
+ " end_prompt = \"\\n\\nResponse:\\n\"\n",
224
+ " \n",
225
+ " # Construct the prompt as used in the tokenize_function\n",
226
+ " prompt = f\"{start_prompt}{row['context']}{middle_prompt}{row['query']}{end_prompt}\"\n",
227
+ " \n",
228
+ " # Encode without truncation to get the full token count\n",
229
+ " prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True, truncation=False)\n",
230
+ " response_tokens = tokenizer.encode(row['response'], add_special_tokens=True, truncation=False)\n",
231
+ " \n",
232
+ " return len(prompt_tokens) <= max_length_prompt and len(response_tokens) <= max_length_response\n",
233
+ "\n",
234
+ "final_df = final_df[final_df.apply(tokenize_length_filter, axis=1)]\n",
235
+ "logger.info(\"Total rows after filtering by token length (prompt <= %d and response <= %d tokens): %d\", \n",
236
+ " max_length_prompt, max_length_response, len(final_df))\n"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 11,
242
+ "id": "177e1e6d-9fbc-442d-9774-5a3e5234329f",
243
+ "metadata": {},
244
+ "outputs": [
245
+ {
246
+ "name": "stderr",
247
+ "output_type": "stream",
248
+ "text": [
249
+ "2025-03-17 17:10:43,968 - INFO - Sample from filtered final_df:\n",
250
+ " query \\\n",
251
+ "0 Name the home team for carlton away team \n",
252
+ "1 what will the population of Asia be when Latin... \n",
253
+ "2 How many faculty members do we have for each g... \n",
254
+ "\n",
255
+ " context \\\n",
256
+ "0 CREATE TABLE table_name_77 ( home_team VARCHAR... \n",
257
+ "1 CREATE TABLE table_22767 ( \"Year\" real, \"World... \n",
258
+ "2 CREATE TABLE Student ( StuID INTEGER, LName VA... \n",
259
+ "\n",
260
+ " response \n",
261
+ "0 SELECT home_team FROM table_name_77 WHERE away... \n",
262
+ "1 SELECT \"Asia\" FROM table_22767 WHERE \"Latin Am... \n",
263
+ "2 SELECT Sex, COUNT(*) FROM Faculty GROUP BY Sex... \n"
264
+ ]
265
+ }
266
+ ],
267
+ "source": [
268
+ "logger.info(\"Sample from filtered final_df:\\n%s\", final_df.head(3))\n",
269
+ "clear_memory()"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 12,
275
+ "id": "0b639efe-ebeb-4b34-bc3f-accf776ba0da",
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "name": "stderr",
280
+ "output_type": "stream",
281
+ "text": [
282
+ "2025-03-17 17:10:44,311 - INFO - Final split sizes: Train: 338708, Test: 39848, Validation: 19925\n"
283
+ ]
284
+ },
285
+ {
286
+ "data": {
287
+ "application/vnd.jupyter.widget-view+json": {
288
+ "model_id": "11dc405cf3d54b6abc81b8eaf6742bea",
289
+ "version_major": 2,
290
+ "version_minor": 0
291
+ },
292
+ "text/plain": [
293
+ "Saving the dataset (0/1 shards): 0%| | 0/338708 [00:00<?, ? examples/s]"
294
+ ]
295
+ },
296
+ "metadata": {},
297
+ "output_type": "display_data"
298
+ },
299
+ {
300
+ "data": {
301
+ "application/vnd.jupyter.widget-view+json": {
302
+ "model_id": "868d3a0d08874c448faac4b50dbb3685",
303
+ "version_major": 2,
304
+ "version_minor": 0
305
+ },
306
+ "text/plain": [
307
+ "Saving the dataset (0/1 shards): 0%| | 0/39848 [00:00<?, ? examples/s]"
308
+ ]
309
+ },
310
+ "metadata": {},
311
+ "output_type": "display_data"
312
+ },
313
+ {
314
+ "data": {
315
+ "application/vnd.jupyter.widget-view+json": {
316
+ "model_id": "0370d0dd07514d5cae499ab93ca47ee8",
317
+ "version_major": 2,
318
+ "version_minor": 0
319
+ },
320
+ "text/plain": [
321
+ "Saving the dataset (0/1 shards): 0%| | 0/19925 [00:00<?, ? examples/s]"
322
+ ]
323
+ },
324
+ "metadata": {},
325
+ "output_type": "display_data"
326
+ },
327
+ {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "2025-03-17 17:10:45,869 - INFO - Merged and Saved Dataset Successfully!\n",
332
+ "2025-03-17 17:10:45,870 - INFO - Dataset summary: DatasetDict({\n",
333
+ " train: Dataset({\n",
334
+ " features: ['query', 'context', 'response'],\n",
335
+ " num_rows: 338708\n",
336
+ " })\n",
337
+ " test: Dataset({\n",
338
+ " features: ['query', 'context', 'response'],\n",
339
+ " num_rows: 39848\n",
340
+ " })\n",
341
+ " validation: Dataset({\n",
342
+ " features: ['query', 'context', 'response'],\n",
343
+ " num_rows: 19925\n",
344
+ " })\n",
345
+ "})\n"
346
+ ]
347
+ }
348
+ ],
349
+ "source": [
350
+ "def split_dataframe(df, train_frac=0.85, test_frac=0.1, val_frac=0.05):\n",
351
+ " n = len(df)\n",
352
+ " train_end = int(n * train_frac)\n",
353
+ " test_end = train_end + int(n * test_frac)\n",
354
+ " train_df = df.iloc[:train_end].reset_index(drop=True)\n",
355
+ " test_df = df.iloc[train_end:test_end].reset_index(drop=True)\n",
356
+ " val_df = df.iloc[test_end:].reset_index(drop=True)\n",
357
+ " return train_df, test_df, val_df\n",
358
+ "\n",
359
+ "train_df, test_df, val_df = split_dataframe(final_df)\n",
360
+ "logger.info(\"Final split sizes: Train: %d, Test: %d, Validation: %d\", len(train_df), len(test_df), len(val_df))\n",
361
+ "\n",
362
+ "# Convert splits to Hugging Face Datasets\n",
363
+ "train_dataset = Dataset.from_pandas(train_df)\n",
364
+ "test_dataset = Dataset.from_pandas(test_df)\n",
365
+ "val_dataset = Dataset.from_pandas(val_df)\n",
366
+ "\n",
367
+ "dataset = DatasetDict({\n",
368
+ " 'train': train_dataset,\n",
369
+ " 'test': test_dataset,\n",
370
+ " 'validation': val_dataset\n",
371
+ "})\n",
372
+ "\n",
373
+ "dataset.save_to_disk(\"merged_dataset\")\n",
374
+ "logger.info(\"Merged and Saved Dataset Successfully!\")\n",
375
+ "logger.info(\"Dataset summary: %s\", dataset)\n",
376
+ "clear_memory()"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 13,
382
+ "id": "9f6e1095-d72d-4e22-b20d-683f1f84544c",
383
+ "metadata": {},
384
+ "outputs": [
385
+ {
386
+ "name": "stderr",
387
+ "output_type": "stream",
388
+ "text": [
389
+ "2025-03-17 17:10:46,218 - INFO - Reloaded dataset from disk. Example from test split:\n",
390
+ "{'query': \"Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\", 'context': \"CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\", 'response': 'SELECT command_name, type FROM defense_security.Military_Cyber_Commands;'}\n",
391
+ "2025-03-17 17:10:46,475 - INFO - Loaded Tokenized Dataset from disk.\n",
392
+ "2025-03-17 17:10:46,477 - INFO - Final tokenized dataset splits: dict_keys(['train', 'test', 'validation'])\n",
393
+ "2025-03-17 17:10:46,483 - INFO - Sample tokenized record from train split:\n",
394
+ "{'input_ids': tensor([ 1193, 6327, 10, 205, 4386, 6048, 332, 17098, 953, 834,\n",
395
+ " 4350, 834, 4013, 41, 234, 834, 11650, 584, 4280, 28027,\n",
396
+ " 6, 550, 834, 11650, 584, 4280, 28027, 3, 61, 3,\n",
397
+ " 27569, 10, 5570, 8, 234, 372, 21, 443, 7377, 550,\n",
398
+ " 372, 16361, 10, 3, 1, 0, 0, 0, 0, 0,\n",
399
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
400
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
401
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
402
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
403
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
404
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
405
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
406
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
407
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
408
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
409
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
410
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
411
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
412
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
413
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
414
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
415
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
416
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
417
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
418
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
419
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
420
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
421
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
422
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
423
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
424
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
425
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
426
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
427
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
428
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
429
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
430
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
431
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
432
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
433
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
434
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
435
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
436
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
437
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
438
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
439
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
440
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
441
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
442
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
443
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
444
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
445
+ " 0, 0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
446
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
447
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
448
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
449
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
450
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
451
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
452
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
453
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
454
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
455
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
456
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
457
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
458
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
459
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
460
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
461
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
462
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
463
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
464
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
465
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
466
+ " 0, 0, 0, 0, 0, 0, 0, 0]), 'labels': tensor([ 3, 23143, 14196, 234, 834, 11650, 21680, 953, 834, 4350,\n",
467
+ " 834, 4013, 549, 17444, 427, 550, 834, 11650, 3274, 96,\n",
468
+ " 1720, 7377, 121, 1, -100, -100, -100, -100, -100, -100,\n",
469
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
470
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
471
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
472
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
473
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
474
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
475
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
476
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
477
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
478
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
479
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
480
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
481
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
482
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
483
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
484
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
485
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
486
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
487
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
488
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
489
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
490
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
491
+ " -100, -100, -100, -100, -100, -100])}\n"
492
+ ]
493
+ }
494
+ ],
495
+ "source": [
496
+ "dataset = load_from_disk(\"merged_dataset\")\n",
497
+ "logger.info(\"Reloaded dataset from disk. Example from test split:\\n%s\", dataset['test'][0])\n",
498
+ "\n",
499
+ "model_name = \"google/flan-t5-base\"\n",
500
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
501
+ "\n",
502
+ "def tokenize_function(batch: dict) -> dict:\n",
503
+ " \"\"\"\n",
504
+ " Tokenizes a batch of examples for T5 fine-tuning.\n",
505
+ " Constructs a prompt in the format:\n",
506
+ " Context:\n",
507
+ " <context>\n",
508
+ " \n",
509
+ " Query:\n",
510
+ " <query>\n",
511
+ " \n",
512
+ " Response:\n",
513
+ " \"\"\"\n",
514
+ " start_prompt = \"Context:\\n\"\n",
515
+ " middle_prompt = \"\\n\\nQuery:\\n\"\n",
516
+ " end_prompt = \"\\n\\nResponse:\\n\"\n",
517
+ "\n",
518
+ " prompts = [\n",
519
+ " f\"{start_prompt}{ctx}{middle_prompt}{qry}{end_prompt}\"\n",
520
+ " for ctx, qry in zip(batch['context'], batch['query'])\n",
521
+ " ]\n",
522
+ "\n",
523
+ " tokenized_inputs = tokenizer(\n",
524
+ " prompts,\n",
525
+ " padding=\"max_length\",\n",
526
+ " truncation=True,\n",
527
+ " max_length=512\n",
528
+ " )\n",
529
+ " tokenized_labels = tokenizer(\n",
530
+ " batch['response'],\n",
531
+ " padding=\"max_length\",\n",
532
+ " truncation=True,\n",
533
+ " max_length=256\n",
534
+ " )\n",
535
+ " labels = [\n",
536
+ " [-100 if token == tokenizer.pad_token_id else token for token in seq]\n",
537
+ " for seq in tokenized_labels['input_ids']\n",
538
+ " ]\n",
539
+ "\n",
540
+ " batch['input_ids'] = tokenized_inputs['input_ids']\n",
541
+ " batch['attention_mask'] = tokenized_inputs['attention_mask']\n",
542
+ " batch['labels'] = labels\n",
543
+ " return batch\n",
544
+ "\n",
545
+ "try:\n",
546
+ " tokenized_datasets = load_from_disk(\"tokenized_datasets\")\n",
547
+ " logger.info(\"Loaded Tokenized Dataset from disk.\")\n",
548
+ "except Exception as e:\n",
549
+ " logger.info(\"Tokenized dataset not found. Creating a new one...\")\n",
550
+ " tokenized_datasets = dataset.map(\n",
551
+ " tokenize_function,\n",
552
+ " batched=True,\n",
553
+ " remove_columns=['query', 'context', 'response'],\n",
554
+ " num_proc=8\n",
555
+ " )\n",
556
+ " tokenized_datasets.save_to_disk(\"tokenized_datasets\")\n",
557
+ " logger.info(\"Tokenized and Saved Dataset.\")\n",
558
+ "\n",
559
+ "tokenized_datasets.set_format(\"torch\")\n",
560
+ "\n",
561
+ "logger.info(\"Final tokenized dataset splits: %s\", tokenized_datasets.keys())\n",
562
+ "logger.info(\"Sample tokenized record from train split:\\n%s\", tokenized_datasets['train'][0])"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 14,
568
+ "id": "7f004e55-181c-47aa-9f3e-c7c1ceae780c",
569
+ "metadata": {},
570
+ "outputs": [
571
+ {
572
+ "name": "stdout",
573
+ "output_type": "stream",
574
+ "text": [
575
+ "----------------------------------------------------------------------------------------------------\n",
576
+ "INPUT PROMPT:\n",
577
+ "Context:\n",
578
+ "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n",
579
+ "\n",
580
+ "Query:\n",
581
+ "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n",
582
+ "\n",
583
+ "Response:\n",
584
+ "\n",
585
+ "----------------------------------------------------------------------------------------------------\n",
586
+ "BASELINE HUMAN ANSWER:\n",
587
+ "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
588
+ "\n",
589
+ "----------------------------------------------------------------------------------------------------\n",
590
+ "MODEL GENERATION - ZERO SHOT:\n",
591
+ "USCYBERCOM, JTF-CND, Offensive Cyber Operations, 10th Fleet, Network Warfare\n"
592
+ ]
593
+ }
594
+ ],
595
+ "source": [
596
+ "model_name = 'google/flan-t5-base'\n",
597
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
598
+ "original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
599
+ "original_model = original_model.to(device)\n",
600
+ "\n",
601
+ "index = 0\n",
602
+ "query = dataset['test'][index]['query']\n",
603
+ "context = dataset['test'][index]['context']\n",
604
+ "response = dataset['test'][index]['response']\n",
605
+ "\n",
606
+ "prompt = f\"\"\"Context:\n",
607
+ "{context}\n",
608
+ "\n",
609
+ "Query:\n",
610
+ "{query}\n",
611
+ "\n",
612
+ "Response:\n",
613
+ "\"\"\"\n",
614
+ "inputs = tokenizer(prompt, return_tensors='pt').to(device)\n",
615
+ "baseline_output = tokenizer.decode(\n",
616
+ " original_model.generate(\n",
617
+ " inputs[\"input_ids\"],\n",
618
+ " max_new_tokens=200,\n",
619
+ " )[0],\n",
620
+ " skip_special_tokens=True\n",
621
+ ")\n",
622
+ "dash_line = '-' * 100\n",
623
+ "print(dash_line)\n",
624
+ "print(f'INPUT PROMPT:\\n{prompt}')\n",
625
+ "print(dash_line)\n",
626
+ "print(f'BASELINE HUMAN ANSWER:\\n{response}\\n')\n",
627
+ "print(dash_line)\n",
628
+ "print(f'MODEL GENERATION - ZERO SHOT:\\n{baseline_output}')\n",
629
+ "clear_memory()"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "code",
634
+ "execution_count": 15,
635
+ "id": "f50e56c7-98b3-42bc-9129-89f3eff802e7",
636
+ "metadata": {},
637
+ "outputs": [
638
+ {
639
+ "name": "stderr",
640
+ "output_type": "stream",
641
+ "text": [
642
+ "2025-03-17 17:10:50,413 - INFO - Attempting to load the fine-tuned model...\n",
643
+ "2025-03-17 17:10:51,949 - INFO - Fine-tuned model loaded successfully.\n"
644
+ ]
645
+ }
646
+ ],
647
+ "source": [
648
+ "import math\n",
649
+ "\n",
650
+ "try:\n",
651
+ " logger.info(\"Attempting to load the fine-tuned model...\")\n",
652
+ " finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"text2sql_flant5base_finetuned\")\n",
653
+ " tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
654
+ " finetuned_model = finetuned_model.to(device)\n",
655
+ " to_train = False\n",
656
+ " logger.info(\"Fine-tuned model loaded successfully.\")\n",
657
+ "except Exception as e:\n",
658
+ " logger.info(\"Fine-tuned model not found.\")\n",
659
+ " logger.info(\"Initializing model and tokenizer for QLORA fine-tuning...\")\n",
660
+ " to_train = True\n",
661
+ "\n",
662
+ " quant_config = BitsAndBytesConfig(\n",
663
+ " load_in_4bit=True,\n",
664
+ " bnb_4bit_quant_type=\"nf4\",\n",
665
+ " bnb_4bit_use_double_quant=True,\n",
666
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
667
+ " )\n",
668
+ "\n",
669
+ " finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\n",
670
+ " model_name,\n",
671
+ " quantization_config=quant_config,\n",
672
+ " device_map=\"auto\",\n",
673
+ " torch_dtype=torch.bfloat16,\n",
674
+ " )\n",
675
+ " finetuned_model = prepare_model_for_kbit_training(finetuned_model)\n",
676
+ " \n",
677
+ " lora_config = LoraConfig(\n",
678
+ " r=32,\n",
679
+ " lora_alpha=64,\n",
680
+ " target_modules=[\"q\", \"v\"],\n",
681
+ " lora_dropout=0.1,\n",
682
+ " bias=\"none\",\n",
683
+ " task_type=\"SEQ_2_SEQ_LM\"\n",
684
+ " )\n",
685
+ " finetuned_model = get_peft_model(finetuned_model, lora_config)\n",
686
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
687
+ " logger.info(\"Base model loaded and prepared for QLORA fine-tuning.\")\n",
688
+ " clear_memory()\n",
689
+ "\n",
690
+ "if to_train:\n",
691
+ " output_dir = f\"./sql-training-{int(time.time())}\"\n",
692
+ " logger.info(\"Starting training. Output directory: %s\", output_dir)\n",
693
+ "\n",
694
+ " # Compute total training steps:\n",
695
+ " num_train_samples = len(tokenized_datasets[\"train\"])\n",
696
+ " per_device_train_batch_size = 64\n",
697
+ " per_device_eval_batch_size = 64\n",
698
+ " num_train_epochs = 6\n",
699
+ " # Assuming no gradient accumulation beyond the per-device batch size\n",
700
+ " total_steps = math.ceil(num_train_samples / per_device_train_batch_size) * num_train_epochs\n",
701
+ " # Set warmup steps as 10% of total steps (adjust as needed)\n",
702
+ " warmup_steps = int(total_steps * 0.1)\n",
703
+ " \n",
704
+ " logger.info(\"Total training steps: %d, Warmup steps (10%%): %d\", total_steps, warmup_steps)\n",
705
+ " \n",
706
+ " training_args = TrainingArguments(\n",
707
+ " output_dir=output_dir,\n",
708
+ " gradient_checkpointing=True,\n",
709
+ " gradient_checkpointing_kwargs={\"use_reentrant\": True},\n",
710
+ " gradient_accumulation_steps = 2,\n",
711
+ " learning_rate=2e-4,\n",
712
+ " optim=\"adamw_bnb_8bit\", # Memory-efficient optimizer\n",
713
+ " num_train_epochs=num_train_epochs,\n",
714
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
715
+ " per_device_eval_batch_size=per_device_eval_batch_size,\n",
716
+ " weight_decay=0.01,\n",
717
+ " logging_steps=200, \n",
718
+ " logging_dir=f\"{output_dir}/logs\",\n",
719
+ " eval_strategy=\"epoch\", # Evaluate at the end of each epoch\n",
720
+ " save_strategy=\"epoch\", # Save the model at the end of each epoch\n",
721
+ " save_total_limit=3,\n",
722
+ " load_best_model_at_end=True,\n",
723
+ " metric_for_best_model=\"eval_loss\",\n",
724
+ " bf16=True, \n",
725
+ " warmup_ratio=0.1, # Warmup 10% of total steps\n",
726
+ " lr_scheduler_type=\"cosine\",\n",
727
+ " )\n",
728
+ " trainer = Trainer(\n",
729
+ " model=finetuned_model,\n",
730
+ " args=training_args,\n",
731
+ " train_dataset=tokenized_datasets[\"train\"],\n",
732
+ " eval_dataset=tokenized_datasets[\"validation\"],\n",
733
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n",
734
+ " )\n",
735
+ " logger.info(\"Beginning fine-tuning...\")\n",
736
+ " trainer.train()\n",
737
+ " logger.info(\"Training completed.\")\n",
738
+ " save_path = \"text2sql_flant5base_finetuned\"\n",
739
+ " finetuned_model.save_pretrained(save_path)\n",
740
+ " logger.info(\"Model saved to %s\", save_path)\n",
741
+ " clear_memory()"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "execution_count": 16,
747
+ "id": "f364eb6b-56cb-4533-8ef6-b5e7f56895aa",
748
+ "metadata": {},
749
+ "outputs": [
750
+ {
751
+ "name": "stderr",
752
+ "output_type": "stream",
753
+ "text": [
754
+ "2025-03-17 17:10:51,987 - INFO - Running inference on 5 examples (displaying real responses).\n"
755
+ ]
756
+ },
757
+ {
758
+ "name": "stdout",
759
+ "output_type": "stream",
760
+ "text": [
761
+ "\n",
762
+ "====================================================================================================\n",
763
+ "----------------------------------------------------------------------------------------------------\n",
764
+ "Example 1\n",
765
+ "----------------------------------------------------------------------------------------------------\n",
766
+ "INPUT PROMPT:\n",
767
+ "Context:\n",
768
+ "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n",
769
+ "\n",
770
+ "Query:\n",
771
+ "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n",
772
+ "\n",
773
+ "Response:\n",
774
+ "\n",
775
+ "----------------------------------------------------------------------------------------------------\n",
776
+ "HUMAN RESPONSE:\n",
777
+ "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
778
+ "----------------------------------------------------------------------------------------------------\n",
779
+ "ORIGINAL MODEL OUTPUT:\n",
780
+ "USCYBERCOM, JTF-CND, Offensive Cyber Operations, 10th Fleet, Network Warfare\n",
781
+ "----------------------------------------------------------------------------------------------------\n",
782
+ "FINE-TUNED MODEL OUTPUT:\n",
783
+ "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
784
+ "====================================================================================================\n",
785
+ "\n",
786
+ "----------------------------------------------------------------------------------------------------\n",
787
+ "Example 2\n",
788
+ "----------------------------------------------------------------------------------------------------\n",
789
+ "INPUT PROMPT:\n",
790
+ "Context:\n",
791
+ "CREATE TABLE incidents (id INT, cause VARCHAR(255), cost INT, date DATE); INSERT INTO incidents (id, cause, cost, date) VALUES (1, 'insider threat', 10000, '2022-01-01'); INSERT INTO incidents (id, cause, cost, date) VALUES (2, 'phishing', 5000, '2022-01-02');\n",
792
+ "\n",
793
+ "Query:\n",
794
+ "Find the total cost of all security incidents caused by insider threats in the last 6 months\n",
795
+ "\n",
796
+ "Response:\n",
797
+ "\n",
798
+ "----------------------------------------------------------------------------------------------------\n",
799
+ "HUMAN RESPONSE:\n",
800
+ "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
801
+ "----------------------------------------------------------------------------------------------------\n",
802
+ "ORIGINAL MODEL OUTPUT:\n",
803
+ "5000\n",
804
+ "----------------------------------------------------------------------------------------------------\n",
805
+ "FINE-TUNED MODEL OUTPUT:\n",
806
+ "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
807
+ "====================================================================================================\n",
808
+ "\n",
809
+ "----------------------------------------------------------------------------------------------------\n",
810
+ "Example 3\n",
811
+ "----------------------------------------------------------------------------------------------------\n",
812
+ "INPUT PROMPT:\n",
813
+ "Context:\n",
814
+ "CREATE TABLE libraries (name VARCHAR(255), state VARCHAR(255), population DECIMAL(10,2), libraries DECIMAL(5,2)); INSERT INTO libraries (name, state, population, libraries) VALUES ('Library1', 'California', 39512223, 3154), ('Library2', 'Texas', 29528404, 2212), ('Library3', 'Florida', 21644287, 1835);\n",
815
+ "\n",
816
+ "Query:\n",
817
+ "Show the top 3 states with the most public libraries per capita.\n",
818
+ "\n",
819
+ "Response:\n",
820
+ "\n",
821
+ "----------------------------------------------------------------------------------------------------\n",
822
+ "HUMAN RESPONSE:\n",
823
+ "SELECT state, (libraries / population) AS libraries_per_capita FROM libraries ORDER BY libraries_per_capita DESC LIMIT 3;\n",
824
+ "----------------------------------------------------------------------------------------------------\n",
825
+ "ORIGINAL MODEL OUTPUT:\n",
826
+ "California, 39512223, 3154\n",
827
+ "----------------------------------------------------------------------------------------------------\n",
828
+ "FINE-TUNED MODEL OUTPUT:\n",
829
+ "SELECT state, population, RANK() OVER (ORDER BY population DESC) as rank FROM libraries GROUP BY state ORDER BY rank DESC LIMIT 3;\n",
830
+ "====================================================================================================\n",
831
+ "\n",
832
+ "----------------------------------------------------------------------------------------------------\n",
833
+ "Example 4\n",
834
+ "----------------------------------------------------------------------------------------------------\n",
835
+ "INPUT PROMPT:\n",
836
+ "Context:\n",
837
+ "CREATE TABLE users (id INT, location VARCHAR(50)); CREATE TABLE posts (id INT, user_id INT, created_at DATETIME);\n",
838
+ "\n",
839
+ "Query:\n",
840
+ "What is the total number of posts made by users located in Australia, in the last month?\n",
841
+ "\n",
842
+ "Response:\n",
843
+ "\n",
844
+ "----------------------------------------------------------------------------------------------------\n",
845
+ "HUMAN RESPONSE:\n",
846
+ "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH);\n",
847
+ "----------------------------------------------------------------------------------------------------\n",
848
+ "ORIGINAL MODEL OUTPUT:\n",
849
+ "INT users created a total of 50 posts in Australia in the last month.\n",
850
+ "----------------------------------------------------------------------------------------------------\n",
851
+ "FINE-TUNED MODEL OUTPUT:\n",
852
+ "SELECT COUNT(*) FROM posts p JOIN users u ON p.user_id = u.id WHERE u.location = 'Australia' AND p.created_at >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);\n",
853
+ "====================================================================================================\n",
854
+ "\n"
855
+ ]
856
+ },
857
+ {
858
+ "name": "stderr",
859
+ "output_type": "stream",
860
+ "text": [
861
+ "2025-03-17 17:11:00,034 - INFO - Starting evaluation on the full test set using batching.\n"
862
+ ]
863
+ },
864
+ {
865
+ "name": "stdout",
866
+ "output_type": "stream",
867
+ "text": [
868
+ "----------------------------------------------------------------------------------------------------\n",
869
+ "Example 5\n",
870
+ "----------------------------------------------------------------------------------------------------\n",
871
+ "INPUT PROMPT:\n",
872
+ "Context:\n",
873
+ "CREATE TABLE WindFarms (FarmID INT, FarmName VARCHAR(255), Capacity DECIMAL(5,2), Country VARCHAR(255)); INSERT INTO WindFarms (FarmID, FarmName, Capacity, Country) VALUES (1, 'WindFarm1', 150, 'USA'), (2, 'WindFarm2', 200, 'Canada'), (3, 'WindFarm3', 120, 'Mexico');\n",
874
+ "\n",
875
+ "Query:\n",
876
+ "List the total installed capacity of wind farms in the WindEnergy schema for each country?\n",
877
+ "\n",
878
+ "Response:\n",
879
+ "\n",
880
+ "----------------------------------------------------------------------------------------------------\n",
881
+ "HUMAN RESPONSE:\n",
882
+ "SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country;\n",
883
+ "----------------------------------------------------------------------------------------------------\n",
884
+ "ORIGINAL MODEL OUTPUT:\n",
885
+ "1, 150, USA, (2, 200, Canada, 3), 120, Mexico\n",
886
+ "----------------------------------------------------------------------------------------------------\n",
887
+ "FINE-TUNED MODEL OUTPUT:\n",
888
+ "SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n",
889
+ "====================================================================================================\n",
890
+ "\n"
891
+ ]
892
+ },
893
+ {
894
+ "name": "stderr",
895
+ "output_type": "stream",
896
+ "text": [
897
+ "2025-03-17 18:28:59,727 - INFO - Full test set comparison (first 5 rows):\n",
898
+ " Human Response \\\n",
899
+ "0 SELECT command_name, type FROM defense_securit... \n",
900
+ "1 SELECT SUM(cost) FROM incidents WHERE cause = ... \n",
901
+ "2 SELECT state, (libraries / population) AS libr... \n",
902
+ "3 SELECT COUNT(posts.id) FROM posts INNER JOIN u... \n",
903
+ "4 SELECT Country, SUM(Capacity) as TotalCapacity... \n",
904
+ "\n",
905
+ " Original Model Output \\\n",
906
+ "0 USCYBERCOM, JTF-CND, offensive Cyber operation... \n",
907
+ "1 t = t. \n",
908
+ "2 California \n",
909
+ "3 The total number of users in Australia is 50. \n",
910
+ "4 a \n",
911
+ "\n",
912
+ " Fine-Tuned Model Output \n",
913
+ "0 SELECT command_name, type FROM military_cyber_... \n",
914
+ "1 SELECT SUM(cost) FROM incidents WHERE cause = ... \n",
915
+ "2 SELECT state, t.population, t.tut FROM librari... \n",
916
+ "3 SELECT COUNT(*) FROM posts WHERE CUTS(CUTS.id,... \n",
917
+ "4 SELECT Country, SUM(Capacity) FROM WindFarms G... \n"
918
+ ]
919
+ },
920
+ {
921
+ "name": "stdout",
922
+ "output_type": "stream",
923
+ "text": [
924
+ "\n",
925
+ "Full Test Set Comparison (First 5 Rows):\n",
926
+ " Human Response Original Model Output Fine-Tuned Model Output\n",
927
+ " SELECT command_name, type FROM defense_security.Military_Cyber_Commands; USCYBERCOM, JTF-CND, offensive Cyber operations, 10th Fleet, Network Warfare SELECT command_name, type FROM military_cyber_Commands;\n",
928
+ " SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH); t = t. SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
929
+ " SELECT state, (libraries / population) AS libraries_per_capita FROM libraries ORDER BY libraries_per_capita DESC LIMIT 3; California SELECT state, t.population, t.tut FROM libraries t JOIN t ON t.state = t.state GROUP BY state ORDER BY t.tut DESC LIMIT 3;\n",
930
+ "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH); The total number of users in Australia is 50. SELECT COUNT(*) FROM posts WHERE CUTS(CUTS.id, CUTS.created_at) = CUTS.id AND CUTS.id = CUTS.id WHERE CUTS.location = 'Australia' AND CUTS.created_at >= DATE_SUB(CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CU\n",
931
+ " SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country; a SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n"
932
+ ]
933
+ },
934
+ {
935
+ "data": {
936
+ "application/vnd.jupyter.widget-view+json": {
937
+ "model_id": "fb9a4b84525845e78668fbb5472ac4c8",
938
+ "version_major": 2,
939
+ "version_minor": 0
940
+ },
941
+ "text/plain": [
942
+ "Downloading builder script: 0%| | 0.00/5.94k [00:00<?, ?B/s]"
943
+ ]
944
+ },
945
+ "metadata": {},
946
+ "output_type": "display_data"
947
+ },
948
+ {
949
+ "data": {
950
+ "application/vnd.jupyter.widget-view+json": {
951
+ "model_id": "5a92eb8c1607450d8babbce26891eb97",
952
+ "version_major": 2,
953
+ "version_minor": 0
954
+ },
955
+ "text/plain": [
956
+ "Downloading extra modules: 0%| | 0.00/1.55k [00:00<?, ?B/s]"
957
+ ]
958
+ },
959
+ "metadata": {},
960
+ "output_type": "display_data"
961
+ },
962
+ {
963
+ "data": {
964
+ "application/vnd.jupyter.widget-view+json": {
965
+ "model_id": "e5b5b1034f354abfbdfc46f0ff2b9349",
966
+ "version_major": 2,
967
+ "version_minor": 0
968
+ },
969
+ "text/plain": [
970
+ "Downloading extra modules: 0%| | 0.00/3.34k [00:00<?, ?B/s]"
971
+ ]
972
+ },
973
+ "metadata": {},
974
+ "output_type": "display_data"
975
+ },
976
+ {
977
+ "name": "stderr",
978
+ "output_type": "stream",
979
+ "text": [
980
+ "2025-03-17 18:29:02,580 - INFO - Using default tokenizer.\n",
981
+ "2025-03-17 18:30:27,253 - INFO - Using default tokenizer.\n"
982
+ ]
983
+ },
984
+ {
985
+ "name": "stdout",
986
+ "output_type": "stream",
987
+ "text": [
988
+ "\n",
989
+ "====================================================================================================\n",
990
+ "Evaluation Metrics:\n",
991
+ "====================================================================================================\n",
992
+ "ORIGINAL MODEL:\n",
993
+ " ROUGE: {'rouge1': np.float64(0.033688028857640176), 'rouge2': np.float64(0.008171862977966522), 'rougeL': np.float64(0.030557406905046474), 'rougeLsum': np.float64(0.030592110084298876)}\n",
994
+ " BLEU: {'bleu': 0.0036692781190090368, 'precisions': [0.02284408025462027, 0.004200643881640979, 0.002134841269783046, 0.0008848453895992066], 'brevity_penalty': 1.0, 'length_ratio': 1.1809102409373358, 'translation_length': 1421725, 'reference_length': 1203923}\n",
995
+ " Fuzzy Match Score: 11.31%\n",
996
+ " Exact Match Accuracy: 0.00%\n",
997
+ "\n",
998
+ "FINE-TUNED MODEL:\n",
999
+ " ROUGE: {'rouge1': np.float64(0.6914345907518044), 'rouge2': np.float64(0.5453255406268581), 'rougeL': np.float64(0.6642891642898592), 'rougeLsum': np.float64(0.6642865716725223)}\n",
1000
+ " BLEU: {'bleu': 0.31698443630421885, 'precisions': [0.46303833317311294, 0.34558772459086096, 0.2792686360724928, 0.2259198229483191], 'brevity_penalty': 1.0, 'length_ratio': 1.4083799379196178, 'translation_length': 1695581, 'reference_length': 1203923}\n",
1001
+ " Fuzzy Match Score: 81.98%\n",
1002
+ " Exact Match Accuracy: 16.39%\n",
1003
+ "====================================================================================================\n"
1004
+ ]
1005
+ }
1006
+ ],
1007
+ "source": [
1008
+ "from rapidfuzz import fuzz\n",
1009
+ "import pandas as pd\n",
1010
+ "import re\n",
1011
+ "import evaluate\n",
1012
+ "\n",
1013
+ "# --- Helper Functions for SQL Normalization and Exact Match ---\n",
1014
+ "def normalize_sql(sql):\n",
1015
+ " \"\"\"Normalize SQL by stripping whitespace and lowercasing.\"\"\"\n",
1016
+ " return \" \".join(sql.strip().lower().split())\n",
1017
+ "\n",
1018
+ "def compute_exact_match(predictions, references):\n",
1019
+ " \"\"\"Computes the exact match accuracy after normalization.\"\"\"\n",
1020
+ " matches = sum(1 for pred, ref in zip(predictions, references)\n",
1021
+ " if normalize_sql(pred) == normalize_sql(ref))\n",
1022
+ " return (matches / len(predictions)) * 100 if predictions else 0\n",
1023
+ "\n",
1024
+ "def compute_fuzzy_match(predictions, references):\n",
1025
+ " \"\"\"Computes a soft matching score using token_set_ratio from rapidfuzz.\"\"\"\n",
1026
+ " scores = [fuzz.token_set_ratio(pred, ref) for pred, ref in zip(predictions, references)]\n",
1027
+ " return sum(scores) / len(scores) if scores else 0\n",
1028
+ "\n",
1029
+ "# --- Part A: Inference on 5 Examples with Real Responses (unchanged) ---\n",
1030
+ "logger.info(\"Running inference on 5 examples (displaying real responses).\")\n",
1031
+ "\n",
1032
+ "num_examples = 5\n",
1033
+ "sample_queries = dataset[\"test\"][:num_examples][\"query\"]\n",
1034
+ "sample_contexts = dataset[\"test\"][:num_examples][\"context\"]\n",
1035
+ "sample_human_responses = dataset[\"test\"][:num_examples][\"response\"]\n",
1036
+ "\n",
1037
+ "print(\"\\n\" + \"=\"*100)\n",
1038
+ "for idx in range(num_examples):\n",
1039
+ " prompt = f\"\"\"Context:\n",
1040
+ "{sample_contexts[idx]}\n",
1041
+ "\n",
1042
+ "Query:\n",
1043
+ "{sample_queries[idx]}\n",
1044
+ "\n",
1045
+ "Response:\n",
1046
+ "\"\"\"\n",
1047
+ " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
1048
+ " \n",
1049
+ " # Generate outputs with both models using keyword arguments\n",
1050
+ " orig_out_ids = original_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=200)\n",
1051
+ " finetuned_out_ids = finetuned_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=200)\n",
1052
+ " \n",
1053
+ " orig_text = tokenizer.decode(orig_out_ids[0], skip_special_tokens=True)\n",
1054
+ " finetuned_text = tokenizer.decode(finetuned_out_ids[0], skip_special_tokens=True)\n",
1055
+ " \n",
1056
+ " print(\"-\" * 100)\n",
1057
+ " print(f\"Example {idx+1}\")\n",
1058
+ " print(\"-\" * 100)\n",
1059
+ " print(\"INPUT PROMPT:\")\n",
1060
+ " print(prompt)\n",
1061
+ " print(\"-\" * 100)\n",
1062
+ " print(\"HUMAN RESPONSE:\")\n",
1063
+ " print(sample_human_responses[idx])\n",
1064
+ " print(\"-\" * 100)\n",
1065
+ " print(\"ORIGINAL MODEL OUTPUT:\")\n",
1066
+ " print(orig_text)\n",
1067
+ " print(\"-\" * 100)\n",
1068
+ " print(\"FINE-TUNED MODEL OUTPUT:\")\n",
1069
+ " print(finetuned_text)\n",
1070
+ " print(\"=\" * 100 + \"\\n\")\n",
1071
+ " clear_memory()\n",
1072
+ "\n",
1073
+ "# --- Part B: Evaluation on Full Test Set with Batching (Optimized) ---\n",
1074
+ "logger.info(\"Starting evaluation on the full test set using batching.\")\n",
1075
+ "\n",
1076
+ "all_human_responses = []\n",
1077
+ "all_original_responses = []\n",
1078
+ "all_finetuned_responses = []\n",
1079
+ "\n",
1080
+ "batch_size = 128 # Adjust batch size based on your GPU memory\n",
1081
+ "test_dataset = dataset[\"test\"]\n",
1082
+ "\n",
1083
+ "for i in range(0, len(test_dataset), batch_size):\n",
1084
+ " # Slicing the dataset returns a dict of lists\n",
1085
+ " batch = test_dataset[i:i+batch_size]\n",
1086
+ " \n",
1087
+ " # Construct prompts for each example in the batch by iterating over indices\n",
1088
+ " prompts = [\n",
1089
+ " f\"Context:\\n{batch['context'][j]}\\n\\nQuery:\\n{batch['query'][j]}\\n\\nResponse:\"\n",
1090
+ " for j in range(len(batch[\"context\"]))\n",
1091
+ " ]\n",
1092
+ " \n",
1093
+ " # Extend human responses for each example\n",
1094
+ " all_human_responses.extend(batch[\"response\"])\n",
1095
+ " \n",
1096
+ " # Tokenize the batch of prompts\n",
1097
+ " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True).to(device)\n",
1098
+ " \n",
1099
+ " # Generate outputs with both models for the batch\n",
1100
+ " orig_ids = original_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=300)\n",
1101
+ " finetuned_ids = finetuned_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=300)\n",
1102
+ " \n",
1103
+ " # Decode each sample in the batch\n",
1104
+ " orig_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in orig_ids]\n",
1105
+ " finetuned_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in finetuned_ids]\n",
1106
+ " \n",
1107
+ " all_original_responses.extend(orig_texts)\n",
1108
+ " all_finetuned_responses.extend(finetuned_texts)\n",
1109
+ " clear_memory()\n",
1110
+ "\n",
1111
+ "# Create a DataFrame for a quick comparison of results\n",
1112
+ "zipped_all = list(zip(all_human_responses, all_original_responses, all_finetuned_responses))\n",
1113
+ "df_full = pd.DataFrame(zipped_all, columns=[\"Human Response\", \"Original Model Output\", \"Fine-Tuned Model Output\"])\n",
1114
+ "logger.info(\"Full test set comparison (first 5 rows):\\n%s\", df_full.head())\n",
1115
+ "print(\"\\nFull Test Set Comparison (First 5 Rows):\")\n",
1116
+ "print(df_full.head().to_string(index=False))\n",
1117
+ "clear_memory()\n",
1118
+ "\n",
1119
+ "# --- Compute Evaluation Metrics ---\n",
1120
+ "# Load evaluation libraries\n",
1121
+ "rouge = evaluate.load(\"rouge\")\n",
1122
+ "bleu = evaluate.load(\"bleu\")\n",
1123
+ "\n",
1124
+ "# Compute metrics for the original (non-fine-tuned) model\n",
1125
+ "orig_rouge = rouge.compute(\n",
1126
+ " predictions=all_original_responses,\n",
1127
+ " references=all_human_responses,\n",
1128
+ " use_aggregator=True,\n",
1129
+ " use_stemmer=True,\n",
1130
+ ")\n",
1131
+ "orig_bleu = bleu.compute(\n",
1132
+ " predictions=all_original_responses,\n",
1133
+ " references=[[ref] for ref in all_human_responses]\n",
1134
+ ")\n",
1135
+ "orig_fuzzy = compute_fuzzy_match(all_original_responses, all_human_responses)\n",
1136
+ "orig_exact = compute_exact_match(all_original_responses, all_human_responses)\n",
1137
+ "\n",
1138
+ "# Compute metrics for the fine-tuned model\n",
1139
+ "finetuned_rouge = rouge.compute(\n",
1140
+ " predictions=all_finetuned_responses,\n",
1141
+ " references=all_human_responses,\n",
1142
+ " use_aggregator=True,\n",
1143
+ " use_stemmer=True,\n",
1144
+ ")\n",
1145
+ "finetuned_bleu = bleu.compute(\n",
1146
+ " predictions=all_finetuned_responses,\n",
1147
+ " references=[[ref] for ref in all_human_responses]\n",
1148
+ ")\n",
1149
+ "finetuned_fuzzy = compute_fuzzy_match(all_finetuned_responses, all_human_responses)\n",
1150
+ "finetuned_exact = compute_exact_match(all_finetuned_responses, all_human_responses)\n",
1151
+ "\n",
1152
+ "print(\"\\n\" + \"=\"*100)\n",
1153
+ "print(\"Evaluation Metrics:\")\n",
1154
+ "print(\"=\"*100)\n",
1155
+ "print(\"ORIGINAL MODEL:\")\n",
1156
+ "print(f\" ROUGE: {orig_rouge}\")\n",
1157
+ "print(f\" BLEU: {orig_bleu}\")\n",
1158
+ "print(f\" Fuzzy Match Score: {orig_fuzzy:.2f}%\")\n",
1159
+ "print(f\" Exact Match Accuracy: {orig_exact:.2f}%\\n\")\n",
1160
+ "print(\"FINE-TUNED MODEL:\")\n",
1161
+ "print(f\" ROUGE: {finetuned_rouge}\")\n",
1162
+ "print(f\" BLEU: {finetuned_bleu}\")\n",
1163
+ "print(f\" Fuzzy Match Score: {finetuned_fuzzy:.2f}%\")\n",
1164
+ "print(f\" Exact Match Accuracy: {finetuned_exact:.2f}%\")\n",
1165
+ "print(\"=\"*100)\n",
1166
+ "clear_memory()\n"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "cell_type": "code",
1171
+ "execution_count": 32,
1172
+ "id": "462546a7-6928-4723-b00e-23c3a4091d99",
1173
+ "metadata": {},
1174
+ "outputs": [
1175
+ {
1176
+ "name": "stderr",
1177
+ "output_type": "stream",
1178
+ "text": [
1179
+ "2025-03-18 16:55:06,158 - INFO - Running inference with deterministic decoding and beam search.\n"
1180
+ ]
1181
+ },
1182
+ {
1183
+ "name": "stdout",
1184
+ "output_type": "stream",
1185
+ "text": [
1186
+ "Prompt:\n",
1187
+ "Context:\n",
1188
+ "CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), (3, 'Charlie', 'Canada'), (4, 'David', 'USA'); INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES (101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), (103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), (105, 4, 900, '2024-03-05');\n",
1189
+ "\n",
1190
+ "Query:\n",
1191
+ "Retrieve the total order amount for each customer, showing only customers from the USA, and sort the result by total order amount in descending order.\n",
1192
+ "\n",
1193
+ "Response:\n",
1194
+ "SELECT customers.name, SUM(orders.total_amount) as total_amount FROM customers INNER JOIN orders ON customers.id = orders.customer_id WHERE customers.country = 'USA' GROUP BY customers.name ORDER BY total_amount DESC;\n",
1195
+ "\n",
1196
+ "EXPECTED RESPONSE:\n",
1197
+ "SELECT c.name, SUM(o.total_amount) as total_order_amount FROM customers c JOIN orders o ON c.id = o.customer_id WHERE c.country = 'USA' GROUP BY c.name ORDER BY total_order_amount DESC;\n"
1198
+ ]
1199
+ }
1200
+ ],
1201
+ "source": [
1202
+ "import torch\n",
1203
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
1204
+ "import logging\n",
1205
+ "\n",
1206
+ "# Set up logging\n",
1207
+ "logging.basicConfig(\n",
1208
+ " level=logging.INFO,\n",
1209
+ " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
1210
+ ")\n",
1211
+ "logger = logging.getLogger(__name__)\n",
1212
+ "\n",
1213
+ "# Ensure device is set (GPU if available)\n",
1214
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1215
+ "\n",
1216
+ "# Load the fine-tuned model and tokenizer\n",
1217
+ "model_name = \"text2sql_flant5base_finetuned\" # Directory of your fine-tuned model\n",
1218
+ "finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
1219
+ "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
1220
+ "finetuned_model.to(device)\n",
1221
+ "\n",
1222
+ "def run_inference(prompt_text: str) -> str:\n",
1223
+ " \"\"\"\n",
1224
+ " Runs inference on the fine-tuned model using deterministic decoding\n",
1225
+ " with beam search, returning the generated SQL query.\n",
1226
+ " \"\"\"\n",
1227
+ " inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(device)\n",
1228
+ " generated_ids = finetuned_model.generate(\n",
1229
+ " input_ids=inputs[\"input_ids\"],\n",
1230
+ " max_new_tokens=250, # Adjust based on query complexity\n",
1231
+ " temperature=0.0, # Deterministic output\n",
1232
+ " num_beams=3, # Beam search for better output quality\n",
1233
+ " early_stopping=True, # Stop early if possible\n",
1234
+ " )\n",
1235
+ " return tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1236
+ "\n",
1237
+ "# Sample context and query (example)\n",
1238
+ "context = (\n",
1239
+ " \"CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); \"\n",
1240
+ " \"CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), \"\n",
1241
+ " \"order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); \"\n",
1242
+ " \"INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), \"\n",
1243
+ " \"(3, 'Charlie', 'Canada'), (4, 'David', 'USA'); \"\n",
1244
+ " \"INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES \"\n",
1245
+ " \"(101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), \"\n",
1246
+ " \"(103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), \"\n",
1247
+ " \"(105, 4, 900, '2024-03-05');\"\n",
1248
+ ")\n",
1249
+ "query = (\n",
1250
+ " \"Retrieve the total order amount for each customer, showing only customers from the USA, \"\n",
1251
+ " \"and sort the result by total order amount in descending order.\"\n",
1252
+ ")\n",
1253
+ "\n",
1254
+ "# Construct the prompt\n",
1255
+ "sample_prompt = f\"\"\"Context:\n",
1256
+ "{context}\n",
1257
+ "\n",
1258
+ "Query:\n",
1259
+ "{query}\n",
1260
+ "\n",
1261
+ "Response:\n",
1262
+ "\"\"\"\n",
1263
+ "\n",
1264
+ "logger.info(\"Running inference with deterministic decoding and beam search.\")\n",
1265
+ "generated_sql = run_inference(sample_prompt)\n",
1266
+ "\n",
1267
+ "# Define the expected response (this is a placeholder - update as necessary)\n",
1268
+ "expected_response = (\n",
1269
+ " \"SELECT c.name, SUM(o.total_amount) as total_order_amount \"\n",
1270
+ " \"FROM customers c \"\n",
1271
+ " \"JOIN orders o ON c.id = o.customer_id \"\n",
1272
+ " \"WHERE c.country = 'USA' \"\n",
1273
+ " \"GROUP BY c.name \"\n",
1274
+ " \"ORDER BY total_order_amount DESC;\"\n",
1275
+ ")\n",
1276
+ "\n",
1277
+ "# Print output in the given format\n",
1278
+ "print(\"Prompt:\")\n",
1279
+ "print(\"Context:\")\n",
1280
+ "print(context)\n",
1281
+ "print(\"\\nQuery:\")\n",
1282
+ "print(query)\n",
1283
+ "print(\"\\nResponse:\")\n",
1284
+ "print(generated_sql)\n",
1285
+ "print(\"\\nEXPECTED RESPONSE:\")\n",
1286
+ "print(expected_response)\n"
1287
+ ]
1288
+ },
1289
+ {
1290
+ "cell_type": "code",
1291
+ "execution_count": 20,
1292
+ "id": "a69f268e-bc69-4633-9c15-4e118c20178e",
1293
+ "metadata": {},
1294
+ "outputs": [
1295
+ {
1296
+ "name": "stdout",
1297
+ "output_type": "stream",
1298
+ "text": [
1299
+ "✅ LoRA adapter saved at: text2sql_flant5base_finetuned\n",
1300
+ "✅ Fully merged fine-tuned model saved at: text2sql_flant5base_finetuned_full\n"
1301
+ ]
1302
+ }
1303
+ ],
1304
+ "source": [
1305
+ "import torch\n",
1306
+ "import json\n",
1307
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
1308
+ "from peft import PeftModel\n",
1309
+ "\n",
1310
+ "# Define paths\n",
1311
+ "base_model_name = \"google/flan-t5-base\" # Base model name\n",
1312
+ "lora_model_path = \"text2sql_flant5base_finetuned\" # Folder where LoRA adapter is saved\n",
1313
+ "full_model_output_path = \"text2sql_flant5base_finetuned_full\" # For merged full model\n",
1314
+ "\n",
1315
+ "# Load base model and tokenizer\n",
1316
+ "base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)\n",
1317
+ "tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n",
1318
+ "\n",
1319
+ "# Load fine-tuned LoRA adapter model\n",
1320
+ "lora_model = PeftModel.from_pretrained(base_model, lora_model_path)\n",
1321
+ "\n",
1322
+ "# Save the LoRA adapter separately (for users who want lightweight adapters)\n",
1323
+ "lora_model.save_pretrained(lora_model_path)\n",
1324
+ "tokenizer.save_pretrained(lora_model_path)\n",
1325
+ "\n",
1326
+ "# Merge LoRA into the base model to create a fully fine-tuned model\n",
1327
+ "merged_model = lora_model.merge_and_unload()\n",
1328
+ "\n",
1329
+ "# Save the full fine-tuned model\n",
1330
+ "merged_model.save_pretrained(full_model_output_path)\n",
1331
+ "tokenizer.save_pretrained(full_model_output_path)\n",
1332
+ "\n",
1333
+ "# Save generation config (optional but recommended for inference settings)\n",
1334
+ "generation_config = {\n",
1335
+ " \"max_new_tokens\": 250,\n",
1336
+ " \"temperature\": 0.0,\n",
1337
+ " \"num_beams\": 3,\n",
1338
+ " \"early_stopping\": True\n",
1339
+ "}\n",
1340
+ "with open(f\"{full_model_output_path}/generation_config.json\", \"w\") as f:\n",
1341
+ " json.dump(generation_config, f)\n",
1342
+ "\n",
1343
+ "print(f\"✅ LoRA adapter saved at: {lora_model_path}\")\n",
1344
+ "print(f\"✅ Fully merged fine-tuned model saved at: {full_model_output_path}\")\n"
1345
+ ]
1346
+ },
1347
+ {
1348
+ "cell_type": "code",
1349
+ "execution_count": 33,
1350
+ "id": "f1c95dfc-6662-44d8-8ecc-bff414fecee5",
1351
+ "metadata": {},
1352
+ "outputs": [
1353
+ {
1354
+ "name": "stderr",
1355
+ "output_type": "stream",
1356
+ "text": [
1357
+ "2025-03-18 16:55:46,428 - INFO - Running inference with beam search decoding.\n"
1358
+ ]
1359
+ },
1360
+ {
1361
+ "name": "stdout",
1362
+ "output_type": "stream",
1363
+ "text": [
1364
+ "Prompt:\n",
1365
+ "Context:\n",
1366
+ "CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); INSERT INTO employees (id, name, department, salary) VALUES (1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), (3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); INSERT INTO projects (project_id, project_name, budget) VALUES (101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); INSERT INTO employee_projects (employee_id, project_id, role) VALUES (1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), (4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\n",
1367
+ "\n",
1368
+ "Query:\n",
1369
+ "Find the names of employees who are working on the 'AI Research' project along with their roles.\n",
1370
+ "\n",
1371
+ "Response:\n",
1372
+ "SELECT employees.name, employee_projects.role FROM employees INNER JOIN employee_projects ON employees.id = employee_projects.employee_id INNER JOIN projects ON employee_projects.project_id = projects.project_id WHERE projects.project_name = 'AI Research';\n"
1373
+ ]
1374
+ }
1375
+ ],
1376
+ "source": [
1377
+ "import torch\n",
1378
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
1379
+ "import logging\n",
1380
+ "\n",
1381
+ "# Set up logging\n",
1382
+ "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
1383
+ "logger = logging.getLogger(__name__)\n",
1384
+ "\n",
1385
+ "# Ensure device is set (GPU if available)\n",
1386
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1387
+ "\n",
1388
+ "# Load the fine-tuned model and tokenizer\n",
1389
+ "model_name = \"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\"\n",
1390
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)\n",
1391
+ "tokenizer = AutoTokenizer.from_pretrained(\"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\")\n",
1392
+ "\n",
1393
+ "# Ensure decoder start token is set\n",
1394
+ "if model.config.decoder_start_token_id is None:\n",
1395
+ " model.config.decoder_start_token_id = tokenizer.pad_token_id\n",
1396
+ "\n",
1397
+ "def run_inference(prompt_text: str) -> str:\n",
1398
+ " \"\"\"\n",
1399
+ " Runs inference on the fine-tuned model using beam search with fixes for repetition.\n",
1400
+ " \"\"\"\n",
1401
+ " inputs = tokenizer(prompt_text, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n",
1402
+ "\n",
1403
+ " generated_ids = model.generate(\n",
1404
+ " input_ids=inputs[\"input_ids\"],\n",
1405
+ " decoder_start_token_id=model.config.decoder_start_token_id, \n",
1406
+ " max_new_tokens=100, \n",
1407
+ " temperature=0.1, \n",
1408
+ " num_beams=5, \n",
1409
+ " repetition_penalty=1.2, \n",
1410
+ " early_stopping=True, \n",
1411
+ " )\n",
1412
+ "\n",
1413
+ " generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1414
+ "\n",
1415
+ " # Post-processing to remove repeated text\n",
1416
+ " generated_sql = generated_sql.split(\";\")[0] + \";\" # Keep only the first valid SQL query\n",
1417
+ "\n",
1418
+ " return generated_sql\n",
1419
+ "\n",
1420
+ "# Example usage:\n",
1421
+ "context = (\n",
1422
+ " \"CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); \"\n",
1423
+ " \"CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); \"\n",
1424
+ " \"CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), \"\n",
1425
+ " \"FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); \"\n",
1426
+ " \"INSERT INTO employees (id, name, department, salary) VALUES \"\n",
1427
+ " \"(1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), \"\n",
1428
+ " \"(3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); \"\n",
1429
+ " \"INSERT INTO projects (project_id, project_name, budget) VALUES \"\n",
1430
+ " \"(101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); \"\n",
1431
+ " \"INSERT INTO employee_projects (employee_id, project_id, role) VALUES \"\n",
1432
+ " \"(1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), \"\n",
1433
+ " \"(4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\"\n",
1434
+ ")\n",
1435
+ "\n",
1436
+ "query = (\"Find the names of employees who are working on the 'AI Research' project along with their roles.\")\n",
1437
+ "\n",
1438
+ "\n",
1439
+ "\n",
1440
+ "# Construct the prompt\n",
1441
+ "sample_prompt = f\"\"\"Context:\n",
1442
+ "{context}\n",
1443
+ "\n",
1444
+ "Query:\n",
1445
+ "{query}\n",
1446
+ "\n",
1447
+ "Response:\n",
1448
+ "\"\"\"\n",
1449
+ "\n",
1450
+ "logger.info(\"Running inference with beam search decoding.\")\n",
1451
+ "generated_sql = run_inference(sample_prompt)\n",
1452
+ "\n",
1453
+ "print(\"Prompt:\")\n",
1454
+ "print(\"Context:\")\n",
1455
+ "print(context)\n",
1456
+ "print(\"\\nQuery:\")\n",
1457
+ "print(query)\n",
1458
+ "print(\"\\nResponse:\")\n",
1459
+ "print(generated_sql)"
1460
+ ]
1461
+ },
1462
+ {
1463
+ "cell_type": "code",
1464
+ "execution_count": null,
1465
+ "id": "562458ed-53f4-44af-a7a3-e42a175c7245",
1466
+ "metadata": {},
1467
+ "outputs": [],
1468
+ "source": []
1469
+ }
1470
+ ],
1471
+ "metadata": {
1472
+ "kernelspec": {
1473
+ "display_name": "Python3 (ipykernel)",
1474
+ "language": "python",
1475
+ "name": "python3"
1476
+ },
1477
+ "language_info": {
1478
+ "codemirror_mode": {
1479
+ "name": "ipython",
1480
+ "version": 3
1481
+ },
1482
+ "file_extension": ".py",
1483
+ "mimetype": "text/x-python",
1484
+ "name": "python",
1485
+ "nbconvert_exporter": "python",
1486
+ "pygments_lexer": "ipython3",
1487
+ "version": "3.10.12"
1488
+ }
1489
+ },
1490
+ "nbformat": 4,
1491
+ "nbformat_minor": 5
1492
+ }