aarohanverma commited on
Commit
0788177
·
verified ·
1 Parent(s): ccc6cb6

Upload text2sql_flant5_qlora.ipynb

Browse files
Files changed (1) hide show
  1. text2sql_flant5_qlora.ipynb +152 -200
text2sql_flant5_qlora.ipynb CHANGED
@@ -3,34 +3,6 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
6
- "id": "cadbd30d-57ce-4ef2-889f-24bd0ff06b89",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "/workspace\n"
14
- ]
15
- }
16
- ],
17
- "source": [
18
- "!echo $PWD"
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 2,
24
- "id": "12d99875-d86b-4442-8682-b9751118d90e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "#!pip3 install evaluate datasets bitsandbytes transformers peft rapidfuzz absl-py"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 3,
34
  "id": "5f167a6f-5139-46e6-afb2-a1fa4d12f3fd",
35
  "metadata": {},
36
  "outputs": [],
@@ -60,7 +32,7 @@
60
  },
61
  {
62
  "cell_type": "code",
63
- "execution_count": 4,
64
  "id": "53684b5e-c27e-4eb9-815e-583aa194e096",
65
  "metadata": {},
66
  "outputs": [
@@ -83,7 +55,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 5,
87
  "id": "a47bf3cd-752d-4d1c-9697-70098d6204fa",
88
  "metadata": {},
89
  "outputs": [],
@@ -97,7 +69,7 @@
97
  },
98
  {
99
  "cell_type": "code",
100
- "execution_count": 6,
101
  "id": "f16df21e-9797-4f78-83a1-a2943759ba55",
102
  "metadata": {},
103
  "outputs": [],
@@ -109,7 +81,7 @@
109
  },
110
  {
111
  "cell_type": "code",
112
- "execution_count": 7,
113
  "id": "196e83da-6c8c-4cd7-bd70-2598a5e2a16a",
114
  "metadata": {},
115
  "outputs": [],
@@ -123,7 +95,7 @@
123
  },
124
  {
125
  "cell_type": "code",
126
- "execution_count": 8,
127
  "id": "cea22b9f-f309-4151-81ac-37547c8feeb0",
128
  "metadata": {},
129
  "outputs": [],
@@ -155,7 +127,7 @@
155
  },
156
  {
157
  "cell_type": "code",
158
- "execution_count": 9,
159
  "id": "d4eb82ce-1713-40b6-981d-43ce35aaa6f6",
160
  "metadata": {},
161
  "outputs": [
@@ -163,9 +135,9 @@
163
  "name": "stderr",
164
  "output_type": "stream",
165
  "text": [
166
- "2025-03-17 17:06:42,785 - INFO - Loading raw datasets from various sources...\n",
167
- "2025-03-17 17:07:15,400 - INFO - Total rows before dropping duplicates: 490241\n",
168
- "2025-03-17 17:07:16,852 - INFO - Total rows after dropping duplicates: 440785\n"
169
  ]
170
  }
171
  ],
@@ -198,7 +170,7 @@
198
  },
199
  {
200
  "cell_type": "code",
201
- "execution_count": 10,
202
  "id": "8446814e-5a2c-48a4-8c01-059afcf1d3c1",
203
  "metadata": {},
204
  "outputs": [
@@ -207,7 +179,7 @@
207
  "output_type": "stream",
208
  "text": [
209
  "Token indices sequence length is longer than the specified maximum sequence length for this model (1113 > 512). Running this sequence through the model will result in indexing errors\n",
210
- "2025-03-17 17:10:43,961 - INFO - Total rows after filtering by token length (prompt <= 500 and response <= 250 tokens): 398481\n"
211
  ]
212
  }
213
  ],
@@ -238,7 +210,7 @@
238
  },
239
  {
240
  "cell_type": "code",
241
- "execution_count": 11,
242
  "id": "177e1e6d-9fbc-442d-9774-5a3e5234329f",
243
  "metadata": {},
244
  "outputs": [
@@ -246,7 +218,7 @@
246
  "name": "stderr",
247
  "output_type": "stream",
248
  "text": [
249
- "2025-03-17 17:10:43,968 - INFO - Sample from filtered final_df:\n",
250
  " query \\\n",
251
  "0 Name the home team for carlton away team \n",
252
  "1 what will the population of Asia be when Latin... \n",
@@ -271,7 +243,7 @@
271
  },
272
  {
273
  "cell_type": "code",
274
- "execution_count": 12,
275
  "id": "0b639efe-ebeb-4b34-bc3f-accf776ba0da",
276
  "metadata": {},
277
  "outputs": [
@@ -279,13 +251,13 @@
279
  "name": "stderr",
280
  "output_type": "stream",
281
  "text": [
282
- "2025-03-17 17:10:44,311 - INFO - Final split sizes: Train: 338708, Test: 39848, Validation: 19925\n"
283
  ]
284
  },
285
  {
286
  "data": {
287
  "application/vnd.jupyter.widget-view+json": {
288
- "model_id": "11dc405cf3d54b6abc81b8eaf6742bea",
289
  "version_major": 2,
290
  "version_minor": 0
291
  },
@@ -299,7 +271,7 @@
299
  {
300
  "data": {
301
  "application/vnd.jupyter.widget-view+json": {
302
- "model_id": "868d3a0d08874c448faac4b50dbb3685",
303
  "version_major": 2,
304
  "version_minor": 0
305
  },
@@ -313,7 +285,7 @@
313
  {
314
  "data": {
315
  "application/vnd.jupyter.widget-view+json": {
316
- "model_id": "0370d0dd07514d5cae499ab93ca47ee8",
317
  "version_major": 2,
318
  "version_minor": 0
319
  },
@@ -328,8 +300,8 @@
328
  "name": "stderr",
329
  "output_type": "stream",
330
  "text": [
331
- "2025-03-17 17:10:45,869 - INFO - Merged and Saved Dataset Successfully!\n",
332
- "2025-03-17 17:10:45,870 - INFO - Dataset summary: DatasetDict({\n",
333
  " train: Dataset({\n",
334
  " features: ['query', 'context', 'response'],\n",
335
  " num_rows: 338708\n",
@@ -378,7 +350,7 @@
378
  },
379
  {
380
  "cell_type": "code",
381
- "execution_count": 13,
382
  "id": "9f6e1095-d72d-4e22-b20d-683f1f84544c",
383
  "metadata": {},
384
  "outputs": [
@@ -386,11 +358,11 @@
386
  "name": "stderr",
387
  "output_type": "stream",
388
  "text": [
389
- "2025-03-17 17:10:46,218 - INFO - Reloaded dataset from disk. Example from test split:\n",
390
  "{'query': \"Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\", 'context': \"CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\", 'response': 'SELECT command_name, type FROM defense_security.Military_Cyber_Commands;'}\n",
391
- "2025-03-17 17:10:46,475 - INFO - Loaded Tokenized Dataset from disk.\n",
392
- "2025-03-17 17:10:46,477 - INFO - Final tokenized dataset splits: dict_keys(['train', 'test', 'validation'])\n",
393
- "2025-03-17 17:10:46,483 - INFO - Sample tokenized record from train split:\n",
394
  "{'input_ids': tensor([ 1193, 6327, 10, 205, 4386, 6048, 332, 17098, 953, 834,\n",
395
  " 4350, 834, 4013, 41, 234, 834, 11650, 584, 4280, 28027,\n",
396
  " 6, 550, 834, 11650, 584, 4280, 28027, 3, 61, 3,\n",
@@ -564,7 +536,7 @@
564
  },
565
  {
566
  "cell_type": "code",
567
- "execution_count": 14,
568
  "id": "7f004e55-181c-47aa-9f3e-c7c1ceae780c",
569
  "metadata": {},
570
  "outputs": [
@@ -631,7 +603,7 @@
631
  },
632
  {
633
  "cell_type": "code",
634
- "execution_count": 15,
635
  "id": "f50e56c7-98b3-42bc-9129-89f3eff802e7",
636
  "metadata": {},
637
  "outputs": [
@@ -639,8 +611,8 @@
639
  "name": "stderr",
640
  "output_type": "stream",
641
  "text": [
642
- "2025-03-17 17:10:50,413 - INFO - Attempting to load the fine-tuned model...\n",
643
- "2025-03-17 17:10:51,949 - INFO - Fine-tuned model loaded successfully.\n"
644
  ]
645
  }
646
  ],
@@ -743,7 +715,7 @@
743
  },
744
  {
745
  "cell_type": "code",
746
- "execution_count": 16,
747
  "id": "f364eb6b-56cb-4533-8ef6-b5e7f56895aa",
748
  "metadata": {},
749
  "outputs": [
@@ -751,7 +723,9 @@
751
  "name": "stderr",
752
  "output_type": "stream",
753
  "text": [
754
- "2025-03-17 17:10:51,987 - INFO - Running inference on 5 examples (displaying real responses).\n"
 
 
755
  ]
756
  },
757
  {
@@ -777,7 +751,7 @@
777
  "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
778
  "----------------------------------------------------------------------------------------------------\n",
779
  "ORIGINAL MODEL OUTPUT:\n",
780
- "USCYBERCOM, JTF-CND, Offensive Cyber Operations, 10th Fleet, Network Warfare\n",
781
  "----------------------------------------------------------------------------------------------------\n",
782
  "FINE-TUNED MODEL OUTPUT:\n",
783
  "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
@@ -800,7 +774,7 @@
800
  "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
801
  "----------------------------------------------------------------------------------------------------\n",
802
  "ORIGINAL MODEL OUTPUT:\n",
803
- "5000\n",
804
  "----------------------------------------------------------------------------------------------------\n",
805
  "FINE-TUNED MODEL OUTPUT:\n",
806
  "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
@@ -846,7 +820,7 @@
846
  "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH);\n",
847
  "----------------------------------------------------------------------------------------------------\n",
848
  "ORIGINAL MODEL OUTPUT:\n",
849
- "INT users created a total of 50 posts in Australia in the last month.\n",
850
  "----------------------------------------------------------------------------------------------------\n",
851
  "FINE-TUNED MODEL OUTPUT:\n",
852
  "SELECT COUNT(*) FROM posts p JOIN users u ON p.user_id = u.id WHERE u.location = 'Australia' AND p.created_at >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);\n",
@@ -858,7 +832,7 @@
858
  "name": "stderr",
859
  "output_type": "stream",
860
  "text": [
861
- "2025-03-17 17:11:00,034 - INFO - Starting evaluation on the full test set using batching.\n"
862
  ]
863
  },
864
  {
@@ -882,7 +856,7 @@
882
  "SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country;\n",
883
  "----------------------------------------------------------------------------------------------------\n",
884
  "ORIGINAL MODEL OUTPUT:\n",
885
- "1, 150, USA, (2, 200, Canada, 3), 120, Mexico\n",
886
  "----------------------------------------------------------------------------------------------------\n",
887
  "FINE-TUNED MODEL OUTPUT:\n",
888
  "SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n",
@@ -890,51 +864,10 @@
890
  "\n"
891
  ]
892
  },
893
- {
894
- "name": "stderr",
895
- "output_type": "stream",
896
- "text": [
897
- "2025-03-17 18:28:59,727 - INFO - Full test set comparison (first 5 rows):\n",
898
- " Human Response \\\n",
899
- "0 SELECT command_name, type FROM defense_securit... \n",
900
- "1 SELECT SUM(cost) FROM incidents WHERE cause = ... \n",
901
- "2 SELECT state, (libraries / population) AS libr... \n",
902
- "3 SELECT COUNT(posts.id) FROM posts INNER JOIN u... \n",
903
- "4 SELECT Country, SUM(Capacity) as TotalCapacity... \n",
904
- "\n",
905
- " Original Model Output \\\n",
906
- "0 USCYBERCOM, JTF-CND, offensive Cyber operation... \n",
907
- "1 t = t. \n",
908
- "2 California \n",
909
- "3 The total number of users in Australia is 50. \n",
910
- "4 a \n",
911
- "\n",
912
- " Fine-Tuned Model Output \n",
913
- "0 SELECT command_name, type FROM military_cyber_... \n",
914
- "1 SELECT SUM(cost) FROM incidents WHERE cause = ... \n",
915
- "2 SELECT state, t.population, t.tut FROM librari... \n",
916
- "3 SELECT COUNT(*) FROM posts WHERE CUTS(CUTS.id,... \n",
917
- "4 SELECT Country, SUM(Capacity) FROM WindFarms G... \n"
918
- ]
919
- },
920
- {
921
- "name": "stdout",
922
- "output_type": "stream",
923
- "text": [
924
- "\n",
925
- "Full Test Set Comparison (First 5 Rows):\n",
926
- " Human Response Original Model Output Fine-Tuned Model Output\n",
927
- " SELECT command_name, type FROM defense_security.Military_Cyber_Commands; USCYBERCOM, JTF-CND, offensive Cyber operations, 10th Fleet, Network Warfare SELECT command_name, type FROM military_cyber_Commands;\n",
928
- " SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH); t = t. SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
929
- " SELECT state, (libraries / population) AS libraries_per_capita FROM libraries ORDER BY libraries_per_capita DESC LIMIT 3; California SELECT state, t.population, t.tut FROM libraries t JOIN t ON t.state = t.state GROUP BY state ORDER BY t.tut DESC LIMIT 3;\n",
930
- "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH); The total number of users in Australia is 50. SELECT COUNT(*) FROM posts WHERE CUTS(CUTS.id, CUTS.created_at) = CUTS.id AND CUTS.id = CUTS.id WHERE CUTS.location = 'Australia' AND CUTS.created_at >= DATE_SUB(CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CUTS.CU\n",
931
- " SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country; a SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n"
932
- ]
933
- },
934
  {
935
  "data": {
936
  "application/vnd.jupyter.widget-view+json": {
937
- "model_id": "fb9a4b84525845e78668fbb5472ac4c8",
938
  "version_major": 2,
939
  "version_minor": 0
940
  },
@@ -948,7 +881,7 @@
948
  {
949
  "data": {
950
  "application/vnd.jupyter.widget-view+json": {
951
- "model_id": "5a92eb8c1607450d8babbce26891eb97",
952
  "version_major": 2,
953
  "version_minor": 0
954
  },
@@ -962,7 +895,7 @@
962
  {
963
  "data": {
964
  "application/vnd.jupyter.widget-view+json": {
965
- "model_id": "e5b5b1034f354abfbdfc46f0ff2b9349",
966
  "version_major": 2,
967
  "version_minor": 0
968
  },
@@ -977,8 +910,8 @@
977
  "name": "stderr",
978
  "output_type": "stream",
979
  "text": [
980
- "2025-03-17 18:29:02,580 - INFO - Using default tokenizer.\n",
981
- "2025-03-17 18:30:27,253 - INFO - Using default tokenizer.\n"
982
  ]
983
  },
984
  {
@@ -990,27 +923,49 @@
990
  "Evaluation Metrics:\n",
991
  "====================================================================================================\n",
992
  "ORIGINAL MODEL:\n",
993
- " ROUGE: {'rouge1': np.float64(0.033688028857640176), 'rouge2': np.float64(0.008171862977966522), 'rougeL': np.float64(0.030557406905046474), 'rougeLsum': np.float64(0.030592110084298876)}\n",
994
- " BLEU: {'bleu': 0.0036692781190090368, 'precisions': [0.02284408025462027, 0.004200643881640979, 0.002134841269783046, 0.0008848453895992066], 'brevity_penalty': 1.0, 'length_ratio': 1.1809102409373358, 'translation_length': 1421725, 'reference_length': 1203923}\n",
995
- " Fuzzy Match Score: 11.31%\n",
996
  " Exact Match Accuracy: 0.00%\n",
997
  "\n",
998
  "FINE-TUNED MODEL:\n",
999
- " ROUGE: {'rouge1': np.float64(0.6914345907518044), 'rouge2': np.float64(0.5453255406268581), 'rougeL': np.float64(0.6642891642898592), 'rougeLsum': np.float64(0.6642865716725223)}\n",
1000
- " BLEU: {'bleu': 0.31698443630421885, 'precisions': [0.46303833317311294, 0.34558772459086096, 0.2792686360724928, 0.2259198229483191], 'brevity_penalty': 1.0, 'length_ratio': 1.4083799379196178, 'translation_length': 1695581, 'reference_length': 1203923}\n",
1001
- " Fuzzy Match Score: 81.98%\n",
1002
- " Exact Match Accuracy: 16.39%\n",
1003
  "====================================================================================================\n"
1004
  ]
1005
  }
1006
  ],
1007
  "source": [
1008
- "from rapidfuzz import fuzz\n",
1009
- "import pandas as pd\n",
1010
  "import re\n",
 
 
1011
  "import evaluate\n",
1012
  "\n",
1013
- "# --- Helper Functions for SQL Normalization and Exact Match ---\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
  "def normalize_sql(sql):\n",
1015
  " \"\"\"Normalize SQL by stripping whitespace and lowercasing.\"\"\"\n",
1016
  " return \" \".join(sql.strip().lower().split())\n",
@@ -1026,7 +981,16 @@
1026
  " scores = [fuzz.token_set_ratio(pred, ref) for pred, ref in zip(predictions, references)]\n",
1027
  " return sum(scores) / len(scores) if scores else 0\n",
1028
  "\n",
1029
- "# --- Part A: Inference on 5 Examples with Real Responses (unchanged) ---\n",
 
 
 
 
 
 
 
 
 
1030
  "logger.info(\"Running inference on 5 examples (displaying real responses).\")\n",
1031
  "\n",
1032
  "num_examples = 5\n",
@@ -1034,7 +998,7 @@
1034
  "sample_contexts = dataset[\"test\"][:num_examples][\"context\"]\n",
1035
  "sample_human_responses = dataset[\"test\"][:num_examples][\"response\"]\n",
1036
  "\n",
1037
- "print(\"\\n\" + \"=\"*100)\n",
1038
  "for idx in range(num_examples):\n",
1039
  " prompt = f\"\"\"Context:\n",
1040
  "{sample_contexts[idx]}\n",
@@ -1044,14 +1008,12 @@
1044
  "\n",
1045
  "Response:\n",
1046
  "\"\"\"\n",
1047
- " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
1048
- " \n",
1049
- " # Generate outputs with both models using keyword arguments\n",
1050
- " orig_out_ids = original_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=200)\n",
1051
- " finetuned_out_ids = finetuned_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=200)\n",
1052
  " \n",
1053
- " orig_text = tokenizer.decode(orig_out_ids[0], skip_special_tokens=True)\n",
1054
- " finetuned_text = tokenizer.decode(finetuned_out_ids[0], skip_special_tokens=True)\n",
 
1055
  " \n",
1056
  " print(\"-\" * 100)\n",
1057
  " print(f\"Example {idx+1}\")\n",
@@ -1063,10 +1025,10 @@
1063
  " print(sample_human_responses[idx])\n",
1064
  " print(\"-\" * 100)\n",
1065
  " print(\"ORIGINAL MODEL OUTPUT:\")\n",
1066
- " print(orig_text)\n",
1067
  " print(\"-\" * 100)\n",
1068
  " print(\"FINE-TUNED MODEL OUTPUT:\")\n",
1069
- " print(finetuned_text)\n",
1070
  " print(\"=\" * 100 + \"\\n\")\n",
1071
  " clear_memory()\n",
1072
  "\n",
@@ -1077,32 +1039,46 @@
1077
  "all_original_responses = []\n",
1078
  "all_finetuned_responses = []\n",
1079
  "\n",
1080
- "batch_size = 128 # Adjust batch size based on your GPU memory\n",
1081
  "test_dataset = dataset[\"test\"]\n",
1082
  "\n",
1083
  "for i in range(0, len(test_dataset), batch_size):\n",
1084
  " # Slicing the dataset returns a dict of lists\n",
1085
- " batch = test_dataset[i:i+batch_size]\n",
1086
  " \n",
1087
- " # Construct prompts for each example in the batch by iterating over indices\n",
1088
  " prompts = [\n",
1089
  " f\"Context:\\n{batch['context'][j]}\\n\\nQuery:\\n{batch['query'][j]}\\n\\nResponse:\"\n",
1090
  " for j in range(len(batch[\"context\"]))\n",
1091
  " ]\n",
1092
  " \n",
1093
- " # Extend human responses for each example\n",
1094
  " all_human_responses.extend(batch[\"response\"])\n",
1095
  " \n",
1096
- " # Tokenize the batch of prompts\n",
1097
- " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True).to(device)\n",
1098
  " \n",
1099
- " # Generate outputs with both models for the batch\n",
1100
- " orig_ids = original_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=300)\n",
1101
- " finetuned_ids = finetuned_model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=300)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  " \n",
1103
- " # Decode each sample in the batch\n",
1104
  " orig_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in orig_ids]\n",
1105
- " finetuned_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in finetuned_ids]\n",
1106
  " \n",
1107
  " all_original_responses.extend(orig_texts)\n",
1108
  " all_finetuned_responses.extend(finetuned_texts)\n",
@@ -1111,13 +1087,10 @@
1111
  "# Create a DataFrame for a quick comparison of results\n",
1112
  "zipped_all = list(zip(all_human_responses, all_original_responses, all_finetuned_responses))\n",
1113
  "df_full = pd.DataFrame(zipped_all, columns=[\"Human Response\", \"Original Model Output\", \"Fine-Tuned Model Output\"])\n",
1114
- "logger.info(\"Full test set comparison (first 5 rows):\\n%s\", df_full.head())\n",
1115
- "print(\"\\nFull Test Set Comparison (First 5 Rows):\")\n",
1116
- "print(df_full.head().to_string(index=False))\n",
1117
  "clear_memory()\n",
1118
  "\n",
1119
  "# --- Compute Evaluation Metrics ---\n",
1120
- "# Load evaluation libraries\n",
1121
  "rouge = evaluate.load(\"rouge\")\n",
1122
  "bleu = evaluate.load(\"bleu\")\n",
1123
  "\n",
@@ -1149,9 +1122,9 @@
1149
  "finetuned_fuzzy = compute_fuzzy_match(all_finetuned_responses, all_human_responses)\n",
1150
  "finetuned_exact = compute_exact_match(all_finetuned_responses, all_human_responses)\n",
1151
  "\n",
1152
- "print(\"\\n\" + \"=\"*100)\n",
1153
  "print(\"Evaluation Metrics:\")\n",
1154
- "print(\"=\"*100)\n",
1155
  "print(\"ORIGINAL MODEL:\")\n",
1156
  "print(f\" ROUGE: {orig_rouge}\")\n",
1157
  "print(f\" BLEU: {orig_bleu}\")\n",
@@ -1162,13 +1135,13 @@
1162
  "print(f\" BLEU: {finetuned_bleu}\")\n",
1163
  "print(f\" Fuzzy Match Score: {finetuned_fuzzy:.2f}%\")\n",
1164
  "print(f\" Exact Match Accuracy: {finetuned_exact:.2f}%\")\n",
1165
- "print(\"=\"*100)\n",
1166
- "clear_memory()\n"
1167
  ]
1168
  },
1169
  {
1170
  "cell_type": "code",
1171
- "execution_count": 32,
1172
  "id": "462546a7-6928-4723-b00e-23c3a4091d99",
1173
  "metadata": {},
1174
  "outputs": [
@@ -1176,7 +1149,7 @@
1176
  "name": "stderr",
1177
  "output_type": "stream",
1178
  "text": [
1179
- "2025-03-18 16:55:06,158 - INFO - Running inference with deterministic decoding and beam search.\n"
1180
  ]
1181
  },
1182
  {
@@ -1191,10 +1164,7 @@
1191
  "Retrieve the total order amount for each customer, showing only customers from the USA, and sort the result by total order amount in descending order.\n",
1192
  "\n",
1193
  "Response:\n",
1194
- "SELECT customers.name, SUM(orders.total_amount) as total_amount FROM customers INNER JOIN orders ON customers.id = orders.customer_id WHERE customers.country = 'USA' GROUP BY customers.name ORDER BY total_amount DESC;\n",
1195
- "\n",
1196
- "EXPECTED RESPONSE:\n",
1197
- "SELECT c.name, SUM(o.total_amount) as total_order_amount FROM customers c JOIN orders o ON c.id = o.customer_id WHERE c.country = 'USA' GROUP BY c.name ORDER BY total_order_amount DESC;\n"
1198
  ]
1199
  }
1200
  ],
@@ -1214,7 +1184,7 @@
1214
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1215
  "\n",
1216
  "# Load the fine-tuned model and tokenizer\n",
1217
- "model_name = \"text2sql_flant5base_finetuned\" # Directory of your fine-tuned model\n",
1218
  "finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
1219
  "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
1220
  "finetuned_model.to(device)\n",
@@ -1227,12 +1197,17 @@
1227
  " inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(device)\n",
1228
  " generated_ids = finetuned_model.generate(\n",
1229
  " input_ids=inputs[\"input_ids\"],\n",
1230
- " max_new_tokens=250, # Adjust based on query complexity\n",
1231
- " temperature=0.0, # Deterministic output\n",
1232
- " num_beams=3, # Beam search for better output quality\n",
1233
  " early_stopping=True, # Stop early if possible\n",
1234
  " )\n",
1235
- " return tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
 
 
 
 
 
1236
  "\n",
1237
  "# Sample context and query (example)\n",
1238
  "context = (\n",
@@ -1264,16 +1239,6 @@
1264
  "logger.info(\"Running inference with deterministic decoding and beam search.\")\n",
1265
  "generated_sql = run_inference(sample_prompt)\n",
1266
  "\n",
1267
- "# Define the expected response (this is a placeholder - update as necessary)\n",
1268
- "expected_response = (\n",
1269
- " \"SELECT c.name, SUM(o.total_amount) as total_order_amount \"\n",
1270
- " \"FROM customers c \"\n",
1271
- " \"JOIN orders o ON c.id = o.customer_id \"\n",
1272
- " \"WHERE c.country = 'USA' \"\n",
1273
- " \"GROUP BY c.name \"\n",
1274
- " \"ORDER BY total_order_amount DESC;\"\n",
1275
- ")\n",
1276
- "\n",
1277
  "# Print output in the given format\n",
1278
  "print(\"Prompt:\")\n",
1279
  "print(\"Context:\")\n",
@@ -1281,14 +1246,12 @@
1281
  "print(\"\\nQuery:\")\n",
1282
  "print(query)\n",
1283
  "print(\"\\nResponse:\")\n",
1284
- "print(generated_sql)\n",
1285
- "print(\"\\nEXPECTED RESPONSE:\")\n",
1286
- "print(expected_response)\n"
1287
  ]
1288
  },
1289
  {
1290
  "cell_type": "code",
1291
- "execution_count": 20,
1292
  "id": "a69f268e-bc69-4633-9c15-4e118c20178e",
1293
  "metadata": {},
1294
  "outputs": [
@@ -1319,22 +1282,22 @@
1319
  "# Load fine-tuned LoRA adapter model\n",
1320
  "lora_model = PeftModel.from_pretrained(base_model, lora_model_path)\n",
1321
  "\n",
1322
- "# Save the LoRA adapter separately (for users who want lightweight adapters)\n",
1323
  "lora_model.save_pretrained(lora_model_path)\n",
1324
  "tokenizer.save_pretrained(lora_model_path)\n",
1325
  "\n",
1326
- "# Merge LoRA into the base model to create a fully fine-tuned model\n",
1327
  "merged_model = lora_model.merge_and_unload()\n",
1328
  "\n",
1329
- "# Save the full fine-tuned model\n",
1330
  "merged_model.save_pretrained(full_model_output_path)\n",
1331
  "tokenizer.save_pretrained(full_model_output_path)\n",
1332
  "\n",
1333
- "# Save generation config (optional but recommended for inference settings)\n",
1334
  "generation_config = {\n",
1335
- " \"max_new_tokens\": 250,\n",
1336
- " \"temperature\": 0.0,\n",
1337
- " \"num_beams\": 3,\n",
1338
  " \"early_stopping\": True\n",
1339
  "}\n",
1340
  "with open(f\"{full_model_output_path}/generation_config.json\", \"w\") as f:\n",
@@ -1346,7 +1309,7 @@
1346
  },
1347
  {
1348
  "cell_type": "code",
1349
- "execution_count": 33,
1350
  "id": "f1c95dfc-6662-44d8-8ecc-bff414fecee5",
1351
  "metadata": {},
1352
  "outputs": [
@@ -1354,22 +1317,11 @@
1354
  "name": "stderr",
1355
  "output_type": "stream",
1356
  "text": [
1357
- "2025-03-18 16:55:46,428 - INFO - Running inference with beam search decoding.\n"
1358
- ]
1359
- },
1360
- {
1361
- "name": "stdout",
1362
- "output_type": "stream",
1363
- "text": [
1364
- "Prompt:\n",
1365
- "Context:\n",
1366
- "CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); INSERT INTO employees (id, name, department, salary) VALUES (1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), (3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); INSERT INTO projects (project_id, project_name, budget) VALUES (101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); INSERT INTO employee_projects (employee_id, project_id, role) VALUES (1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), (4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\n",
1367
- "\n",
1368
- "Query:\n",
1369
- "Find the names of employees who are working on the 'AI Research' project along with their roles.\n",
1370
- "\n",
1371
- "Response:\n",
1372
- "SELECT employees.name, employee_projects.role FROM employees INNER JOIN employee_projects ON employees.id = employee_projects.employee_id INNER JOIN projects ON employee_projects.project_id = projects.project_id WHERE projects.project_name = 'AI Research';\n"
1373
  ]
1374
  }
1375
  ],
@@ -1462,7 +1414,7 @@
1462
  {
1463
  "cell_type": "code",
1464
  "execution_count": null,
1465
- "id": "562458ed-53f4-44af-a7a3-e42a175c7245",
1466
  "metadata": {},
1467
  "outputs": [],
1468
  "source": []
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "id": "5f167a6f-5139-46e6-afb2-a1fa4d12f3fd",
7
  "metadata": {},
8
  "outputs": [],
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 2,
36
  "id": "53684b5e-c27e-4eb9-815e-583aa194e096",
37
  "metadata": {},
38
  "outputs": [
 
55
  },
56
  {
57
  "cell_type": "code",
58
+ "execution_count": 3,
59
  "id": "a47bf3cd-752d-4d1c-9697-70098d6204fa",
60
  "metadata": {},
61
  "outputs": [],
 
69
  },
70
  {
71
  "cell_type": "code",
72
+ "execution_count": 4,
73
  "id": "f16df21e-9797-4f78-83a1-a2943759ba55",
74
  "metadata": {},
75
  "outputs": [],
 
81
  },
82
  {
83
  "cell_type": "code",
84
+ "execution_count": 5,
85
  "id": "196e83da-6c8c-4cd7-bd70-2598a5e2a16a",
86
  "metadata": {},
87
  "outputs": [],
 
95
  },
96
  {
97
  "cell_type": "code",
98
+ "execution_count": 6,
99
  "id": "cea22b9f-f309-4151-81ac-37547c8feeb0",
100
  "metadata": {},
101
  "outputs": [],
 
127
  },
128
  {
129
  "cell_type": "code",
130
+ "execution_count": 7,
131
  "id": "d4eb82ce-1713-40b6-981d-43ce35aaa6f6",
132
  "metadata": {},
133
  "outputs": [
 
135
  "name": "stderr",
136
  "output_type": "stream",
137
  "text": [
138
+ "2025-03-19 14:56:53,295 - INFO - Loading raw datasets from various sources...\n",
139
+ "2025-03-19 14:57:25,655 - INFO - Total rows before dropping duplicates: 490241\n",
140
+ "2025-03-19 14:57:27,208 - INFO - Total rows after dropping duplicates: 440785\n"
141
  ]
142
  }
143
  ],
 
170
  },
171
  {
172
  "cell_type": "code",
173
+ "execution_count": 8,
174
  "id": "8446814e-5a2c-48a4-8c01-059afcf1d3c1",
175
  "metadata": {},
176
  "outputs": [
 
179
  "output_type": "stream",
180
  "text": [
181
  "Token indices sequence length is longer than the specified maximum sequence length for this model (1113 > 512). Running this sequence through the model will result in indexing errors\n",
182
+ "2025-03-19 15:01:13,787 - INFO - Total rows after filtering by token length (prompt <= 500 and response <= 250 tokens): 398481\n"
183
  ]
184
  }
185
  ],
 
210
  },
211
  {
212
  "cell_type": "code",
213
+ "execution_count": 9,
214
  "id": "177e1e6d-9fbc-442d-9774-5a3e5234329f",
215
  "metadata": {},
216
  "outputs": [
 
218
  "name": "stderr",
219
  "output_type": "stream",
220
  "text": [
221
+ "2025-03-19 15:01:13,794 - INFO - Sample from filtered final_df:\n",
222
  " query \\\n",
223
  "0 Name the home team for carlton away team \n",
224
  "1 what will the population of Asia be when Latin... \n",
 
243
  },
244
  {
245
  "cell_type": "code",
246
+ "execution_count": 10,
247
  "id": "0b639efe-ebeb-4b34-bc3f-accf776ba0da",
248
  "metadata": {},
249
  "outputs": [
 
251
  "name": "stderr",
252
  "output_type": "stream",
253
  "text": [
254
+ "2025-03-19 15:01:14,006 - INFO - Final split sizes: Train: 338708, Test: 39848, Validation: 19925\n"
255
  ]
256
  },
257
  {
258
  "data": {
259
  "application/vnd.jupyter.widget-view+json": {
260
+ "model_id": "81e753f720e44f40b5f0dfa5263e2bf5",
261
  "version_major": 2,
262
  "version_minor": 0
263
  },
 
271
  {
272
  "data": {
273
  "application/vnd.jupyter.widget-view+json": {
274
+ "model_id": "59b1ce0d9ee548668dbc87b99d6e0951",
275
  "version_major": 2,
276
  "version_minor": 0
277
  },
 
285
  {
286
  "data": {
287
  "application/vnd.jupyter.widget-view+json": {
288
+ "model_id": "4a378405a0a24c13a81fc853550d01d6",
289
  "version_major": 2,
290
  "version_minor": 0
291
  },
 
300
  "name": "stderr",
301
  "output_type": "stream",
302
  "text": [
303
+ "2025-03-19 15:01:15,490 - INFO - Merged and Saved Dataset Successfully!\n",
304
+ "2025-03-19 15:01:15,497 - INFO - Dataset summary: DatasetDict({\n",
305
  " train: Dataset({\n",
306
  " features: ['query', 'context', 'response'],\n",
307
  " num_rows: 338708\n",
 
350
  },
351
  {
352
  "cell_type": "code",
353
+ "execution_count": 11,
354
  "id": "9f6e1095-d72d-4e22-b20d-683f1f84544c",
355
  "metadata": {},
356
  "outputs": [
 
358
  "name": "stderr",
359
  "output_type": "stream",
360
  "text": [
361
+ "2025-03-19 15:01:15,843 - INFO - Reloaded dataset from disk. Example from test split:\n",
362
  "{'query': \"Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\", 'context': \"CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\", 'response': 'SELECT command_name, type FROM defense_security.Military_Cyber_Commands;'}\n",
363
+ "2025-03-19 15:01:16,155 - INFO - Loaded Tokenized Dataset from disk.\n",
364
+ "2025-03-19 15:01:16,159 - INFO - Final tokenized dataset splits: dict_keys(['train', 'test', 'validation'])\n",
365
+ "2025-03-19 15:01:16,167 - INFO - Sample tokenized record from train split:\n",
366
  "{'input_ids': tensor([ 1193, 6327, 10, 205, 4386, 6048, 332, 17098, 953, 834,\n",
367
  " 4350, 834, 4013, 41, 234, 834, 11650, 584, 4280, 28027,\n",
368
  " 6, 550, 834, 11650, 584, 4280, 28027, 3, 61, 3,\n",
 
536
  },
537
  {
538
  "cell_type": "code",
539
+ "execution_count": 12,
540
  "id": "7f004e55-181c-47aa-9f3e-c7c1ceae780c",
541
  "metadata": {},
542
  "outputs": [
 
603
  },
604
  {
605
  "cell_type": "code",
606
+ "execution_count": 13,
607
  "id": "f50e56c7-98b3-42bc-9129-89f3eff802e7",
608
  "metadata": {},
609
  "outputs": [
 
611
  "name": "stderr",
612
  "output_type": "stream",
613
  "text": [
614
+ "2025-03-19 15:01:30,827 - INFO - Attempting to load the fine-tuned model...\n",
615
+ "2025-03-19 15:01:32,195 - INFO - Fine-tuned model loaded successfully.\n"
616
  ]
617
  }
618
  ],
 
715
  },
716
  {
717
  "cell_type": "code",
718
+ "execution_count": 14,
719
  "id": "f364eb6b-56cb-4533-8ef6-b5e7f56895aa",
720
  "metadata": {},
721
  "outputs": [
 
723
  "name": "stderr",
724
  "output_type": "stream",
725
  "text": [
726
+ "2025-03-19 15:01:32,235 - INFO - Running inference on 5 examples (displaying real responses).\n",
727
+ "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
728
+ " warnings.warn(\n"
729
  ]
730
  },
731
  {
 
751
  "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
752
  "----------------------------------------------------------------------------------------------------\n",
753
  "ORIGINAL MODEL OUTPUT:\n",
754
+ "USCYBERCOM, JTF-CND, Offensive Cyber Operations\n",
755
  "----------------------------------------------------------------------------------------------------\n",
756
  "FINE-TUNED MODEL OUTPUT:\n",
757
  "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
 
774
  "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
775
  "----------------------------------------------------------------------------------------------------\n",
776
  "ORIGINAL MODEL OUTPUT:\n",
777
+ "10000, 2022-01-01\n",
778
  "----------------------------------------------------------------------------------------------------\n",
779
  "FINE-TUNED MODEL OUTPUT:\n",
780
  "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
 
820
  "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH);\n",
821
  "----------------------------------------------------------------------------------------------------\n",
822
  "ORIGINAL MODEL OUTPUT:\n",
823
+ "The total number of posts made by users located in Australia is 50.\n",
824
  "----------------------------------------------------------------------------------------------------\n",
825
  "FINE-TUNED MODEL OUTPUT:\n",
826
  "SELECT COUNT(*) FROM posts p JOIN users u ON p.user_id = u.id WHERE u.location = 'Australia' AND p.created_at >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);\n",
 
832
  "name": "stderr",
833
  "output_type": "stream",
834
  "text": [
835
+ "2025-03-19 15:01:40,448 - INFO - Starting evaluation on the full test set using batching.\n"
836
  ]
837
  },
838
  {
 
856
  "SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country;\n",
857
  "----------------------------------------------------------------------------------------------------\n",
858
  "ORIGINAL MODEL OUTPUT:\n",
859
+ "1, 150, USA, 2, 200, Canada, 3, 120, Mexico\n",
860
  "----------------------------------------------------------------------------------------------------\n",
861
  "FINE-TUNED MODEL OUTPUT:\n",
862
  "SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n",
 
864
  "\n"
865
  ]
866
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  {
868
  "data": {
869
  "application/vnd.jupyter.widget-view+json": {
870
+ "model_id": "a7beecee09a34f9790be1e4538a87442",
871
  "version_major": 2,
872
  "version_minor": 0
873
  },
 
881
  {
882
  "data": {
883
  "application/vnd.jupyter.widget-view+json": {
884
+ "model_id": "763373c451c94f5e92bc6a6253109275",
885
  "version_major": 2,
886
  "version_minor": 0
887
  },
 
895
  {
896
  "data": {
897
  "application/vnd.jupyter.widget-view+json": {
898
+ "model_id": "afdce82cb8964da788756d783539ee8d",
899
  "version_major": 2,
900
  "version_minor": 0
901
  },
 
910
  "name": "stderr",
911
  "output_type": "stream",
912
  "text": [
913
+ "2025-03-19 16:47:58,173 - INFO - Using default tokenizer.\n",
914
+ "2025-03-19 16:49:07,668 - INFO - Using default tokenizer.\n"
915
  ]
916
  },
917
  {
 
923
  "Evaluation Metrics:\n",
924
  "====================================================================================================\n",
925
  "ORIGINAL MODEL:\n",
926
+ " ROUGE: {'rouge1': np.float64(0.05646642898660111), 'rouge2': np.float64(0.01562815013068162), 'rougeL': np.float64(0.05031267225420556), 'rougeLsum': np.float64(0.05036072587316542)}\n",
927
+ " BLEU: {'bleu': 0.003142147128241449, 'precisions': [0.12293406776920406, 0.03289697910893642, 0.018512080104175887, 0.008342750223825794], 'brevity_penalty': 0.11177079327444009, 'length_ratio': 0.3133514352662089, 'translation_length': 377251, 'reference_length': 1203923}\n",
928
+ " Fuzzy Match Score: 13.98%\n",
929
  " Exact Match Accuracy: 0.00%\n",
930
  "\n",
931
  "FINE-TUNED MODEL:\n",
932
+ " ROUGE: {'rouge1': np.float64(0.7538800834024002), 'rouge2': np.float64(0.6103863808522726), 'rougeL': np.float64(0.7262841884754194), 'rougeLsum': np.float64(0.7261852209847466)}\n",
933
+ " BLEU: {'bleu': 0.4719774431701209, 'precisions': [0.7603153442288385, 0.598309257795389, 0.5021259810303533, 0.42128998564638875], 'brevity_penalty': 0.8474086962179814, 'length_ratio': 0.8579477258927689, 'translation_length': 1032903, 'reference_length': 1203923}\n",
934
+ " Fuzzy Match Score: 85.62%\n",
935
+ " Exact Match Accuracy: 18.29%\n",
936
  "====================================================================================================\n"
937
  ]
938
  }
939
  ],
940
  "source": [
941
+ "import logging\n",
 
942
  "import re\n",
943
+ "import pandas as pd\n",
944
+ "from rapidfuzz import fuzz\n",
945
  "import evaluate\n",
946
  "\n",
947
+ "# Assuming tokenizer, device, original_model, finetuned_model, and dataset are already defined.\n",
948
+ "# Define a helper function for output post-processing.\n",
949
+ "def post_process_output(output_text: str) -> str:\n",
950
+ " \"\"\"Post-process the generated output to remove repeated text.\"\"\"\n",
951
+ " # Keep only the first valid SQL query (everything before the first semicolon)\n",
952
+ " return output_text.split(\";\")[0] + \";\" if \";\" in output_text else output_text\n",
953
+ "\n",
954
+ "# Define a helper function for generating outputs with the given generation parameters.\n",
955
+ "def generate_with_params(model, input_ids):\n",
956
+ " generated_ids = model.generate(\n",
957
+ " input_ids=input_ids,\n",
958
+ " max_new_tokens=100, \n",
959
+ " num_beams=5,\n",
960
+ " repetition_penalty=1.2,\n",
961
+ " temperature=0.1,\n",
962
+ " early_stopping=True\n",
963
+ " )\n",
964
+ " # Decode and post-process output\n",
965
+ " output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
966
+ " return output_text\n",
967
+ "\n",
968
+ "# Helper functions for SQL normalization and evaluation metrics\n",
969
  "def normalize_sql(sql):\n",
970
  " \"\"\"Normalize SQL by stripping whitespace and lowercasing.\"\"\"\n",
971
  " return \" \".join(sql.strip().lower().split())\n",
 
981
  " scores = [fuzz.token_set_ratio(pred, ref) for pred, ref in zip(predictions, references)]\n",
982
  " return sum(scores) / len(scores) if scores else 0\n",
983
  "\n",
984
+ "# Dummy function to free up memory if needed.\n",
985
+ "def clear_memory():\n",
986
+ " # If using torch.cuda, you can clear cache:\n",
987
+ " # torch.cuda.empty_cache()\n",
988
+ " pass\n",
989
+ "\n",
990
+ "logger = logging.getLogger(__name__)\n",
991
+ "logger.setLevel(logging.INFO)\n",
992
+ "\n",
993
+ "# --- Part A: Inference on 5 Examples with Real Responses ---\n",
994
  "logger.info(\"Running inference on 5 examples (displaying real responses).\")\n",
995
  "\n",
996
  "num_examples = 5\n",
 
998
  "sample_contexts = dataset[\"test\"][:num_examples][\"context\"]\n",
999
  "sample_human_responses = dataset[\"test\"][:num_examples][\"response\"]\n",
1000
  "\n",
1001
+ "print(\"\\n\" + \"=\" * 100)\n",
1002
  "for idx in range(num_examples):\n",
1003
  " prompt = f\"\"\"Context:\n",
1004
  "{sample_contexts[idx]}\n",
 
1008
  "\n",
1009
  "Response:\n",
1010
  "\"\"\"\n",
1011
+ " # Tokenize the prompt and move to device\n",
1012
+ " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n",
 
 
 
1013
  " \n",
1014
+ " # Generate outputs using the modified generation parameters\n",
1015
+ " orig_out = generate_with_params(original_model, inputs[\"input_ids\"])\n",
1016
+ " finetuned_out = post_process_output(generate_with_params(finetuned_model, inputs[\"input_ids\"]))\n",
1017
  " \n",
1018
  " print(\"-\" * 100)\n",
1019
  " print(f\"Example {idx+1}\")\n",
 
1025
  " print(sample_human_responses[idx])\n",
1026
  " print(\"-\" * 100)\n",
1027
  " print(\"ORIGINAL MODEL OUTPUT:\")\n",
1028
+ " print(orig_out)\n",
1029
  " print(\"-\" * 100)\n",
1030
  " print(\"FINE-TUNED MODEL OUTPUT:\")\n",
1031
+ " print(finetuned_out)\n",
1032
  " print(\"=\" * 100 + \"\\n\")\n",
1033
  " clear_memory()\n",
1034
  "\n",
 
1039
  "all_original_responses = []\n",
1040
  "all_finetuned_responses = []\n",
1041
  "\n",
1042
+ "batch_size = 128 # Adjust based on GPU memory\n",
1043
  "test_dataset = dataset[\"test\"]\n",
1044
  "\n",
1045
  "for i in range(0, len(test_dataset), batch_size):\n",
1046
  " # Slicing the dataset returns a dict of lists\n",
1047
+ " batch = test_dataset[i:i + batch_size]\n",
1048
  " \n",
1049
+ " # Construct prompts for each example in the batch\n",
1050
  " prompts = [\n",
1051
  " f\"Context:\\n{batch['context'][j]}\\n\\nQuery:\\n{batch['query'][j]}\\n\\nResponse:\"\n",
1052
  " for j in range(len(batch[\"context\"]))\n",
1053
  " ]\n",
1054
  " \n",
1055
+ " # Extend human responses\n",
1056
  " all_human_responses.extend(batch[\"response\"])\n",
1057
  " \n",
1058
+ " # Tokenize the batch of prompts with padding and truncation\n",
1059
+ " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(device)\n",
1060
  " \n",
1061
+ " # Generate outputs for the batch for both models\n",
1062
+ " orig_ids = original_model.generate(\n",
1063
+ " input_ids=inputs[\"input_ids\"],\n",
1064
+ " max_new_tokens=100,\n",
1065
+ " num_beams=5,\n",
1066
+ " repetition_penalty=1.2,\n",
1067
+ " temperature=0.1,\n",
1068
+ " early_stopping=True\n",
1069
+ " )\n",
1070
+ " finetuned_ids = finetuned_model.generate(\n",
1071
+ " input_ids=inputs[\"input_ids\"],\n",
1072
+ " max_new_tokens=100,\n",
1073
+ " num_beams=5,\n",
1074
+ " repetition_penalty=1.2,\n",
1075
+ " temperature=0.1,\n",
1076
+ " early_stopping=True\n",
1077
+ " )\n",
1078
  " \n",
1079
+ " # Decode and post-process each sample in the batch\n",
1080
  " orig_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in orig_ids]\n",
1081
+ " finetuned_texts = [post_process_output(tokenizer.decode(ids, skip_special_tokens=True)) for ids in finetuned_ids]\n",
1082
  " \n",
1083
  " all_original_responses.extend(orig_texts)\n",
1084
  " all_finetuned_responses.extend(finetuned_texts)\n",
 
1087
  "# Create a DataFrame for a quick comparison of results\n",
1088
  "zipped_all = list(zip(all_human_responses, all_original_responses, all_finetuned_responses))\n",
1089
  "df_full = pd.DataFrame(zipped_all, columns=[\"Human Response\", \"Original Model Output\", \"Fine-Tuned Model Output\"])\n",
1090
+ "df_full.to_csv('evaluation_results.csv', index=False)\n",
 
 
1091
  "clear_memory()\n",
1092
  "\n",
1093
  "# --- Compute Evaluation Metrics ---\n",
 
1094
  "rouge = evaluate.load(\"rouge\")\n",
1095
  "bleu = evaluate.load(\"bleu\")\n",
1096
  "\n",
 
1122
  "finetuned_fuzzy = compute_fuzzy_match(all_finetuned_responses, all_human_responses)\n",
1123
  "finetuned_exact = compute_exact_match(all_finetuned_responses, all_human_responses)\n",
1124
  "\n",
1125
+ "print(\"\\n\" + \"=\" * 100)\n",
1126
  "print(\"Evaluation Metrics:\")\n",
1127
+ "print(\"=\" * 100)\n",
1128
  "print(\"ORIGINAL MODEL:\")\n",
1129
  "print(f\" ROUGE: {orig_rouge}\")\n",
1130
  "print(f\" BLEU: {orig_bleu}\")\n",
 
1135
  "print(f\" BLEU: {finetuned_bleu}\")\n",
1136
  "print(f\" Fuzzy Match Score: {finetuned_fuzzy:.2f}%\")\n",
1137
  "print(f\" Exact Match Accuracy: {finetuned_exact:.2f}%\")\n",
1138
+ "print(\"=\" * 100)\n",
1139
+ "clear_memory()"
1140
  ]
1141
  },
1142
  {
1143
  "cell_type": "code",
1144
+ "execution_count": 15,
1145
  "id": "462546a7-6928-4723-b00e-23c3a4091d99",
1146
  "metadata": {},
1147
  "outputs": [
 
1149
  "name": "stderr",
1150
  "output_type": "stream",
1151
  "text": [
1152
+ "2025-03-19 16:51:05,225 - INFO - Running inference with deterministic decoding and beam search.\n"
1153
  ]
1154
  },
1155
  {
 
1164
  "Retrieve the total order amount for each customer, showing only customers from the USA, and sort the result by total order amount in descending order.\n",
1165
  "\n",
1166
  "Response:\n",
1167
+ "SELECT customer_id, SUM(total_amount) as total_amount FROM orders JOIN customers ON orders.customer_id = customers.id WHERE customers.country = 'USA' GROUP BY customer_id ORDER BY total_amount DESC;\n"
 
 
 
1168
  ]
1169
  }
1170
  ],
 
1184
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1185
  "\n",
1186
  "# Load the fine-tuned model and tokenizer\n",
1187
+ "model_name = \"text2sql_flant5base_finetuned\" \n",
1188
  "finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
1189
  "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
1190
  "finetuned_model.to(device)\n",
 
1197
  " inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(device)\n",
1198
  " generated_ids = finetuned_model.generate(\n",
1199
  " input_ids=inputs[\"input_ids\"],\n",
1200
+ " max_new_tokens=100, # Adjust based on query complexity\n",
1201
+ " temperature=0.1, # Deterministic output\n",
1202
+ " num_beams=5, # Beam search for better output quality\n",
1203
  " early_stopping=True, # Stop early if possible\n",
1204
  " )\n",
1205
+ " generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
1206
+ "\n",
1207
+ " # Post-processing to remove repeated text\n",
1208
+ " generated_sql = generated_sql.split(\";\")[0] + \";\" # Keep only the first valid SQL query\n",
1209
+ "\n",
1210
+ " return generated_sql\n",
1211
  "\n",
1212
  "# Sample context and query (example)\n",
1213
  "context = (\n",
 
1239
  "logger.info(\"Running inference with deterministic decoding and beam search.\")\n",
1240
  "generated_sql = run_inference(sample_prompt)\n",
1241
  "\n",
 
 
 
 
 
 
 
 
 
 
1242
  "# Print output in the given format\n",
1243
  "print(\"Prompt:\")\n",
1244
  "print(\"Context:\")\n",
 
1246
  "print(\"\\nQuery:\")\n",
1247
  "print(query)\n",
1248
  "print(\"\\nResponse:\")\n",
1249
+ "print(generated_sql)\n"
 
 
1250
  ]
1251
  },
1252
  {
1253
  "cell_type": "code",
1254
+ "execution_count": 16,
1255
  "id": "a69f268e-bc69-4633-9c15-4e118c20178e",
1256
  "metadata": {},
1257
  "outputs": [
 
1282
  "# Load fine-tuned LoRA adapter model\n",
1283
  "lora_model = PeftModel.from_pretrained(base_model, lora_model_path)\n",
1284
  "\n",
1285
+ "# Save the LoRA adapter separately (for users who want lightweight adapters)\n",
1286
  "lora_model.save_pretrained(lora_model_path)\n",
1287
  "tokenizer.save_pretrained(lora_model_path)\n",
1288
  "\n",
1289
+ "# Merge LoRA into the base model to create a fully fine-tuned model\n",
1290
  "merged_model = lora_model.merge_and_unload()\n",
1291
  "\n",
1292
+ "# Save the full fine-tuned model\n",
1293
  "merged_model.save_pretrained(full_model_output_path)\n",
1294
  "tokenizer.save_pretrained(full_model_output_path)\n",
1295
  "\n",
1296
+ "# Save generation config (optional but recommended for inference settings)\n",
1297
  "generation_config = {\n",
1298
+ " \"max_new_tokens\": 100,\n",
1299
+ " \"temperature\": 0.1,\n",
1300
+ " \"num_beams\": 5,\n",
1301
  " \"early_stopping\": True\n",
1302
  "}\n",
1303
  "with open(f\"{full_model_output_path}/generation_config.json\", \"w\") as f:\n",
 
1309
  },
1310
  {
1311
  "cell_type": "code",
1312
+ "execution_count": null,
1313
  "id": "f1c95dfc-6662-44d8-8ecc-bff414fecee5",
1314
  "metadata": {},
1315
  "outputs": [
 
1317
  "name": "stderr",
1318
  "output_type": "stream",
1319
  "text": [
1320
+ "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
1321
+ " warnings.warn(\n",
1322
+ "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
1323
+ " warnings.warn(\n",
1324
+ "2025-03-19 16:51:49,933 - INFO - Running inference with beam search decoding.\n"
 
 
 
 
 
 
 
 
 
 
 
1325
  ]
1326
  }
1327
  ],
 
1414
  {
1415
  "cell_type": "code",
1416
  "execution_count": null,
1417
+ "id": "97425ac4-ad46-4f38-b22d-071e161da20a",
1418
  "metadata": {},
1419
  "outputs": [],
1420
  "source": []