File size: 128,474 Bytes
71c12a0
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
71c12a0
0788177
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
 
71c12a0
 
 
 
 
 
 
 
 
 
 
0788177
 
 
71c12a0
 
 
0788177
 
 
 
71c12a0
 
 
 
 
0788177
71c12a0
0788177
 
71c12a0
 
0788177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
 
 
 
 
 
 
 
 
 
71c12a0
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
0788177
 
71c12a0
0788177
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
0788177
71c12a0
0788177
71c12a0
 
 
 
 
0788177
71c12a0
 
0788177
 
71c12a0
0788177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
0788177
71c12a0
0788177
71c12a0
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
0788177
71c12a0
 
 
 
 
 
 
 
 
 
0788177
 
71c12a0
 
d809e46
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
0788177
 
 
71c12a0
 
0788177
 
 
 
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
d809e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
0788177
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
0788177
71c12a0
 
0788177
71c12a0
 
 
0788177
71c12a0
0788177
 
 
71c12a0
 
 
 
 
 
 
 
 
d809e46
 
 
 
 
 
 
 
71c12a0
 
9133db0
71c12a0
 
 
 
 
 
 
0788177
 
 
 
 
71c12a0
9133db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0788177
71c12a0
 
 
 
 
 
 
d809e46
71c12a0
 
 
 
 
 
 
 
 
 
 
 
 
d809e46
71c12a0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "20e6b5e9-1b58-4da4-b46a-3400db5bc662",
   "metadata": {},
   "source": [
    "# Imports and Library Setup\n",
    "\n",
    "This code block sets up the environment by importing all necessary libraries and modules required for the project. Below is a detailed explanation of each import:\n",
    "\n",
    "## Standard Libraries\n",
    "- **time**: Provides time-related functions, useful for measuring execution time or creating timestamps.\n",
    "  \n",
    "- **logging**: Enables configurable logging, which is essential for tracking the execution flow and debugging.\n",
    "\n",
    "- **re**: Supports regular expression operations for pattern matching and text manipulation.\n",
    "\n",
    "- **random**: Supplies functions to generate random numbers and make random selections, which can aid in data sampling and ensuring reproducibility.\n",
    "\n",
    "- **gc**: Interfaces with Python’s garbage collection system, allowing manual memory management (e.g., freeing up unused memory).\n",
    "\n",
    "## Data Processing Libraries\n",
    "- **numpy (`np`)**: A fundamental package for numerical computations and handling multi-dimensional arrays.\n",
    "\n",
    "- **pandas (`pd`)**: A powerful library for data manipulation and analysis, especially with DataFrame structures that handle tabular data efficiently.\n",
    "\n",
    "## Deep Learning Libraries\n",
    "- **torch**: PyTorch is used for tensor computations and building deep learning models. It supports GPU acceleration, which is crucial for training large models.\n",
    "\n",
    "- **evaluate**: A library designed for model evaluation, offering standardized metrics to compare model outputs against reference data.\n",
    "\n",
    "## Hugging Face Ecosystem\n",
    "- **datasets (Dataset, DatasetDict, load_from_disk)**: \n",
    "  - *Dataset*: Represents a single dataset, allowing for efficient data manipulation.\n",
    "\n",
    "  - *DatasetDict*: A container to hold multiple datasets (e.g., train, test, validation splits).\n",
    "\n",
    "  - *load_from_disk*: Facilitates loading previously saved datasets, which helps in preserving preprocessing efforts.\n",
    "\n",
    "- **transformers**: A library for working with state-of-the-art pre-trained models and tokenizers.\n",
    "  - **AutoModelForSeq2SeqLM**: Automatically loads a model suited for sequence-to-sequence tasks (e.g., translation, summarization).\n",
    "\n",
    "\n",
    "  - **AutoTokenizer**: Automatically loads the appropriate tokenizer for the chosen model.\n",
    "\n",
    "\n",
    "  - **TrainingArguments**: Encapsulates all the hyperparameters and configurations needed for training a model.\n",
    "\n",
    "\n",
    "  - **Trainer**: Provides an easy-to-use interface for training, evaluating, and performing predictions with models.\n",
    "\n",
    "\n",
    "  - **GenerationConfig**: Allows configuration of parameters for text generation, such as max token limits or beam search settings.\n",
    "\n",
    "\n",
    "  - **BitsAndBytesConfig**: Configures low-bit quantization settings, enabling efficient fine-tuning and inference with reduced precision.\n",
    "\n",
    "- **EarlyStoppingCallback (from transformers.trainer_callback)**: Implements early stopping during training to prevent overfitting and save computational resources.\n",
    "\n",
    "## PEFT (Parameter-Efficient Fine-Tuning)\n",
    "- **LoraConfig**: Configures Low-Rank Adaptation (LoRA), a method for fine-tuning large models efficiently by updating only a small subset of parameters.\n",
    "\n",
    "- **get_peft_model**: Applies the PEFT method to a base model, integrating LoRA layers into it.\n",
    "\n",
    "- **prepare_model_for_kbit_training**: Prepares a model for training with low-bit (quantized) precision, optimizing memory usage and computational efficiency.\n",
    "\n",
    "---\n",
    "\n",
    "These imports collectively provide all the functionalities needed for:\n",
    "- Data handling and preprocessing.\n",
    "\n",
    "- Loading and configuring state-of-the-art models.\n",
    "\n",
    "- Fine-tuning using efficient techniques like LoRA.\n",
    "\n",
    "- Evaluating model performance with standardized metrics.\n",
    "\n",
    "This robust setup forms the backbone for building, training, and evaluating sequence-to-sequence models tailored for tasks such as text-to-SQL conversion.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f167a6f-5139-46e6-afb2-a1fa4d12f3fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import logging\n",
    "import re\n",
    "import random\n",
    "import gc\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import evaluate\n",
    "\n",
    "from datasets import Dataset, DatasetDict, load_from_disk\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoTokenizer,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    "    GenerationConfig,\n",
    "    BitsAndBytesConfig,\n",
    ")\n",
    "from transformers.trainer_callback import EarlyStoppingCallback\n",
    "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "080d7108-82c7-41da-ab2f-d0e8401c6ea9",
   "metadata": {},
   "source": [
    "# Device and Computation Optimization Setup\n",
    "\n",
    "This block of code is dedicated to optimizing the runtime performance and ensuring that computations are performed on the most suitable hardware.\n",
    "\n",
    "## Enabling cuDNN Benchmark\n",
    "\n",
    "- **Purpose:**\n",
    "\n",
    "    - Enables cuDNN benchmarking in PyTorch to automatically select the fastest algorithm for the given fixed input sizes.\n",
    "\n",
    "- **Benefit:**\n",
    "\n",
    "    - Can significantly speed up computations, particularly for operations like convolutions in deep learning models when input dimensions remain constant.\n",
    "\n",
    "## Setting the Computation Device\n",
    "\n",
    "- **Device Selection:**\n",
    "\n",
    "    - Checks if a CUDA-compatible GPU is available. If so, it sets the device to \"cuda\" otherwise, it defaults to the CPU.\n",
    "\n",
    "- **Output:**\n",
    "\n",
    "    - The print(device) statement confirms which device is selected.\n",
    "\n",
    "- **Impact:**\n",
    "\n",
    "    - Utilizing a GPU when available can dramatically accelerate model training and inference by leveraging parallel processing capabilities.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "53684b5e-c27e-4eb9-815e-583aa194e096",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "# Enable cudnn benchmark for fixed input sizes (can speed up computation)\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "# Set device to RTX 4090\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc8d1e5e-f312-43e8-b6cd-645b49a74839",
   "metadata": {},
   "source": [
    "# Setting Random Seeds for Reproducibility\n",
    "\n",
    "This block of code sets a fixed random seed across multiple libraries to ensure that experiments and model training are reproducible. Consistent seeding is crucial for debugging, comparing model performance, and sharing reproducible research.\n",
    "\n",
    "- **random.seed(42):**\n",
    "Sets the seed for Python's built-in random module. This ensures that any random numbers generated using this module (e.g., for data shuffling or random sampling) remain consistent across runs.\n",
    "\n",
    "- **np.random.seed(42):**\n",
    "Sets the seed for NumPy's random number generator. Since NumPy is widely used for numerical operations and generating random arrays, this makes sure that any randomness in these operations is controlled.\n",
    "\n",
    "- **torch.manual_seed(42):**\n",
    "Sets the seed for PyTorch's CPU-based random number generator. This is essential for ensuring that model initialization and any other operations involving randomness in PyTorch produce the same results in every run.\n",
    "\n",
    "- **torch.cuda.manual_seed_all(42):**\n",
    "Ensures that all random operations on GPUs (like random weight initialization or dropout) are also deterministic.\n",
    "\n",
    "- **Overall Impact:**\n",
    "By fixing the seed value (42 in this case) across all relevant libraries (Python's random, NumPy, and PyTorch for both CPU and GPU), this code block guarantees that the entire pipeline produces the same results on every run, thus enhancing reproducibility and reliability in experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a47bf3cd-752d-4d1c-9697-70098d6204fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4cc7a07-1345-458b-9afd-5d88db903e31",
   "metadata": {},
   "source": [
    "# Memory Clearing Utility Function\n",
    "\n",
    "This function is designed to free up system and GPU memory during runtime. It is especially useful in scenarios where you are working with large datasets or deep learning models, as it helps mitigate memory leaks and prevents out-of-memory errors.\n",
    "\n",
    "- **`gc.collect()`:**\n",
    "Invokes Python’s garbage collector to identify and free up memory that is no longer in use.\n",
    "This step helps in cleaning up any residual objects that are not referenced, ensuring efficient memory management on the CPU.\n",
    "\n",
    "- **`torch.cuda.empty_cache()`:**\n",
    "Clears the cached memory allocated by PyTorch on the GPU.\n",
    "Although PyTorch reuses cached memory to improve performance, manually clearing the cache can be beneficial when switching between models or after heavy memory usage, as it makes more GPU memory available for subsequent operations.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f16df21e-9797-4f78-83a1-a2943759ba55",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clear_memory():\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d1fcaa4-20a8-4b09-97bb-57304a4d9ef5",
   "metadata": {},
   "source": [
    "# Logging Configuration Setup\n",
    "\n",
    "This block configures the Python logging module to ensure that log messages are informative and consistently formatted across the application. Here's a detailed breakdown:\n",
    "\n",
    "- **`logging.basicConfig(...)`:**\n",
    "Sets up the basic configuration for the logging system:\n",
    "\n",
    "- **`level=logging.INFO`:**\n",
    "Specifies that the logging level is set to INFO. This means that all log messages with a severity of INFO, WARNING, ERROR, or CRITICAL will be captured and displayed. Debug messages (DEBUG level) will be ignored unless the level is set to DEBUG.\n",
    "\n",
    "\n",
    "- **`format=\"%(asctime)s - %(levelname)s - %(message)s\"`:**\n",
    "Defines the format of the log messages:\n",
    "    - **`%(asctime)s`:** Inserts a timestamp into each log message, which helps in tracking when each event occurred.\n",
    "    \n",
    "    - **`%(levelname)s`:** Inserts the log level (e.g., INFO, WARNING) of the message, providing insight into the severity or importance of the logged event.\n",
    "    \n",
    "    - **`%(message)s`:** Inserts the actual log message content.\n",
    "\n",
    "\n",
    "- **`logger = logging.getLogger(__name__)`:**\n",
    "Retrieves a logger instance specific to the current module:\n",
    "\n",
    "**`__name__`** is a special Python variable that holds the name of the current module. This ensures that the logger is uniquely identified by the module's name, making it easier to locate log messages in larger applications with multiple modules.\n",
    "The returned logger object is then used to log messages throughout the module with methods like logger.info(), logger.error(), etc.\n",
    "\n",
    "### Overall, this setup standardizes logging across the application, making it easier to monitor the program's behavior, debug issues, and record the sequence of events during execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "196e83da-6c8c-4cd7-bd70-2598a5e2a16a",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
    ")\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71ca521b-0534-4911-bd25-4c25896f16a4",
   "metadata": {},
   "source": [
    "# Data Cleaning Utility Functions\n",
    "\n",
    "This block defines two functions essential for preparing and cleaning text data within pandas DataFrames. These functions help standardize text fields and ensure DataFrame columns are properly formatted for subsequent processing or model training.\n",
    "\n",
    "---\n",
    "\n",
    "## Function: `preprocess`\n",
    "\n",
    "### Explanation:\n",
    "\n",
    "- **Purpose:**\n",
    "The preprocess function is designed to clean a text string by removing unnecessary whitespace and newline characters, ensuring that the text is formatted consistently.\n",
    "\n",
    "- **Input Parameter:**\n",
    "\n",
    "    - `text (str)`: The text string to be cleaned.\n",
    "\n",
    "- **Functionality:**\n",
    "\n",
    "    - **Type Check:**\n",
    "        The function first checks if the input is a string. If it is not, it returns an empty string to avoid processing invalid data.\n",
    "\n",
    "    - **Newline Replacement:**\n",
    "        The `text.replace('\\n', ' ')` part converts newline characters into spaces, flattening the text into a single line.\n",
    "\n",
    "    - **Whitespace Reduction:**\n",
    "        The `re.sub(r'\\s+', ' ', ...)` call uses a regular expression to replace multiple consecutive whitespace characters (spaces, tabs, etc.) with a single space.\n",
    "\n",
    "    - **Trimming:**\n",
    "        Finally, `.strip()` removes any leading or trailing whitespace from the resulting string.\n",
    "### Output:\n",
    "**A cleaned string with extra spaces and newline characters removed.**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cea22b9f-f309-4151-81ac-37547c8feeb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(text: str) -> str:\n",
    "    \"\"\"Remove extra whitespaces and newlines from a text string.\"\"\"\n",
    "    if not isinstance(text, str):\n",
    "        return \"\"\n",
    "    return re.sub(r'\\s+', ' ', text.replace('\\n', ' ')).strip()\n",
    "\n",
    "def clean_df(df, rename=None, drop=None, select=None):\n",
    "    \"\"\"\n",
    "    Clean and rename dataframe columns:\n",
    "      - drop: list of columns to drop\n",
    "      - rename: dict mapping old column names to new names\n",
    "      - select: list of columns to keep in final order\n",
    "    \"\"\"\n",
    "    if drop:\n",
    "        df = df.drop(columns=drop, errors='ignore')\n",
    "    if rename:\n",
    "        df = df.rename(columns=rename)\n",
    "    for col in ['query', 'context', 'response']:\n",
    "        if col in df.columns:\n",
    "            df[col] = df[col].apply(preprocess)\n",
    "    if select:\n",
    "        df = df[select]\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20a3d1e9-5291-4402-bb5f-5b120a1cca20",
   "metadata": {},
   "source": [
    "# Loading, Cleaning, and Merging Raw Datasets\n",
    "\n",
    "This code block handles the entire process of loading raw datasets from various sources, cleaning and standardizing their columns, and merging them into a single DataFrame ready for further processing. Below is a detailed explanation of each step:\n",
    "\n",
    "1. **Logging the Start of Dataset Loading:**\n",
    "   This log message indicates that the process of loading the datasets is about to begin, providing visibility into the workflow execution.\n",
    "\n",
    "2. **Loading Datasets**\n",
    "\n",
    "3. **Cleaning and Standardizing Column Names:**\n",
    "   To ensure consistency across all datasets, each DataFrame is cleaned using the clean_df function, which:\n",
    "    - Renames columns to a unified naming convention (query, context, response).\n",
    "    \n",
    "    - Drops unnecessary columns.\n",
    "    \n",
    "    - Applies text preprocessing to remove extraneous whitespaces and newlines.\n",
    "\n",
    "4. **Concatenating DataFrames:**\n",
    "    - All the cleaned DataFrames are merged into a single DataFrame (final_df) using pd.concat, which resets the index.\n",
    "    \n",
    "    - A log statement records the total number of rows in the merged DataFrame before any duplicates are removed.\n",
    "\n",
    "5. **Final DataFrame Cleanup:**\n",
    "    - **Column Reordering:**\n",
    "          The DataFrame is restructured to enforce a specific column order: query, context, and response.\n",
    "    \n",
    "    - **Dropping Missing Values:**\n",
    "          Any rows that have missing values in these critical columns are removed using dropna().\n",
    "    \n",
    "    - **Removing Duplicates:**\n",
    "          Duplicate rows are dropped to ensure each example is unique.\n",
    "\n",
    "A final log statement records the number of rows remaining after cleaning, providing a clear picture of the dataset size for subsequent steps.\n",
    "\n",
    "### Overall, these steps ensure that diverse raw datasets are uniformly processed, cleaned, and merged into a coherent format, ready for tokenization and model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d4eb82ce-1713-40b6-981d-43ce35aaa6f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 14:56:53,295 - INFO - Loading raw datasets from various sources...\n",
      "2025-03-19 14:57:25,655 - INFO - Total rows before dropping duplicates: 490241\n",
      "2025-03-19 14:57:27,208 - INFO - Total rows after dropping duplicates: 440785\n"
     ]
    }
   ],
   "source": [
    "logger.info(\"Loading raw datasets from various sources...\")\n",
    "\n",
    "# Load datasets\n",
    "df1 = pd.read_json(\"hf://datasets/Clinton/Text-to-sql-v1/texttosqlv2.jsonl\", lines=True)\n",
    "df2 = pd.read_json(\"hf://datasets/b-mc2/sql-create-context/sql_create_context_v4.json\")\n",
    "df3 = pd.read_parquet(\"hf://datasets/gretelai/synthetic_text_to_sql/synthetic_text_to_sql_train.snappy.parquet\")\n",
    "df4 = pd.read_json(\"hf://datasets/knowrohit07/know_sql/know_sql_val3{ign}.json\")\n",
    "\n",
    "# Clean and rename columns to unify to 'query', 'context', 'response'\n",
    "df1 = clean_df(df1, rename={'instruction': 'query', 'input': 'context'}, drop=['source', 'text'])\n",
    "df2 = clean_df(df2, rename={'question': 'query', 'answer': 'response'})\n",
    "df3 = clean_df(df3, rename={'sql_prompt': 'query', 'sql_context': 'context', 'sql': 'response'},\n",
    "                select=['query', 'context', 'response'])\n",
    "df4 = clean_df(df4, rename={'question': 'query', 'answer': 'response'})\n",
    "\n",
    "# Concatenate all DataFrames\n",
    "final_df = pd.concat([df1, df2, df3, df4], ignore_index=True)\n",
    "logger.info(\"Total rows before dropping duplicates: %d\", len(final_df))\n",
    "\n",
    "# Force correct column order and drop rows with missing fields\n",
    "final_df = final_df[['query', 'context', 'response']]\n",
    "final_df = final_df.dropna(subset=['query', 'context', 'response'])\n",
    "final_df = final_df.drop_duplicates()\n",
    "logger.info(\"Total rows after dropping duplicates: %d\", len(final_df))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cca0d00e-5c35-4e37-819e-760ed16d60f6",
   "metadata": {},
   "source": [
    "# Tokenizer Initialization and Token Length Filtering\n",
    "\n",
    "This block of code prepares the tokenizer and filters the merged DataFrame (`final_df`) based on token length constraints for both prompts and responses. Here’s a detailed breakdown:\n",
    "\n",
    "1. **Tokenizer Initialization:**\n",
    "    - **Purpose:**\n",
    "        Loads the pre-trained tokenizer corresponding to the \"`google/flan-t5-base`\" model. This tokenizer will be used to convert text (prompts and responses) into token IDs required by the model.\n",
    "    \n",
    "    - **Context:**\n",
    "        Using a pre-trained tokenizer ensures that text is tokenized in a manner consistent with the model’s training, which is essential for achieving good performance.\n",
    "2. **Setting Maximum Token Lengths:**\n",
    "    - **Purpose:**\n",
    "        These variables define the maximum allowed token lengths for the prompt and response respectively.\n",
    "    \n",
    "    - **Benefit:**\n",
    "        Enforcing token limits prevents overly long inputs that could lead to out-of-memory errors or inefficient processing. The limits are chosen based on model constraints and expected input sizes.\n",
    "3. **Defining the Token Length Filter Function:**\n",
    "    - **Purpose:**\n",
    "        This function evaluates each row of the DataFrame to determine if its prompt and response, once tokenized, fit within the predefined token limits.\n",
    "    \n",
    "    - **Steps within the Function:**\n",
    "        - **Prompt Construction:**\n",
    "            Combines the context and query fields from the row with designated markers (`\"Context:\"`, `\"Query:\"`, `\"Response:\"`) to create a full prompt.\n",
    "        \n",
    "        - **Tokenization:**\n",
    "            The prompt and response are tokenized without truncation (truncation=False) to capture the full token count.\n",
    "            add_special_tokens=True ensures that any model-specific tokens (like start-of-sequence or end-of-sequence tokens) are included.\n",
    "        \n",
    "        - **Token Length Check:**\n",
    "            The function returns True if both the prompt and response have token counts within their respective limits, otherwise False.\n",
    "4. **Filtering the DataFrame:**\n",
    "    - **Purpose:**\n",
    "        Applies the tokenize_length_filter function to each row of final_df and filters out any rows that do not meet the token length constraints.\n",
    "    \n",
    "    - **Result:**\n",
    "        The resulting DataFrame contains only those rows where both the prompt and the response are within the acceptable token lengths.\n",
    "5. **Logging the Result:**\n",
    "    - **Purpose:**\n",
    "        Logs the number of rows remaining in the DataFrame after applying the token length filter.\n",
    "    \n",
    "    - **Benefit:**\n",
    "        This log statement provides insight into how much data remains after filtering, which is useful for monitoring and debugging the preprocessing pipeline.\n",
    "\n",
    "### Overall Impact: \n",
    "    - By initializing the tokenizer and filtering based on token length, this code ensures that the inputs provided to the model are within manageable sizes, thus preventing errors during training or inference and ensuring efficient processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8446814e-5a2c-48a4-8c01-059afcf1d3c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1113 > 512). Running this sequence through the model will result in indexing errors\n",
      "2025-03-19 15:01:13,787 - INFO - Total rows after filtering by token length (prompt <= 500 and response <= 250 tokens): 398481\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
    "\n",
    "max_length_prompt = 500\n",
    "max_length_response = 250\n",
    "\n",
    "def tokenize_length_filter(row):\n",
    "    start_prompt = \"Context:\\n\"\n",
    "    middle_prompt = \"\\n\\nQuery:\\n\"\n",
    "    end_prompt = \"\\n\\nResponse:\\n\"\n",
    "    \n",
    "    # Construct the prompt as used in the tokenize_function\n",
    "    prompt = f\"{start_prompt}{row['context']}{middle_prompt}{row['query']}{end_prompt}\"\n",
    "    \n",
    "    # Encode without truncation to get the full token count\n",
    "    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True, truncation=False)\n",
    "    response_tokens = tokenizer.encode(row['response'], add_special_tokens=True, truncation=False)\n",
    "    \n",
    "    return len(prompt_tokens) <= max_length_prompt and len(response_tokens) <= max_length_response\n",
    "\n",
    "final_df = final_df[final_df.apply(tokenize_length_filter, axis=1)]\n",
    "logger.info(\"Total rows after filtering by token length (prompt <= %d and response <= %d tokens): %d\", \n",
    "            max_length_prompt, max_length_response, len(final_df))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "177e1e6d-9fbc-442d-9774-5a3e5234329f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:13,794 - INFO - Sample from filtered final_df:\n",
      "                                               query  \\\n",
      "0           Name the home team for carlton away team   \n",
      "1  what will the population of Asia be when Latin...   \n",
      "2  How many faculty members do we have for each g...   \n",
      "\n",
      "                                             context  \\\n",
      "0  CREATE TABLE table_name_77 ( home_team VARCHAR...   \n",
      "1  CREATE TABLE table_22767 ( \"Year\" real, \"World...   \n",
      "2  CREATE TABLE Student ( StuID INTEGER, LName VA...   \n",
      "\n",
      "                                            response  \n",
      "0  SELECT home_team FROM table_name_77 WHERE away...  \n",
      "1  SELECT \"Asia\" FROM table_22767 WHERE \"Latin Am...  \n",
      "2  SELECT Sex, COUNT(*) FROM Faculty GROUP BY Sex...  \n"
     ]
    }
   ],
   "source": [
    "logger.info(\"Sample from filtered final_df:\\n%s\", final_df.head(3))\n",
    "clear_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66fe2c58-4cee-4566-a505-81193658f7ae",
   "metadata": {},
   "source": [
    "# Splitting DataFrame into Train, Test, and Validation Sets and Converting to Hugging Face Dataset\n",
    "\n",
    "This code block is responsible for dividing the cleaned DataFrame (`final_df`) into training, testing, and validation subsets, and then converting these splits into Hugging Face Datasets for further processing.\n",
    "\n",
    "---\n",
    "\n",
    "1. **Splitting the DataFrame:**\n",
    "\n",
    "    - **Function Purpose:**\n",
    "    The `split_dataframe` function divides a DataFrame into three distinct subsets:\n",
    "\n",
    "        - **Training Set:** Contains 85% of the data (by default).\n",
    "    \n",
    "        - **Test Set:** Contains 10% of the data.\n",
    "    \n",
    "        - **Validation Set:** Contains the remaining 5% of the data.\n",
    "\n",
    "    - **Detailed Steps:**\n",
    "\n",
    "        - **Determine Total Rows:**\n",
    "            - `n = len(df)` calculates the total number of rows in the DataFrame.\n",
    "\n",
    "        - **Calculate Split Indices:**\n",
    "            - **`train_end = int(n * train_frac)`** determines the index where the training data should end.\n",
    "        \n",
    "            - **`test_end = train_end + int(n * test_frac)`** calculates the index where the test data should end, immediately following the training set.\n",
    "\n",
    "        - **Slice the DataFrame:**\n",
    "            - **`train_df = df.iloc[:train_end].reset_index(drop=True)`** selects the training portion and resets the index.\n",
    "        \n",
    "            - **`test_df = df.iloc[train_end:test_end].reset_index(drop=True)`** selects the test portion and resets the index.\n",
    "        \n",
    "            - **`val_df = df.iloc[test_end:].reset_index(drop=True)`** selects the validation portion and resets the index.\n",
    "          \n",
    "        - **Return Splits:**\n",
    "            The function returns the three DataFrame splits.\n",
    " \n",
    "2. **Converting DataFrame Splits to Hugging Face Datasets:**\n",
    "\n",
    "    - **Conversion Process:**\n",
    "        Each pandas DataFrame (`train_df`, `test_df`, `val_df`) is converted into a Hugging Face `Dataset` using `Dataset.from_pandas()`. This allows efficient manipulation and processing during training and evaluation.\n",
    "\n",
    "    - **Dataset Dictionary:**\n",
    "        A `DatasetDict` is created to organize the datasets under clearly labeled splits: `'train'`, `'test'`, and `'validation'`. This standard structure is widely used in Hugging Face pipelines.\n",
    "\n",
    "3. **Saving the Dataset to Disk and Memory Cleanup:**\n",
    "\n",
    "    - **Saving:**\n",
    "        `dataset.save_to_disk(\"merged_dataset\")` saves the merged dataset to disk. This ensures that the preprocessed data is persisted and can be reloaded in future sessions without repeating the preprocessing steps.\n",
    "\n",
    "    - **Logging:**\n",
    "        Log messages confirm that the dataset has been successfully merged, saved, and provide a summary of its structure.\n",
    "\n",
    "    - **Memory Cleanup:**\n",
    "        `clear_memory()` is called to free up any residual memory, which is especially important after handling large datasets.\n",
    "\n",
    "### Overall Impact:\n",
    "\n",
    "**This entire code block streamlines the data preparation process by:**\n",
    "\n",
    "    - Splitting the dataset into train, test, and validation sets.\n",
    "    \n",
    "    - Converting these splits into a format compatible with Hugging Face's ecosystem.\n",
    "    \n",
    "    - Saving the processed dataset for efficient future reuse.\n",
    "    \n",
    "    - Logging key information and cleaning up memory to maintain optimal performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0b639efe-ebeb-4b34-bc3f-accf776ba0da",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:14,006 - INFO - Final split sizes: Train: 338708, Test: 39848, Validation: 19925\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81e753f720e44f40b5f0dfa5263e2bf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/338708 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59b1ce0d9ee548668dbc87b99d6e0951",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/39848 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a378405a0a24c13a81fc853550d01d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/19925 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:15,490 - INFO - Merged and Saved Dataset Successfully!\n",
      "2025-03-19 15:01:15,497 - INFO - Dataset summary: DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['query', 'context', 'response'],\n",
      "        num_rows: 338708\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['query', 'context', 'response'],\n",
      "        num_rows: 39848\n",
      "    })\n",
      "    validation: Dataset({\n",
      "        features: ['query', 'context', 'response'],\n",
      "        num_rows: 19925\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "def split_dataframe(df, train_frac=0.85, test_frac=0.1, val_frac=0.05):\n",
    "    n = len(df)\n",
    "    train_end = int(n * train_frac)\n",
    "    test_end = train_end + int(n * test_frac)\n",
    "    train_df = df.iloc[:train_end].reset_index(drop=True)\n",
    "    test_df = df.iloc[train_end:test_end].reset_index(drop=True)\n",
    "    val_df = df.iloc[test_end:].reset_index(drop=True)\n",
    "    return train_df, test_df, val_df\n",
    "\n",
    "train_df, test_df, val_df = split_dataframe(final_df)\n",
    "logger.info(\"Final split sizes: Train: %d, Test: %d, Validation: %d\", len(train_df), len(test_df), len(val_df))\n",
    "\n",
    "# Convert splits to Hugging Face Datasets\n",
    "train_dataset = Dataset.from_pandas(train_df)\n",
    "test_dataset = Dataset.from_pandas(test_df)\n",
    "val_dataset = Dataset.from_pandas(val_df)\n",
    "\n",
    "dataset = DatasetDict({\n",
    "    'train': train_dataset,\n",
    "    'test': test_dataset,\n",
    "    'validation': val_dataset\n",
    "})\n",
    "\n",
    "dataset.save_to_disk(\"merged_dataset\")\n",
    "logger.info(\"Merged and Saved Dataset Successfully!\")\n",
    "logger.info(\"Dataset summary: %s\", dataset)\n",
    "clear_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aad832b-518c-41f2-97d4-ba02212f1aac",
   "metadata": {},
   "source": [
    "# Reloading and Tokenizing the Dataset for T5 Fine-Tuning\n",
    "\n",
    "This code block covers several critical steps to prepare the dataset for fine-tuning a T5 model. It includes reloading a previously merged dataset from disk, initializing the tokenizer, defining a custom tokenization function, and finally either loading an existing tokenized version or creating one if it doesn't exist. Below is a detailed explanation of each part of the code:\n",
    "\n",
    "---\n",
    "\n",
    "1. **Reloading the Dataset from Disk:**\n",
    "    - **Purpose:**\n",
    "      The merged dataset, previously saved under the name `\"merged_dataset\"`, is reloaded using the `load_from_disk function`.\n",
    "        \n",
    "    - **Logging:**\n",
    "      A sample from the test split of the reloaded dataset is logged. This helps verify that the dataset was loaded correctly and gives a quick glance at the structure of the data.\n",
    "\n",
    "2. **Initializing the Tokenizer:**\n",
    "    - **Purpose:**\n",
    "        - Sets the model name to `\"google/flan-t5-base\"`, which corresponds to the pre-trained T5 model.\n",
    "        \n",
    "        - Loads the associated tokenizer using `AutoTokenizer.from_pretrained`, ensuring that tokenization is consistent with how the model was originally trained.\n",
    "\n",
    "3. **Defining the Tokenization Function:**\n",
    "\n",
    "    - **Function Overview:**\n",
    "        - The `tokenize_function` prepares a batch of examples for T5 fine-tuning by constructing a standardized prompt and tokenizing both the prompt and its corresponding response.\n",
    "\n",
    "    - **Prompt Construction:**\n",
    "        - Three strings (`start_prompt`, `middle_prompt`, and `end_prompt`) define the structure of the prompt.\n",
    "        \n",
    "        - Each example’s prompt is created by combining the `context` and `query` fields from the batch with these fixed markers.\n",
    "          \n",
    "    - **Tokenization of Inputs and Labels:**\n",
    "        - **Prompts:**\n",
    "            - The constructed prompts are tokenized with a maximum length of 512 tokens, using padding and truncation to ensure uniformity.\n",
    "        \n",
    "        - **Responses:**\n",
    "            - The responses are tokenized separately with a maximum length of 256 tokens.\n",
    "\n",
    "        - **Label Preparation:**\n",
    "            - Tokenized responses are processed so that any token matching the pad token ID is replaced with `-100`. This ensures that the loss function will ignore these padded positions during training.\n",
    "\n",
    "    - **Updating the Batch:**\n",
    "        - **The function adds three keys to the batch dictionary:**\n",
    "            - `'input_ids'`: The tokenized prompt IDs.\n",
    "            \n",
    "            - `'attention_mask'`: The attention masks corresponding to the prompts.\n",
    "            \n",
    "            - `'labels'`: The processed tokenized responses.\n",
    "\n",
    "        - **The updated batch is then returned.**\n",
    "\n",
    "4. **Loading or Creating the Tokenized Dataset:**\n",
    "\n",
    "    - **Purpose:**\n",
    "        - The code attempts to load a pre-tokenized version of the dataset from disk, saved under the name `\"tokenized_datasets\"`.\n",
    "\n",
    "    - **Error Handling:**\n",
    "        - If the tokenized dataset is not found (which throws an exception), the code logs that it will create a new tokenized dataset.\n",
    "\n",
    "    - **Dataset Mapping:**\n",
    "        - Uses the map function on the original dataset with the `tokenize_function`:\n",
    "\n",
    "        - **Batched Processing:**\n",
    "            - `batched=True` allows processing multiple examples at once for efficiency.\n",
    "\n",
    "        - **Removing Unnecessary Columns:**\n",
    "            - The original columns query, context, and response are removed after tokenization since they are no longer needed.\n",
    "\n",
    "        - **Parallel Processing:**\n",
    "            - `num_proc=8` utilizes 8 processes to speed up the tokenization.\n",
    "\n",
    "    - **Saving:**\n",
    "        - The newly tokenized dataset is saved to disk as `\"tokenized_datasets\"` for future reuse.\n",
    "\n",
    "        - Appropriate log messages are generated to confirm the process.\n",
    "     \n",
    "5. **Formatting and Logging the Final Tokenized Dataset:**\n",
    "\n",
    "    - **Setting the Format:**\n",
    "        - `tokenized_datasets.set_format(\"torch\")` converts the dataset into a PyTorch-compatible format, which is necessary for efficient training.\n",
    "\n",
    "    - **Logging the Dataset Structure:**\n",
    "        - The keys of the tokenized dataset (typically `'train'`, `'test'`, and `'validation'`) are logged to verify the available splits.\n",
    "    \n",
    "        - A sample record from the training split is also logged to ensure that the tokenization was successful and the data is formatted as expected.\n",
    "\n",
    "### Overall Impact:\n",
    "    - This comprehensive block of code efficiently reloads, tokenizes, and formats the dataset for T5 fine-tuning. \n",
    "    \n",
    "    - It leverages robust error handling to either load an existing tokenized dataset or generate a new one when necessary, ensuring that the subsequent training process will proceed smoothly with data in the correct format.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9f6e1095-d72d-4e22-b20d-683f1f84544c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:15,843 - INFO - Reloaded dataset from disk. Example from test split:\n",
      "{'query': \"Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\", 'context': \"CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\", 'response': 'SELECT command_name, type FROM defense_security.Military_Cyber_Commands;'}\n",
      "2025-03-19 15:01:16,155 - INFO - Loaded Tokenized Dataset from disk.\n",
      "2025-03-19 15:01:16,159 - INFO - Final tokenized dataset splits: dict_keys(['train', 'test', 'validation'])\n",
      "2025-03-19 15:01:16,167 - INFO - Sample tokenized record from train split:\n",
      "{'input_ids': tensor([ 1193,  6327,    10,   205,  4386,  6048,   332, 17098,   953,   834,\n",
      "         4350,   834,  4013,    41,   234,   834, 11650,   584,  4280, 28027,\n",
      "            6,   550,   834, 11650,   584,  4280, 28027,     3,    61,     3,\n",
      "        27569,    10,  5570,     8,   234,   372,    21,   443,  7377,   550,\n",
      "          372, 16361,    10,     3,     1,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "            0,     0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0]), 'labels': tensor([    3, 23143, 14196,   234,   834, 11650, 21680,   953,   834,  4350,\n",
      "          834,  4013,   549, 17444,   427,   550,   834, 11650,  3274,    96,\n",
      "         1720,  7377,   121,     1,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,\n",
      "         -100,  -100,  -100,  -100,  -100,  -100])}\n"
     ]
    }
   ],
   "source": [
    "dataset = load_from_disk(\"merged_dataset\")\n",
    "logger.info(\"Reloaded dataset from disk. Example from test split:\\n%s\", dataset['test'][0])\n",
    "\n",
    "model_name = \"google/flan-t5-base\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "def tokenize_function(batch: dict) -> dict:\n",
    "    \"\"\"\n",
    "    Tokenizes a batch of examples for T5 fine-tuning.\n",
    "    Constructs a prompt in the format:\n",
    "      Context:\n",
    "      <context>\n",
    "      \n",
    "      Query:\n",
    "      <query>\n",
    "      \n",
    "      Response:\n",
    "    \"\"\"\n",
    "    start_prompt = \"Context:\\n\"\n",
    "    middle_prompt = \"\\n\\nQuery:\\n\"\n",
    "    end_prompt = \"\\n\\nResponse:\\n\"\n",
    "\n",
    "    prompts = [\n",
    "        f\"{start_prompt}{ctx}{middle_prompt}{qry}{end_prompt}\"\n",
    "        for ctx, qry in zip(batch['context'], batch['query'])\n",
    "    ]\n",
    "\n",
    "    tokenized_inputs = tokenizer(\n",
    "        prompts,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        max_length=512\n",
    "    )\n",
    "    tokenized_labels = tokenizer(\n",
    "        batch['response'],\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        max_length=256\n",
    "    )\n",
    "    labels = [\n",
    "        [-100 if token == tokenizer.pad_token_id else token for token in seq]\n",
    "        for seq in tokenized_labels['input_ids']\n",
    "    ]\n",
    "\n",
    "    batch['input_ids'] = tokenized_inputs['input_ids']\n",
    "    batch['attention_mask'] = tokenized_inputs['attention_mask']\n",
    "    batch['labels'] = labels\n",
    "    return batch\n",
    "\n",
    "try:\n",
    "    tokenized_datasets = load_from_disk(\"tokenized_datasets\")\n",
    "    logger.info(\"Loaded Tokenized Dataset from disk.\")\n",
    "except Exception as e:\n",
    "    logger.info(\"Tokenized dataset not found. Creating a new one...\")\n",
    "    tokenized_datasets = dataset.map(\n",
    "        tokenize_function,\n",
    "        batched=True,\n",
    "        remove_columns=['query', 'context', 'response'],\n",
    "        num_proc=8\n",
    "    )\n",
    "    tokenized_datasets.save_to_disk(\"tokenized_datasets\")\n",
    "    logger.info(\"Tokenized and Saved Dataset.\")\n",
    "\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "\n",
    "logger.info(\"Final tokenized dataset splits: %s\", tokenized_datasets.keys())\n",
    "logger.info(\"Sample tokenized record from train split:\\n%s\", tokenized_datasets['train'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8101b0de-8ff2-42f0-8e9c-48e8f52778f0",
   "metadata": {},
   "source": [
    "# Baseline Zero-Shot Generation Using the Pre-Trained T5 Model\n",
    "\n",
    "This code demonstrates how to perform zero-shot text generation using the pre-trained T5 model (\"google/flan-t5-base\"). The process involves loading the model and its tokenizer, selecting a sample from the test dataset, constructing a prompt, generating a response with the model, and finally printing both the human-provided response and the model's output. Additionally, the memory is cleared at the end to manage GPU resources efficiently.\n",
    "\n",
    "---\n",
    "\n",
    "1. ***Model and Tokenizer Initialization:***\n",
    "   \n",
    "    - ***Model Name:***\n",
    "        - The variable model_name is set to `\"google/flan-t5-base\"`, indicating the specific pre-trained T5 model being used.\n",
    "\n",
    "    - ***Tokenizer Initialization:***\n",
    "        - The tokenizer corresponding to the T5 model is loaded using `AutoTokenizer.from_pretrained()`. This ensures that the text is tokenized in a manner consistent with the model’s training.\n",
    "\n",
    "    - ***Model Loading:***\n",
    "        - The T5 model is loaded with the data type `torch.bfloat16` for improved memory efficiency and performance on compatible hardware.\n",
    "\n",
    "    - ***Device Assignment:***\n",
    "        - The model is moved to the computation device (GPU or CPU) specified by the variable `device` to leverage hardware acceleration.\n",
    "          \n",
    "2. ***Selecting a Test Example:***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - A single example (at index 0) is selected from the test split of the dataset.\n",
    "\n",
    "    - ***Data Extraction:***\n",
    "        - The `query`, `context`, and `response` fields are extracted from the chosen test example. These fields represent the input query, accompanying context, and the human-provided response respectively.\n",
    "\n",
    "3. ***Constructing the Input Prompt***\n",
    "\n",
    "    - ***Prompt Format:***\n",
    "        - A multi-line formatted string (f-string) is created to structure the input prompt.\n",
    "\n",
    "        - The prompt is divided into three parts:\n",
    "            - Context: Introduced by the line `\"Context:\"` followed by the actual context.\n",
    "\n",
    "            - Query: Introduced by `\"Query:\"` followed by the query.\n",
    "\n",
    "            - Response: A label `\"Response:\"` which indicates where the model should generate its answer.\n",
    "    \n",
    "    - ***Purpose:***\n",
    "        - This structured prompt format guides the model during text generation by clearly delineating the context and query sections.\n",
    "         \n",
    "4. ***Tokenizing the Prompt and Generating Output***\n",
    "\n",
    "    - ***Tokenization:***\n",
    "        - The prompt is tokenized using the tokenizer. The parameter `return_tensors='pt'` ensures that the output is in PyTorch tensor format, suitable for model processing.\n",
    "\n",
    "        - The tokenized inputs are moved to the device (GPU/CPU) for computation.\n",
    "\n",
    "    - ***Text Generation:***\n",
    "        - The model's `generate()` method is called on the tokenized `input_ids` to produce output tokens. The parameter `max_new_tokens=200` sets a limit on the number of tokens the model can generate.\n",
    "\n",
    "    - ***Decoding:***\n",
    "        - The generated tokens are decoded back into a human-readable string using `tokenizer.decode()`. The parameter `skip_special_tokens=True` removes any special tokens (e.g., `<pad>`, `<eos>`) from the output.\n",
    "\n",
    "5. ***Displaying the Input and Output:***\n",
    "\n",
    "    - ***Visual Separation:***\n",
    "        - A dashed line is created (`dash_line`) for clear visual separation of the output sections.\n",
    "\n",
    "    - ***Output Sections:***\n",
    "        - The constructed input prompt is printed.\n",
    "\n",
    "        - The human-provided (baseline) response is printed, serving as a reference.\n",
    "     \n",
    "        - The model-generated output is printed, representing the zero-shot generation result.\n",
    "     \n",
    "6. ***Memory Cleanup***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - The `clear_memory()` function is called to release unused GPU memory. This is essential in managing resources efficiently, especially after performing memory-intensive operations like model inference.\n",
    "\n",
    "### Overall Impact:\n",
    "\n",
    "    - This code block establishes a baseline for zero-shot text generation with the T5 model. \n",
    "    \n",
    "    - By clearly structuring the prompt, tokenizing input, and generating a response, it sets the stage for comparing the model's performance against human responses. \n",
    "    \n",
    "    - The process, along with memory management practices, ensures that the system runs efficiently during inference.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7f004e55-181c-47aa-9f3e-c7c1ceae780c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n",
      "\n",
      "Query:\n",
      "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "BASELINE HUMAN ANSWER:\n",
      "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "MODEL GENERATION - ZERO SHOT:\n",
      "USCYBERCOM, JTF-CND, Offensive Cyber Operations, 10th Fleet, Network Warfare\n"
     ]
    }
   ],
   "source": [
    "model_name = 'google/flan-t5-base'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
    "original_model = original_model.to(device)\n",
    "\n",
    "index = 0\n",
    "query = dataset['test'][index]['query']\n",
    "context = dataset['test'][index]['context']\n",
    "response = dataset['test'][index]['response']\n",
    "\n",
    "prompt = f\"\"\"Context:\n",
    "{context}\n",
    "\n",
    "Query:\n",
    "{query}\n",
    "\n",
    "Response:\n",
    "\"\"\"\n",
    "inputs = tokenizer(prompt, return_tensors='pt').to(device)\n",
    "baseline_output = tokenizer.decode(\n",
    "    original_model.generate(\n",
    "        inputs[\"input_ids\"],\n",
    "        max_new_tokens=200,\n",
    "    )[0],\n",
    "    skip_special_tokens=True\n",
    ")\n",
    "dash_line = '-' * 100\n",
    "print(dash_line)\n",
    "print(f'INPUT PROMPT:\\n{prompt}')\n",
    "print(dash_line)\n",
    "print(f'BASELINE HUMAN ANSWER:\\n{response}\\n')\n",
    "print(dash_line)\n",
    "print(f'MODEL GENERATION - ZERO SHOT:\\n{baseline_output}')\n",
    "clear_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "361801d0-8898-4bcc-bc50-1b99d070c8f0",
   "metadata": {},
   "source": [
    "# Fine-Tuning Setup with QLORA: Detailed Explanation\n",
    "\n",
    "This section is dedicated to setting up and launching the fine-tuning process using a QLORA-based approach. The code first attempts to load an already fine-tuned model from disk; if it isn’t available, it initializes the base model with QLORA modifications for efficient parameter-efficient fine-tuning. The following detailed explanation breaks down every part of the code:\n",
    "\n",
    "---\n",
    "\n",
    "1. ***Attempting to Load the Fine-Tuned Model:***\n",
    "\n",
    "    - ***Objective:***\n",
    "        - The code first attempts to load an existing fine-tuned model (`text2sql_flant5base_finetuned`) from disk.\n",
    "\n",
    "    - ***Steps:***\n",
    "        - ***Logging:***\n",
    "            - A log message indicates the attempt to load the fine-tuned model.\n",
    "\n",
    "        - ***Loading the Model:***\n",
    "            - `AutoModelForSeq2SeqLM.from_pretrained(\"text2sql_flant5base_finetuned\")` loads the model that has already been fine-tuned.\n",
    "\n",
    "        - ***Tokenizer Initialization:***\n",
    "            - The corresponding tokenizer is loaded using the base model identifier (`google/flan-t5-base`).\n",
    "\n",
    "        - ***Device Assignment:***\n",
    "            - The model is moved to the computation device (GPU/CPU) specified by the variable `device`.\n",
    "\n",
    "        - ***Training Flag:***\n",
    "            - The variable `to_train` is set to `False`, indicating that fine-tuning is not required if the model is already available.\n",
    " \n",
    "        - ***Confirmation Logging:***\n",
    "            - A log message confirms that the fine-tuned model was successfully loaded.\n",
    "\n",
    "        - ***If the fine-tuned model is not found, the code enters the `except` block.***\n",
    "\n",
    "            - ***Actions Taken:***\n",
    "\n",
    "                - ***Logging:***\n",
    "                  \n",
    "                    - The absence of the fine-tuned model is logged.\n",
    "\n",
    "                    - A subsequent log message indicates that the process is shifting to initialize the base model for QLORA fine-tuning.\n",
    "                \n",
    "                - ***Training Flag:***\n",
    "                    - `to_train` is set to `True`, signifying that the fine-tuning process will commence.\n",
    "\n",
    "2. ***QLORA-Specific Model Initialization:***\n",
    "\n",
    "    - ***Setting Up Quantization Configuration:***\n",
    "        - ***Purpose:***\n",
    "            - ***Quantization with BitsAndBytes:***\n",
    "                - The `BitsAndBytesConfig` is used to configure the model for 4-bit quantization. This is crucial for reducing memory usage and speeding up fine-tuning while maintaining model performance.\n",
    "            \n",
    "        - ***Parameters:***\n",
    "            - ***`load_in_4bit=True`:***\n",
    "                - Instructs the library to load the model in 4-bit precision.      \n",
    "            \n",
    "            - ***`bnb_4bit_quant_type=\"nf4\"`:***\n",
    "                - Specifies the quantization type. NF4 is a specific quantization scheme that optimizes the trade-off between precision and efficiency.\n",
    "\n",
    "  \n",
    "            - ***`bnb_4bit_use_double_quant=True`:***\n",
    "                - Enables double quantization, which can improve quantization quality.\n",
    "        \n",
    "            - ***`bnb_4bit_compute_dtype=torch.bfloat16`:***\n",
    "                - Sets the compute data type to `bfloat16`, offering a balance between speed and numerical stability.\n",
    "           \n",
    "    - ***Loading the Base Model with Quantization:***\n",
    "        - ***Purpose:***\n",
    "            - ***Loading the Base Model:***\n",
    "                - The base T5 model is loaded with the quantization configuration applied. This means the model is prepared to operate in a lower-precision (4-bit) environment.\n",
    "\n",
    "        - ***Key Parameters:***\n",
    "            - ***`quantization_config=quant_config`:***\n",
    "                - Applies the quantization settings defined earlier.\n",
    "\n",
    "            - ***`device_map=\"auto\"`:***\n",
    "                - Automatically assigns parts of the model to available devices, which is especially useful for large models.\n",
    "            - ***`torch_dtype=torch.bfloat16`:***\n",
    "                - Ensures that the model computations use the `bfloat16` data type, optimizing for performance on compatible hardware.\n",
    "\n",
    "    - ***Preparing the Model for k-Bit Training:***\n",
    "\n",
    "        - ***Purpose:***\n",
    "            - This function call further prepares the model to be fine-tuned in a low-bit (quantized) setting.\n",
    "\n",
    "        - ***Impact:***\n",
    "            - It typically involves adjustments such as modifying certain layers or parameters to better support k-bit (in this case, 4-bit) precision during training.\n",
    "\n",
    "    - ***Configuring LoRA (Low-Rank Adaptation):***\n",
    "        - ***Purpose:***\n",
    "            - ***Parameter-Efficient Fine-Tuning (PEFT):***\n",
    "                - LoRA is a method to fine-tune large language models efficiently by injecting trainable low-rank adaptation matrices into the model’s weights. This reduces the number of parameters that need to be updated during training.\n",
    "        \n",
    "        - ***Parameters Explained:***\n",
    "            - ***`r=32`:***\n",
    "                - The rank of the LoRA matrices. A higher rank allows for more expressiveness but at the cost of additional parameters.\n",
    "\n",
    "            - ***`lora_alpha=64`:***\n",
    "                - A scaling factor that balances the impact of the LoRA updates relative to the original weights.\n",
    "\n",
    "            - ***`target_modules=[\"q\", \"v\"]`:***\n",
    "                - Specifies that only the query and value projection matrices in the model will be adapted using LoRA. This focuses the fine-tuning on key components of the transformer architecture.\n",
    "\n",
    "            - ***`lora_dropout=0.1`:***\n",
    "                - Applies a dropout rate of 10% to the LoRA layers to prevent overfitting.\n",
    "\n",
    "            - ***`bias=\"none\"`:***\n",
    "                - Indicates that no bias terms will be added or modified in the LoRA layers.\n",
    "\n",
    "            - ***`task_type=\"SEQ_2_SEQ_LM\"`:***\n",
    "                - Specifies that the task is sequence-to-sequence language modeling, which tailors the LoRA modifications to this setting.\n",
    "\n",
    "    - ***Integrating LoRA with the Model:***\n",
    "\n",
    "        - ***Purpose:***\n",
    "            - ***PEFT Integration:***\n",
    "                - The `get_peft_model` function wraps the base model with the LoRA configuration. This integrates the low-rank adapters into the model, allowing efficient fine-tuning by updating only a small subset of parameters.\n",
    "        \n",
    "        - ***Impact:***\n",
    "            - This step is critical in reducing the computational cost and memory footprint during training, making it feasible to fine-tune large models on limited hardware.\n",
    "\n",
    "    - ***Logging and Memory Cleanup:***\n",
    "\n",
    "        - ***Logging:***\n",
    "            - A log message confirms that the base model has been successfully loaded and is ready for QLORA fine-tuning.\n",
    "\n",
    "        - ***Memory Cleanup:***\n",
    "            - `clear_memory()` is called to free up any unused memory, which is especially important after loading and modifying large models.\n",
    "\n",
    "3. ***Initiating the Training Process (if Required):***\n",
    "\n",
    "    - Condition:\n",
    "        - The training process is only initiated if `to_train` is `True` (i.e., when no fine-tuned model was found, and the model has been prepared for QLORA fine-tuning).\n",
    "    \n",
    "    - Output Directory:\n",
    "        - An output directory is dynamically created using the current timestamp to ensure unique directory names for each training session.\n",
    "    \n",
    "    - Logging:\n",
    "        - A log message confirms the start of the training process and displays the output directory path.\n",
    "\n",
    "    - ***Calculating Training Steps and Warmup Steps:***\n",
    "\n",
    "        - ***Purpose:***\n",
    "            - Computes the total number of training steps based on:\n",
    "            \n",
    "            - The number of training samples.\n",
    "\n",
    "            - The per-device training batch size.\n",
    "\n",
    "            - The total number of training epochs.\n",
    "\n",
    "            - Warmup steps are set to 10% of the total steps to gradually ramp up the learning rate at the beginning of training.\n",
    "         \n",
    "        - ***Math Operations:***\n",
    "            - `math.ceil()` ensures that the number of steps is rounded up to cover all training samples.\n",
    "\n",
    "        - ***Logging:***\n",
    "            - Logs provide visibility into the total training and warmup steps.\n",
    "\n",
    "    - ***Setting Up Training Arguments:***\n",
    "  \n",
    "        - ***Purpose:***\n",
    "            - This block creates a `TrainingArguments` object that encapsulates all hyperparameters and training configurations.\n",
    "    \n",
    "        - ***Key Hyperparameters and Strategies:***\n",
    "            - ***`out_dir`:***\n",
    "                - Where the model checkpoints and logs will be saved.\n",
    "\n",
    "            - ***`gradient_checkpointing`:***\n",
    "                - Enabled to reduce memory usage by storing only necessary activations.\n",
    "\n",
    "            - ***`gradient_accumulation_steps`:***\n",
    "                - Accumulates gradients over multiple steps (2 in this case) to simulate a larger batch size.\n",
    "\n",
    "            - ***`Learning Rate and Optimizer`:***\n",
    "                - `learning_rate=2e-4` is set with the `optim=adamw_bnb_8bit`, a memory-efficient version of AdamW.\n",
    "            \n",
    "            - ***`per_device_train_batch_size` and `per_device_train_batch_size`:***\n",
    "                - Specifies per-device batch sizes for both training and evaluation.\n",
    "\n",
    "            - ***`weight_decay`:***\n",
    "                - A regularization parameter to prevent overfitting.\n",
    "\n",
    "            - ***Logging and Evaluation Strategies:***\n",
    "                - Logging occurs every 200 steps.\n",
    "                \n",
    "                - Evaluation and model saving occur at the end of each epoch.\n",
    "         \n",
    "            - ***Checkpoint Management:***\n",
    "                - Limits the total number of saved checkpoints to 3.\n",
    "                \n",
    "                - Automatically loads the best model at the end of training based on evaluation loss.\n",
    "                  \n",
    "            - ***Mixed Precision:***\n",
    "                - `bf16` is enabled to leverage bfloat16 precision for faster computation with minimal loss in accuracy.\n",
    "            \n",
    "            - ***Learning Rate Scheduler:`lr_scheduler_type`***\n",
    "                - A cosine scheduler is used, along with a warmup ratio of 10% of the total steps.\n",
    "    \n",
    "    - ***Creating the Trainer and Launching Training:***\n",
    "        - ***Trainer Initialization:***\n",
    "            - ***A `Trainer` object is created with:***\n",
    "                - The QLORA-prepared `finetuned_model`.\n",
    "\n",
    "                - The training arguments defined earlier.\n",
    "\n",
    "                - The training and validation datasets.\n",
    "\n",
    "                - A callback (`EarlyStoppingCallback`) that stops training if the evaluation loss does not improve for 2 consecutive epochs.\n",
    "\n",
    "        - ***Training Execution:***\n",
    "            - The training process is started using `trainer.train()`, and log messages provide feedback before and after training.\n",
    "\n",
    "    - ***Saving the Fine-Tuned Model and Final Cleanup:***\n",
    "        - ***Model Saving:***\n",
    "            - After training, the fine-tuned model is saved to disk under the path `\"text2sql_flant5base_finetuned\"`.\n",
    "\n",
    "        - ***Logging and Cleanup:***\n",
    "            - A log message confirms that the model has been saved successfully.\n",
    "            - The `clear_memory()` function is called to release any unused GPU memory.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f50e56c7-98b3-42bc-9129-89f3eff802e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:30,827 - INFO - Attempting to load the fine-tuned model...\n",
      "2025-03-19 15:01:32,195 - INFO - Fine-tuned model loaded successfully.\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "try:\n",
    "    logger.info(\"Attempting to load the fine-tuned model...\")\n",
    "    finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"text2sql_flant5base_finetuned\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
    "    finetuned_model = finetuned_model.to(device)\n",
    "    to_train = False\n",
    "    logger.info(\"Fine-tuned model loaded successfully.\")\n",
    "except Exception as e:\n",
    "    logger.info(\"Fine-tuned model not found.\")\n",
    "    logger.info(\"Initializing model and tokenizer for QLORA fine-tuning...\")\n",
    "    to_train = True\n",
    "\n",
    "    quant_config = BitsAndBytesConfig(\n",
    "        load_in_4bit=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    )\n",
    "\n",
    "    finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\n",
    "        model_name,\n",
    "        quantization_config=quant_config,\n",
    "        device_map=\"auto\",\n",
    "        torch_dtype=torch.bfloat16,\n",
    "    )\n",
    "    finetuned_model = prepare_model_for_kbit_training(finetuned_model)\n",
    "    \n",
    "    lora_config = LoraConfig(\n",
    "        r=32,\n",
    "        lora_alpha=64,\n",
    "        target_modules=[\"q\", \"v\"],\n",
    "        lora_dropout=0.1,\n",
    "        bias=\"none\",\n",
    "        task_type=\"SEQ_2_SEQ_LM\"\n",
    "    )\n",
    "    finetuned_model = get_peft_model(finetuned_model, lora_config)\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "    logger.info(\"Base model loaded and prepared for QLORA fine-tuning.\")\n",
    "    clear_memory()\n",
    "\n",
    "if to_train:\n",
    "    output_dir = f\"./sql-training-{int(time.time())}\"\n",
    "    logger.info(\"Starting training. Output directory: %s\", output_dir)\n",
    "\n",
    "    # Compute total training steps:\n",
    "    num_train_samples = len(tokenized_datasets[\"train\"])\n",
    "    per_device_train_batch_size = 64\n",
    "    per_device_eval_batch_size = 64\n",
    "    num_train_epochs = 6\n",
    "    # Assuming no gradient accumulation beyond the per-device batch size\n",
    "    total_steps = math.ceil(num_train_samples / per_device_train_batch_size) * num_train_epochs\n",
    "    # Set warmup steps as 10% of total steps (adjust as needed)\n",
    "    warmup_steps = int(total_steps * 0.1)\n",
    "    \n",
    "    logger.info(\"Total training steps: %d, Warmup steps (10%%): %d\", total_steps, warmup_steps)\n",
    "    \n",
    "    training_args = TrainingArguments(\n",
    "        output_dir=output_dir,\n",
    "        gradient_checkpointing=True,\n",
    "        gradient_checkpointing_kwargs={\"use_reentrant\": True},\n",
    "        gradient_accumulation_steps = 2,\n",
    "        learning_rate=2e-4,\n",
    "        optim=\"adamw_bnb_8bit\",  # Memory-efficient optimizer\n",
    "        num_train_epochs=num_train_epochs,\n",
    "        per_device_train_batch_size=per_device_train_batch_size,\n",
    "        per_device_eval_batch_size=per_device_eval_batch_size,\n",
    "        weight_decay=0.01,\n",
    "        logging_steps=200, \n",
    "        logging_dir=f\"{output_dir}/logs\",\n",
    "        eval_strategy=\"epoch\",  # Evaluate at the end of each epoch\n",
    "        save_strategy=\"epoch\",  # Save the model at the end of each epoch\n",
    "        save_total_limit=3,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_loss\",\n",
    "        bf16=True,  \n",
    "        warmup_ratio=0.1,  # Warmup 10% of total steps\n",
    "        lr_scheduler_type=\"cosine\",\n",
    "    )\n",
    "    trainer = Trainer(\n",
    "        model=finetuned_model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_datasets[\"train\"],\n",
    "        eval_dataset=tokenized_datasets[\"validation\"],\n",
    "        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n",
    "    )\n",
    "    logger.info(\"Beginning fine-tuning...\")\n",
    "    trainer.train()\n",
    "    logger.info(\"Training completed.\")\n",
    "    save_path = \"text2sql_flant5base_finetuned\"\n",
    "    finetuned_model.save_pretrained(save_path)\n",
    "    logger.info(\"Model saved to %s\", save_path)\n",
    "    clear_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7367a7ea-b2ae-4d45-be62-775ed16d89d9",
   "metadata": {},
   "source": [
    "# Evaluation Pipeline: Detailed Explanation\n",
    "\n",
    "This code segment implements a comprehensive pipeline for evaluating the performance of both an original (baseline) and a fine-tuned T5 model on a text-to-SQL task. It covers helper function definitions, example-based inference, batch evaluation over the entire test set, and computation of multiple evaluation metrics. Below is an in-depth explanation of every component:\n",
    "\n",
    "---\n",
    "\n",
    "1. ***Helper Functions for Post-Processing, Generation, and Evaluation:***\n",
    "\n",
    "    - ***`post_process_output`:***\n",
    "        - ***Purpose:***\n",
    "            - This function processes the raw text generated by the model to ensure it contains only a single, complete SQL query.\n",
    "\n",
    "        - ***Mechanism:***\n",
    "            - It splits the output text at the first semicolon (`;`).\n",
    "\n",
    "            - If a semicolon is found, only the portion before (and including) it is returned. Otherwise, the entire output is returned.\n",
    "\n",
    "    - ***`generate_with_params`:***\n",
    "        - ***Purpose:***\n",
    "            - This helper function wraps the model's `generate()` method to produce text outputs with specified generation parameters.\n",
    "\n",
    "        - ***Parameters Explained:***\n",
    "            - ***`max_new_tokens=100`:***\n",
    "                - Limits the number of tokens to generate.\n",
    "\n",
    "            - ***`num_beams=5`:***\n",
    "                - Uses beam search with 5 beams for better-quality outputs.\n",
    "\n",
    "            - ***`repetition_penalty=1.2`:***\n",
    "                - Penalizes repetitive token generation.\n",
    "\n",
    "            - ***`temperature=0.1`:***\n",
    "                - Low temperature makes the output more deterministic.\n",
    "\n",
    "            - ***`early_stopping=True`:***\n",
    "                - Stops generation as soon as the end-of-sequence token is produced.\n",
    "        \n",
    "        - ***Post-Processing:***\n",
    "            - After generation, the output tokens are decoded into a string, and special tokens are skipped.\n",
    "    \n",
    "    - ***`normalize_sql`:***\n",
    "        - ***Purpose:***\n",
    "            - Converts SQL queries into a normalized form by lowercasing and removing extra spaces, facilitating fair comparisons between predictions and references.\n",
    "\n",
    "    - ***`compute_exact_match`:***\n",
    "        - ***Purpose:***\n",
    "            - Calculates the percentage of predictions that exactly match the corresponding reference queries after normalization.\n",
    "\n",
    "        - ***Mechanism:***\n",
    "            - Compares each pair of normalized predictions and references; computes the ratio of exact matches.\n",
    "         \n",
    "    - ***`compute_fuzzy_match`:***\n",
    "        - ***Purpose:***\n",
    "            - Computes an average fuzzy matching score, which accounts for partial matches between generated queries and references.\n",
    "\n",
    "        - ***Mechanism:***\n",
    "            - Uses `fuzz.token_set_ratio` to score similarity for each prediction-reference pair and averages the results.\n",
    "\n",
    "2. ***Part A: Inference on 5 Examples (Qualitative Evaluation):***\n",
    "\n",
    "    - ***Extracting Samples from the Test Set:***\n",
    "        - ***Purpose:***\n",
    "            - Retrieves the first 5 examples from the test split, extracting the `query`, `context`, and `response` fields for each sample.\n",
    "\n",
    "    - ***Looping Through Each Example for Inference:***\n",
    "        - ***Process Overview:***\n",
    "            - ***Prompt Construction:***\n",
    "                - For each example, a prompt is built by concatenating the context and query with headers (`\"Context:\"`, `\"Query:\"`, and `\"Response:\"`) to instruct the model.\n",
    "\n",
    "            - ***Tokenization:***\n",
    "                - The prompt is tokenized into PyTorch tensors with a maximum length of 512 tokens and moved to the designated device.\n",
    "\n",
    "            - ***Generation:***\n",
    "                - The original model produces an output using `generate_with_params()`.\n",
    "\n",
    "                - The fine-tuned model output is post-processed using `post_process_output()` to remove any repetitions.\n",
    "\n",
    "            - ***Output Display:***\n",
    "                - For each example, the input prompt, human response, original model output, and fine-tuned model output are printed with visual dividers.\n",
    "\n",
    "            - ***Memory Management:***\n",
    "                - The `clear_memory()` function is called after processing each example to maintain optimal resource usage.\n",
    "\n",
    "3. ***Part B: Evaluation on the Full Test Set with Batching (Quantitative Evaluation):***\n",
    "\n",
    "    - ***Initializing Response Lists and Batch Size:***\n",
    "        - ***Purpose:***\n",
    "            - Initializes lists to store the human-provided responses, and the outputs from both the original and fine-tuned models.\n",
    "            - Sets the batch size for processing the test dataset in chunks (here, 128 examples per batch).\n",
    "    \n",
    "    - ***Processing the Test Set in Batches:***\n",
    "        - ***Steps:***\n",
    "            - ***Batch Slicing:***\n",
    "                - The test dataset is processed in batches using slicing, which returns a dictionary of lists for each field (e.g., context, query, response).\n",
    "\n",
    "            - ***Prompt Construction:***\n",
    "                - For each example in the batch, a prompt is constructed in the same format as in Part A.\n",
    "\n",
    "            - ***Response Collection:***\n",
    "                - Human responses from the batch are collected into `all_human_responses`.\n",
    "\n",
    "            - ***Tokenization:***\n",
    "                - The batch of prompts is tokenized with padding and truncation (max length 512) and moved to the device.\n",
    "\n",
    "            - ***Generation:***\n",
    "                - Both models generate outputs for the entire batch using the same generation parameters as before.\n",
    "            \n",
    "            - ***Decoding and Post-Processing:***\n",
    "                - The generated token IDs are decoded into strings. For the fine-tuned model, the output is post-processed to ensure only the first valid SQL query is retained.\n",
    "            \n",
    "            - ***Extending Response Lists:***\n",
    "                - The decoded outputs are added to the corresponding response lists.\n",
    "\n",
    "            - ***Memory Management:***\n",
    "                - The `clear_memory()` function is called after each batch to maintain resource efficiency.\n",
    "\n",
    "    - ***Creating and Saving a Comparison DataFrame:***\n",
    "        - ***Purpose:***\n",
    "            - Zips together human responses, original model outputs, and fine-tuned model outputs into a list of tuples.\n",
    "            \n",
    "            - Creates a pandas DataFrame from the zipped list, allowing for easy visualization and comparison.\n",
    "            \n",
    "            - Saves the DataFrame to a CSV file (`evaluation_results.csv`) for further analysis.\n",
    "            \n",
    "            - Calls `clear_memory()` to free resources.\n",
    "         \n",
    "4. ***Computing Evaluation Metrics:***\n",
    "\n",
    "    - ***Computing Metrics for Original and Fine-tuned Model:***\n",
    "        - ***Details:***\n",
    "            - ***`ROUGE`:***\n",
    "                - Evaluates the overlap of n-grams between predictions and references. Options like use_aggregator and use_stemmer help produce a more robust score.\n",
    "\n",
    "            - ***`BLEU`:***\n",
    "                - Computes precision-based scores by comparing the generated text to one or more reference texts.\n",
    "\n",
    "            - ***`Fuzzy Matching`:***\n",
    "                - Uses the custom compute_fuzzy_match function (with rapidfuzz) to measure soft similarity.\n",
    "\n",
    "            - ***`Exact Match`:***\n",
    "                - Uses the compute_exact_match function to calculate the percentage of predictions that exactly match the references after normalization.\n",
    "\n",
    "        - ***The results of the evaluation metrics for both models are printed in a formatted output.***\n",
    "            - This provides a clear quantitative assessment of the improvements (or differences) between the original and fine-tuned models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f364eb6b-56cb-4533-8ef6-b5e7f56895aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:32,235 - INFO - Running inference on 5 examples (displaying real responses).\n",
      "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "====================================================================================================\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Example 1\n",
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n",
      "\n",
      "Query:\n",
      "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "HUMAN RESPONSE:\n",
      "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
      "----------------------------------------------------------------------------------------------------\n",
      "ORIGINAL MODEL OUTPUT:\n",
      "USCYBERCOM, JTF-CND, Offensive Cyber Operations\n",
      "----------------------------------------------------------------------------------------------------\n",
      "FINE-TUNED MODEL OUTPUT:\n",
      "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n",
      "====================================================================================================\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Example 2\n",
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE TABLE incidents (id INT, cause VARCHAR(255), cost INT, date DATE); INSERT INTO incidents (id, cause, cost, date) VALUES (1, 'insider threat', 10000, '2022-01-01'); INSERT INTO incidents (id, cause, cost, date) VALUES (2, 'phishing', 5000, '2022-01-02');\n",
      "\n",
      "Query:\n",
      "Find the total cost of all security incidents caused by insider threats in the last 6 months\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "HUMAN RESPONSE:\n",
      "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
      "----------------------------------------------------------------------------------------------------\n",
      "ORIGINAL MODEL OUTPUT:\n",
      "10000, 2022-01-01\n",
      "----------------------------------------------------------------------------------------------------\n",
      "FINE-TUNED MODEL OUTPUT:\n",
      "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n",
      "====================================================================================================\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Example 3\n",
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE TABLE libraries (name VARCHAR(255), state VARCHAR(255), population DECIMAL(10,2), libraries DECIMAL(5,2)); INSERT INTO libraries (name, state, population, libraries) VALUES ('Library1', 'California', 39512223, 3154), ('Library2', 'Texas', 29528404, 2212), ('Library3', 'Florida', 21644287, 1835);\n",
      "\n",
      "Query:\n",
      "Show the top 3 states with the most public libraries per capita.\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "HUMAN RESPONSE:\n",
      "SELECT state, (libraries / population) AS libraries_per_capita FROM libraries ORDER BY libraries_per_capita DESC LIMIT 3;\n",
      "----------------------------------------------------------------------------------------------------\n",
      "ORIGINAL MODEL OUTPUT:\n",
      "California, 39512223, 3154\n",
      "----------------------------------------------------------------------------------------------------\n",
      "FINE-TUNED MODEL OUTPUT:\n",
      "SELECT state, population, RANK() OVER (ORDER BY population DESC) as rank FROM libraries GROUP BY state ORDER BY rank DESC LIMIT 3;\n",
      "====================================================================================================\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Example 4\n",
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE TABLE users (id INT, location VARCHAR(50)); CREATE TABLE posts (id INT, user_id INT, created_at DATETIME);\n",
      "\n",
      "Query:\n",
      "What is the total number of posts made by users located in Australia, in the last month?\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "HUMAN RESPONSE:\n",
      "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH);\n",
      "----------------------------------------------------------------------------------------------------\n",
      "ORIGINAL MODEL OUTPUT:\n",
      "The total number of posts made by users located in Australia is 50.\n",
      "----------------------------------------------------------------------------------------------------\n",
      "FINE-TUNED MODEL OUTPUT:\n",
      "SELECT COUNT(*) FROM posts p JOIN users u ON p.user_id = u.id WHERE u.location = 'Australia' AND p.created_at >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);\n",
      "====================================================================================================\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 15:01:40,448 - INFO - Starting evaluation on the full test set using batching.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Example 5\n",
      "----------------------------------------------------------------------------------------------------\n",
      "INPUT PROMPT:\n",
      "Context:\n",
      "CREATE TABLE WindFarms (FarmID INT, FarmName VARCHAR(255), Capacity DECIMAL(5,2), Country VARCHAR(255)); INSERT INTO WindFarms (FarmID, FarmName, Capacity, Country) VALUES (1, 'WindFarm1', 150, 'USA'), (2, 'WindFarm2', 200, 'Canada'), (3, 'WindFarm3', 120, 'Mexico');\n",
      "\n",
      "Query:\n",
      "List the total installed capacity of wind farms in the WindEnergy schema for each country?\n",
      "\n",
      "Response:\n",
      "\n",
      "----------------------------------------------------------------------------------------------------\n",
      "HUMAN RESPONSE:\n",
      "SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country;\n",
      "----------------------------------------------------------------------------------------------------\n",
      "ORIGINAL MODEL OUTPUT:\n",
      "1, 150, USA, 2, 200, Canada, 3, 120, Mexico\n",
      "----------------------------------------------------------------------------------------------------\n",
      "FINE-TUNED MODEL OUTPUT:\n",
      "SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n",
      "====================================================================================================\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7beecee09a34f9790be1e4538a87442",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "763373c451c94f5e92bc6a6253109275",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afdce82cb8964da788756d783539ee8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 16:47:58,173 - INFO - Using default tokenizer.\n",
      "2025-03-19 16:49:07,668 - INFO - Using default tokenizer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "====================================================================================================\n",
      "Evaluation Metrics:\n",
      "====================================================================================================\n",
      "ORIGINAL MODEL:\n",
      "  ROUGE: {'rouge1': np.float64(0.05646642898660111), 'rouge2': np.float64(0.01562815013068162), 'rougeL': np.float64(0.05031267225420556), 'rougeLsum': np.float64(0.05036072587316542)}\n",
      "  BLEU: {'bleu': 0.003142147128241449, 'precisions': [0.12293406776920406, 0.03289697910893642, 0.018512080104175887, 0.008342750223825794], 'brevity_penalty': 0.11177079327444009, 'length_ratio': 0.3133514352662089, 'translation_length': 377251, 'reference_length': 1203923}\n",
      "  Fuzzy Match Score: 13.98%\n",
      "  Exact Match Accuracy: 0.00%\n",
      "\n",
      "FINE-TUNED MODEL:\n",
      "  ROUGE: {'rouge1': np.float64(0.7538800834024002), 'rouge2': np.float64(0.6103863808522726), 'rougeL': np.float64(0.7262841884754194), 'rougeLsum': np.float64(0.7261852209847466)}\n",
      "  BLEU: {'bleu': 0.4719774431701209, 'precisions': [0.7603153442288385, 0.598309257795389, 0.5021259810303533, 0.42128998564638875], 'brevity_penalty': 0.8474086962179814, 'length_ratio': 0.8579477258927689, 'translation_length': 1032903, 'reference_length': 1203923}\n",
      "  Fuzzy Match Score: 85.62%\n",
      "  Exact Match Accuracy: 18.29%\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "import re\n",
    "import pandas as pd\n",
    "from rapidfuzz import fuzz\n",
    "import evaluate\n",
    "\n",
    "# Assuming tokenizer, device, original_model, finetuned_model, and dataset are already defined.\n",
    "# Define a helper function for output post-processing.\n",
    "def post_process_output(output_text: str) -> str:\n",
    "    \"\"\"Post-process the generated output to remove repeated text.\"\"\"\n",
    "    # Keep only the first valid SQL query (everything before the first semicolon)\n",
    "    return output_text.split(\";\")[0] + \";\" if \";\" in output_text else output_text\n",
    "\n",
    "# Define a helper function for generating outputs with the given generation parameters.\n",
    "def generate_with_params(model, input_ids):\n",
    "    generated_ids = model.generate(\n",
    "        input_ids=input_ids,\n",
    "        max_new_tokens=100, \n",
    "        num_beams=5,\n",
    "        repetition_penalty=1.2,\n",
    "        temperature=0.1,\n",
    "        early_stopping=True\n",
    "    )\n",
    "    # Decode and post-process output\n",
    "    output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "    return output_text\n",
    "\n",
    "# Helper functions for SQL normalization and evaluation metrics\n",
    "def normalize_sql(sql):\n",
    "    \"\"\"Normalize SQL by stripping whitespace and lowercasing.\"\"\"\n",
    "    return \" \".join(sql.strip().lower().split())\n",
    "\n",
    "def compute_exact_match(predictions, references):\n",
    "    \"\"\"Computes the exact match accuracy after normalization.\"\"\"\n",
    "    matches = sum(1 for pred, ref in zip(predictions, references)\n",
    "                  if normalize_sql(pred) == normalize_sql(ref))\n",
    "    return (matches / len(predictions)) * 100 if predictions else 0\n",
    "\n",
    "def compute_fuzzy_match(predictions, references):\n",
    "    \"\"\"Computes a soft matching score using token_set_ratio from rapidfuzz.\"\"\"\n",
    "    scores = [fuzz.token_set_ratio(pred, ref) for pred, ref in zip(predictions, references)]\n",
    "    return sum(scores) / len(scores) if scores else 0\n",
    "\n",
    "# Dummy function to free up memory if needed.\n",
    "def clear_memory():\n",
    "    # If using torch.cuda, you can clear cache:\n",
    "    # torch.cuda.empty_cache()\n",
    "    pass\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "logger.setLevel(logging.INFO)\n",
    "\n",
    "# --- Part A: Inference on 5 Examples with Real Responses ---\n",
    "logger.info(\"Running inference on 5 examples (displaying real responses).\")\n",
    "\n",
    "num_examples = 5\n",
    "sample_queries = dataset[\"test\"][:num_examples][\"query\"]\n",
    "sample_contexts = dataset[\"test\"][:num_examples][\"context\"]\n",
    "sample_human_responses = dataset[\"test\"][:num_examples][\"response\"]\n",
    "\n",
    "print(\"\\n\" + \"=\" * 100)\n",
    "for idx in range(num_examples):\n",
    "    prompt = f\"\"\"Context:\n",
    "{sample_contexts[idx]}\n",
    "\n",
    "Query:\n",
    "{sample_queries[idx]}\n",
    "\n",
    "Response:\n",
    "\"\"\"\n",
    "    # Tokenize the prompt and move to device\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n",
    "    \n",
    "    # Generate outputs using the modified generation parameters\n",
    "    orig_out = generate_with_params(original_model, inputs[\"input_ids\"])\n",
    "    finetuned_out = post_process_output(generate_with_params(finetuned_model, inputs[\"input_ids\"]))\n",
    "    \n",
    "    print(\"-\" * 100)\n",
    "    print(f\"Example {idx+1}\")\n",
    "    print(\"-\" * 100)\n",
    "    print(\"INPUT PROMPT:\")\n",
    "    print(prompt)\n",
    "    print(\"-\" * 100)\n",
    "    print(\"HUMAN RESPONSE:\")\n",
    "    print(sample_human_responses[idx])\n",
    "    print(\"-\" * 100)\n",
    "    print(\"ORIGINAL MODEL OUTPUT:\")\n",
    "    print(orig_out)\n",
    "    print(\"-\" * 100)\n",
    "    print(\"FINE-TUNED MODEL OUTPUT:\")\n",
    "    print(finetuned_out)\n",
    "    print(\"=\" * 100 + \"\\n\")\n",
    "    clear_memory()\n",
    "\n",
    "# --- Part B: Evaluation on Full Test Set with Batching (Optimized) ---\n",
    "logger.info(\"Starting evaluation on the full test set using batching.\")\n",
    "\n",
    "all_human_responses = []\n",
    "all_original_responses = []\n",
    "all_finetuned_responses = []\n",
    "\n",
    "batch_size = 128  # Adjust based on GPU memory\n",
    "test_dataset = dataset[\"test\"]\n",
    "\n",
    "for i in range(0, len(test_dataset), batch_size):\n",
    "    # Slicing the dataset returns a dict of lists\n",
    "    batch = test_dataset[i:i + batch_size]\n",
    "    \n",
    "    # Construct prompts for each example in the batch\n",
    "    prompts = [\n",
    "        f\"Context:\\n{batch['context'][j]}\\n\\nQuery:\\n{batch['query'][j]}\\n\\nResponse:\"\n",
    "        for j in range(len(batch[\"context\"]))\n",
    "    ]\n",
    "    \n",
    "    # Extend human responses\n",
    "    all_human_responses.extend(batch[\"response\"])\n",
    "    \n",
    "    # Tokenize the batch of prompts with padding and truncation\n",
    "    inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(device)\n",
    "    \n",
    "    # Generate outputs for the batch for both models\n",
    "    orig_ids = original_model.generate(\n",
    "        input_ids=inputs[\"input_ids\"],\n",
    "        max_new_tokens=100,\n",
    "        num_beams=5,\n",
    "        repetition_penalty=1.2,\n",
    "        temperature=0.1,\n",
    "        early_stopping=True\n",
    "    )\n",
    "    finetuned_ids = finetuned_model.generate(\n",
    "        input_ids=inputs[\"input_ids\"],\n",
    "        max_new_tokens=100,\n",
    "        num_beams=5,\n",
    "        repetition_penalty=1.2,\n",
    "        temperature=0.1,\n",
    "        early_stopping=True\n",
    "    )\n",
    "    \n",
    "    # Decode and post-process each sample in the batch\n",
    "    orig_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in orig_ids]\n",
    "    finetuned_texts = [post_process_output(tokenizer.decode(ids, skip_special_tokens=True)) for ids in finetuned_ids]\n",
    "    \n",
    "    all_original_responses.extend(orig_texts)\n",
    "    all_finetuned_responses.extend(finetuned_texts)\n",
    "    clear_memory()\n",
    "\n",
    "# Create a DataFrame for a quick comparison of results\n",
    "zipped_all = list(zip(all_human_responses, all_original_responses, all_finetuned_responses))\n",
    "df_full = pd.DataFrame(zipped_all, columns=[\"Human Response\", \"Original Model Output\", \"Fine-Tuned Model Output\"])\n",
    "df_full.to_csv('evaluation_results.csv', index=False)\n",
    "clear_memory()\n",
    "\n",
    "# --- Compute Evaluation Metrics ---\n",
    "rouge = evaluate.load(\"rouge\")\n",
    "bleu = evaluate.load(\"bleu\")\n",
    "\n",
    "# Compute metrics for the original (non-fine-tuned) model\n",
    "orig_rouge = rouge.compute(\n",
    "    predictions=all_original_responses,\n",
    "    references=all_human_responses,\n",
    "    use_aggregator=True,\n",
    "    use_stemmer=True,\n",
    ")\n",
    "orig_bleu = bleu.compute(\n",
    "    predictions=all_original_responses,\n",
    "    references=[[ref] for ref in all_human_responses]\n",
    ")\n",
    "orig_fuzzy = compute_fuzzy_match(all_original_responses, all_human_responses)\n",
    "orig_exact = compute_exact_match(all_original_responses, all_human_responses)\n",
    "\n",
    "# Compute metrics for the fine-tuned model\n",
    "finetuned_rouge = rouge.compute(\n",
    "    predictions=all_finetuned_responses,\n",
    "    references=all_human_responses,\n",
    "    use_aggregator=True,\n",
    "    use_stemmer=True,\n",
    ")\n",
    "finetuned_bleu = bleu.compute(\n",
    "    predictions=all_finetuned_responses,\n",
    "    references=[[ref] for ref in all_human_responses]\n",
    ")\n",
    "finetuned_fuzzy = compute_fuzzy_match(all_finetuned_responses, all_human_responses)\n",
    "finetuned_exact = compute_exact_match(all_finetuned_responses, all_human_responses)\n",
    "\n",
    "print(\"\\n\" + \"=\" * 100)\n",
    "print(\"Evaluation Metrics:\")\n",
    "print(\"=\" * 100)\n",
    "print(\"ORIGINAL MODEL:\")\n",
    "print(f\"  ROUGE: {orig_rouge}\")\n",
    "print(f\"  BLEU: {orig_bleu}\")\n",
    "print(f\"  Fuzzy Match Score: {orig_fuzzy:.2f}%\")\n",
    "print(f\"  Exact Match Accuracy: {orig_exact:.2f}%\\n\")\n",
    "print(\"FINE-TUNED MODEL:\")\n",
    "print(f\"  ROUGE: {finetuned_rouge}\")\n",
    "print(f\"  BLEU: {finetuned_bleu}\")\n",
    "print(f\"  Fuzzy Match Score: {finetuned_fuzzy:.2f}%\")\n",
    "print(f\"  Exact Match Accuracy: {finetuned_exact:.2f}%\")\n",
    "print(\"=\" * 100)\n",
    "clear_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63cd485b-2af8-48b8-b05b-ba7bffa632f2",
   "metadata": {},
   "source": [
    "# Inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "462546a7-6928-4723-b00e-23c3a4091d99",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-03-19 16:51:05,225 - INFO - Running inference with deterministic decoding and beam search.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt:\n",
      "Context:\n",
      "CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), (3, 'Charlie', 'Canada'), (4, 'David', 'USA'); INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES (101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), (103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), (105, 4, 900, '2024-03-05');\n",
      "\n",
      "Query:\n",
      "Retrieve the total order amount for each customer, showing only customers from the USA, and sort the result by total order amount in descending order.\n",
      "\n",
      "Response:\n",
      "SELECT customer_id, SUM(total_amount) as total_amount FROM orders JOIN customers ON orders.customer_id = customers.id WHERE customers.country = 'USA' GROUP BY customer_id ORDER BY total_amount DESC;\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "import logging\n",
    "\n",
    "# Set up logging\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format=\"%(asctime)s - %(levelname)s - %(message)s\",\n",
    ")\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "# Ensure device is set (GPU if available)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load the fine-tuned model and tokenizer\n",
    "model_name = \"text2sql_flant5base_finetuned\" \n",
    "finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n",
    "finetuned_model.to(device)\n",
    "\n",
    "def run_inference(prompt_text: str) -> str:\n",
    "    \"\"\"\n",
    "    Runs inference on the fine-tuned model using deterministic decoding\n",
    "    with beam search, returning the generated SQL query.\n",
    "    \"\"\"\n",
    "    inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(device)\n",
    "    generated_ids = finetuned_model.generate(\n",
    "        input_ids=inputs[\"input_ids\"],\n",
    "        max_new_tokens=100,   # Adjust based on query complexity\n",
    "        temperature=0.1,      # Deterministic output\n",
    "        num_beams=5,          # Beam search for better output quality\n",
    "        early_stopping=True,  # Stop early if possible\n",
    "    )\n",
    "    generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "\n",
    "    # Post-processing to remove repeated text\n",
    "    generated_sql = generated_sql.split(\";\")[0] + \";\"  # Keep only the first valid SQL query\n",
    "\n",
    "    return generated_sql\n",
    "\n",
    "# Sample context and query (example)\n",
    "context = (\n",
    "    \"CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); \"\n",
    "    \"CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), \"\n",
    "    \"order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); \"\n",
    "    \"INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), \"\n",
    "    \"(3, 'Charlie', 'Canada'), (4, 'David', 'USA'); \"\n",
    "    \"INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES \"\n",
    "    \"(101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), \"\n",
    "    \"(103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), \"\n",
    "    \"(105, 4, 900, '2024-03-05');\"\n",
    ")\n",
    "query = (\n",
    "    \"Retrieve the total order amount for each customer, showing only customers from the USA, \"\n",
    "    \"and sort the result by total order amount in descending order.\"\n",
    ")\n",
    "\n",
    "# Construct the prompt\n",
    "sample_prompt = f\"\"\"Context:\n",
    "{context}\n",
    "\n",
    "Query:\n",
    "{query}\n",
    "\n",
    "Response:\n",
    "\"\"\"\n",
    "\n",
    "logger.info(\"Running inference with deterministic decoding and beam search.\")\n",
    "generated_sql = run_inference(sample_prompt)\n",
    "\n",
    "# Print output in the given format\n",
    "print(\"Prompt:\")\n",
    "print(\"Context:\")\n",
    "print(context)\n",
    "print(\"\\nQuery:\")\n",
    "print(query)\n",
    "print(\"\\nResponse:\")\n",
    "print(generated_sql)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fad246dc-59d0-4757-b83a-f672f389e59f",
   "metadata": {},
   "source": [
    "# Merging and Saving a Fine-Tuned Model with LoRA Adapter\n",
    "\n",
    "This section of code demonstrates how to save a fine-tuned LoRA adapter separately, merge it into the base model to create a fully fine-tuned model, and then save the complete model along with a generation configuration file. Below is an in-depth explanation of every part of the code:\n",
    "\n",
    "---\n",
    "\n",
    "1. ***Defining Paths and Model Identifiers:***\n",
    "\n",
    "    - ***`base_model_name`:***\n",
    "        - Specifies the identifier for the pre-trained T5 model provided by Google (`google/flan-t5-base`).\n",
    "\n",
    "    - ***`lora_model_path`:***\n",
    "        - Indicates the directory where the fine-tuned LoRA adapter is stored. This adapter contains the additional parameters that were updated during fine-tuning.\n",
    "\n",
    "    - ***`full_model_output_path`:***\n",
    "        - Defines the directory where the final merged model (base model with integrated LoRA adapter) will be saved.\n",
    "        This full model can be used for inference without needing to load the adapter separately.\n",
    "\n",
    "2. ***Loading the Base Model and Tokenizer:***\n",
    "\n",
    "    - ***Base Model Loading:***\n",
    "        - The base T5 model is loaded using `AutoModelForSeq2SeqLM.from_pretrained` with a specified data type (`torch.bfloat16`) to optimize memory usage and computational performance.\n",
    "\n",
    "    - ***Tokenizer Loading:***\n",
    "        - The corresponding tokenizer is loaded to ensure that text is pre-processed consistently with the model's requirements.\n",
    "     \n",
    "3. ***Loading the Fine-Tuned LoRA Adapter:***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - This line loads the fine-tuned LoRA adapter into the base model. The adapter contains modifications tailored to a specific task (e.g., text-to-SQL) and enables efficient fine-tuning by updating only a small subset of parameters.\n",
    "\n",
    "    - ***Mechanism:***\n",
    "        - `PeftModel.from_pretrained` wraps the base model with the fine-tuned LoRA adapter parameters stored at `lora_model_path`.\n",
    "     \n",
    "4. ***Saving the LoRA Adapter Separately:***\n",
    "\n",
    "    - ***Why Save Separately?***\n",
    "        - Saving the LoRA adapter by itself allows users who need a lightweight model for further fine-tuning or deployment to load only the adapter parameters rather than the entire model.\n",
    "\n",
    "    - ***Actions:***\n",
    "        - The LoRA adapter and the associated tokenizer are saved to `lora_model_path` using their respective `save_pretrained` methods.\n",
    "     \n",
    "5. ***Merging the LoRA Adapter with the Base Model:***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - The merging process integrates the LoRA adapter parameters into the base model. After merging, the model becomes fully fine-tuned and no longer requires the separate adapter.\n",
    "\n",
    "    - ***Method:***\n",
    "        - `merge_and_unload() combines the adapter weights with the base model’s weights and then unloads the adapter, resulting in a standalone, fully fine-tuned model.\n",
    "\n",
    "6. ***Saving the Fully Merged Fine-Tuned Model:***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - After merging, the complete fine-tuned model (now including the LoRA modifications) is saved to disk for future inference or deployment.\n",
    "\n",
    "    - ***Actions:***\n",
    "        - The `save_pretrained` method is called on both the merged model and the tokenizer, storing them at `full_model_output_path`.\n",
    "     \n",
    "7. ***Saving the Generation Configuration:***\n",
    "\n",
    "    - ***Purpose:***\n",
    "        - Saving a generation configuration is optional but highly recommended. It ensures that inference settings (such as maximum tokens to generate, temperature, beam search settings, etc.) are preserved and can be loaded alongside the model.\n",
    "\n",
    "    - ***Configuration Details:***\n",
    "        - ***`max_new_tokens`:***\n",
    "            - Limits the number of tokens the model will generate in a single inference call.\n",
    "\n",
    "        - ***`temperature`:***\n",
    "            - Controls the randomness of predictions; a lower value (0.1) makes the output more deterministic.\n",
    "\n",
    "        - ***`num_beams`:***\n",
    "            - Specifies the number of beams used in beam search for more diverse and high-quality generation.\n",
    "\n",
    "        - ***`early_stopping`:***\n",
    "            -  Enables stopping the generation process early if certain conditions are met.\n",
    "\n",
    "    - ***Saving Process:***\n",
    "        - The configuration dictionary is saved as a JSON file in the model output directory, allowing for easy retrieval during inference\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a69f268e-bc69-4633-9c15-4e118c20178e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ LoRA adapter saved at: text2sql_flant5base_finetuned\n",
      "✅ Fully merged fine-tuned model saved at: text2sql_flant5base_finetuned_full\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import json\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from peft import PeftModel\n",
    "\n",
    "# Define paths\n",
    "base_model_name = \"google/flan-t5-base\"  # Base model name\n",
    "lora_model_path = \"text2sql_flant5base_finetuned\"  # Folder where LoRA adapter is saved\n",
    "full_model_output_path = \"text2sql_flant5base_finetuned_full\"  # For merged full model\n",
    "\n",
    "# Load base model and tokenizer\n",
    "base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)\n",
    "tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n",
    "\n",
    "# Load fine-tuned LoRA adapter model\n",
    "lora_model = PeftModel.from_pretrained(base_model, lora_model_path)\n",
    "\n",
    "# ✅ Save the LoRA adapter separately (for users who want lightweight adapters)\n",
    "lora_model.save_pretrained(lora_model_path)\n",
    "tokenizer.save_pretrained(lora_model_path)\n",
    "\n",
    "# ✅ Merge LoRA into the base model to create a fully fine-tuned model\n",
    "merged_model = lora_model.merge_and_unload()\n",
    "\n",
    "# ✅ Save the full fine-tuned model\n",
    "merged_model.save_pretrained(full_model_output_path)\n",
    "tokenizer.save_pretrained(full_model_output_path)\n",
    "\n",
    "# ✅ Save generation config (optional but recommended for inference settings)\n",
    "generation_config = {\n",
    "    \"max_new_tokens\": 100,\n",
    "    \"temperature\": 0.1,\n",
    "    \"num_beams\": 5,\n",
    "    \"early_stopping\": True\n",
    "}\n",
    "with open(f\"{full_model_output_path}/generation_config.json\", \"w\") as f:\n",
    "    json.dump(generation_config, f)\n",
    "\n",
    "print(f\"✅ LoRA adapter saved at: {lora_model_path}\")\n",
    "print(f\"✅ Fully merged fine-tuned model saved at: {full_model_output_path}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "940fd86e-b7ec-4417-ade7-7e9eebfb0642",
   "metadata": {},
   "source": [
    "# Inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f1c95dfc-6662-44d8-8ecc-bff414fecee5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "2025-03-19 16:51:49,933 - INFO - Running inference with beam search decoding.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt:\n",
      "Context:\n",
      "CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); INSERT INTO employees (id, name, department, salary) VALUES (1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), (3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); INSERT INTO projects (project_id, project_name, budget) VALUES (101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); INSERT INTO employee_projects (employee_id, project_id, role) VALUES (1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), (4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\n",
      "\n",
      "Query:\n",
      "Find the names of employees who are working on the 'AI Research' project along with their roles.\n",
      "\n",
      "Response:\n",
      "SELECT employees.name, employee_projects.role FROM employees INNER JOIN employee_projects ON employees.id = employee_projects.employee_id INNER JOIN projects ON employee_projects.project_id = projects.project_id WHERE projects.project_name = 'AI Research';\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "import logging\n",
    "\n",
    "# Set up logging\n",
    "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "# Ensure device is set (GPU if available)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load the fine-tuned model and tokenizer\n",
    "model_name = \"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\"\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\")\n",
    "\n",
    "# Ensure decoder start token is set\n",
    "if model.config.decoder_start_token_id is None:\n",
    "    model.config.decoder_start_token_id = tokenizer.pad_token_id\n",
    "\n",
    "def run_inference(prompt_text: str) -> str:\n",
    "    \"\"\"\n",
    "    Runs inference on the fine-tuned model using beam search with fixes for repetition.\n",
    "    \"\"\"\n",
    "    inputs = tokenizer(prompt_text, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n",
    "\n",
    "    generated_ids = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"],\n",
    "        decoder_start_token_id=model.config.decoder_start_token_id, \n",
    "        max_new_tokens=100,  \n",
    "        temperature=0.1, \n",
    "        num_beams=5, \n",
    "        repetition_penalty=1.2,  \n",
    "        early_stopping=True,  \n",
    "    )\n",
    "\n",
    "    generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n",
    "\n",
    "    # Post-processing to remove repeated text\n",
    "    generated_sql = generated_sql.split(\";\")[0] + \";\"  # Keep only the first valid SQL query\n",
    "\n",
    "    return generated_sql\n",
    "\n",
    "# Example usage:\n",
    "context = (\n",
    "    \"CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); \"\n",
    "    \"CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); \"\n",
    "    \"CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), \"\n",
    "    \"FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); \"\n",
    "    \"INSERT INTO employees (id, name, department, salary) VALUES \"\n",
    "    \"(1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), \"\n",
    "    \"(3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); \"\n",
    "    \"INSERT INTO projects (project_id, project_name, budget) VALUES \"\n",
    "    \"(101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); \"\n",
    "    \"INSERT INTO employee_projects (employee_id, project_id, role) VALUES \"\n",
    "    \"(1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), \"\n",
    "    \"(4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\"\n",
    ")\n",
    "\n",
    "query = (\"Find the names of employees who are working on the 'AI Research' project along with their roles.\")\n",
    "\n",
    "\n",
    "\n",
    "# Construct the prompt\n",
    "sample_prompt = f\"\"\"Context:\n",
    "{context}\n",
    "\n",
    "Query:\n",
    "{query}\n",
    "\n",
    "Response:\n",
    "\"\"\"\n",
    "\n",
    "logger.info(\"Running inference with beam search decoding.\")\n",
    "generated_sql = run_inference(sample_prompt)\n",
    "\n",
    "print(\"Prompt:\")\n",
    "print(\"Context:\")\n",
    "print(context)\n",
    "print(\"\\nQuery:\")\n",
    "print(query)\n",
    "print(\"\\nResponse:\")\n",
    "print(generated_sql)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97425ac4-ad46-4f38-b22d-071e161da20a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}