DeanGumas commited on
Commit
f94c0b8
·
1 Parent(s): 98e9ca0

added fine-tune testing on validation set with short queries only

Browse files
Files changed (1) hide show
  1. finetune_model.ipynb +716 -0
finetune_model.ipynb CHANGED
@@ -3329,6 +3329,722 @@
3329
  "\n",
3330
  "# break"
3331
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3332
  }
3333
  ],
3334
  "metadata": {
 
3329
  "\n",
3330
  "# break"
3331
  ]
3332
+ },
3333
+ {
3334
+ "cell_type": "markdown",
3335
+ "metadata": {},
3336
+ "source": [
3337
+ "## Test validation set only on short queries"
3338
+ ]
3339
+ },
3340
+ {
3341
+ "cell_type": "code",
3342
+ "execution_count": 43,
3343
+ "metadata": {},
3344
+ "outputs": [
3345
+ {
3346
+ "name": "stderr",
3347
+ "output_type": "stream",
3348
+ "text": [
3349
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3350
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n",
3351
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
3352
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
3353
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3354
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3355
+ ]
3356
+ },
3357
+ {
3358
+ "name": "stdout",
3359
+ "output_type": "stream",
3360
+ "text": [
3361
+ "SELECT MAX(fg_pct_home) FROM game WHERE team_name_home = 'Miami Heat';\n",
3362
+ "SQLite:\n",
3363
+ "SELECT MAX(fg_pct_home) FROM game WHERE team_name_home = 'Miami Heat'\n",
3364
+ "['0.675']\n",
3365
+ "['0.675']\n",
3366
+ "Statement valid? True\n",
3367
+ "SQLite matched? True\n",
3368
+ "Result matched? True\n",
3369
+ "\n",
3370
+ "\n"
3371
+ ]
3372
+ },
3373
+ {
3374
+ "name": "stderr",
3375
+ "output_type": "stream",
3376
+ "text": [
3377
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3378
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3379
+ ]
3380
+ },
3381
+ {
3382
+ "name": "stdout",
3383
+ "output_type": "stream",
3384
+ "text": [
3385
+ "SELECT MAX(pts_home) FROM game WHERE team_abbreviation_home = 'GSW';\n",
3386
+ "SQLite:\n",
3387
+ "SELECT MAX(pts_home) FROM game WHERE team_name_home = 'Golden State Warriors'\n",
3388
+ "['149.0']\n",
3389
+ "['155.0']\n",
3390
+ "Statement valid? True\n",
3391
+ "SQLite matched? False\n",
3392
+ "Result matched? False\n",
3393
+ "\n",
3394
+ "\n"
3395
+ ]
3396
+ },
3397
+ {
3398
+ "name": "stderr",
3399
+ "output_type": "stream",
3400
+ "text": [
3401
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3402
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3403
+ ]
3404
+ },
3405
+ {
3406
+ "name": "stdout",
3407
+ "output_type": "stream",
3408
+ "text": [
3409
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'CHI';\n",
3410
+ "SQLite:\n",
3411
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Chicago Bulls'\n",
3412
+ "Statement valid? False\n",
3413
+ "SQLite matched? False\n",
3414
+ "Result matched? False\n",
3415
+ "\n",
3416
+ "\n"
3417
+ ]
3418
+ },
3419
+ {
3420
+ "name": "stderr",
3421
+ "output_type": "stream",
3422
+ "text": [
3423
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3424
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3425
+ ]
3426
+ },
3427
+ {
3428
+ "name": "stdout",
3429
+ "output_type": "stream",
3430
+ "text": [
3431
+ "SELECT MAX(pts_fb_home) FROM other_stats WHERE team_abbreviation_home = 'BKN';\n",
3432
+ "SQLite:\n",
3433
+ "SELECT MAX(pts_fb_home) FROM game WHERE team_name_home = 'Brooklyn Nets'\n",
3434
+ "Statement valid? False\n",
3435
+ "SQLite matched? False\n",
3436
+ "Result matched? False\n",
3437
+ "\n",
3438
+ "\n"
3439
+ ]
3440
+ },
3441
+ {
3442
+ "name": "stderr",
3443
+ "output_type": "stream",
3444
+ "text": [
3445
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3446
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3447
+ ]
3448
+ },
3449
+ {
3450
+ "name": "stdout",
3451
+ "output_type": "stream",
3452
+ "text": [
3453
+ "SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'LAL';\n",
3454
+ "SQLite:\n",
3455
+ "SELECT MAX(pts_away) FROM game WHERE team_name_away = 'Los Angeles Lakers'\n",
3456
+ "['153.0']\n",
3457
+ "['153.0']\n",
3458
+ "Statement valid? True\n",
3459
+ "SQLite matched? False\n",
3460
+ "Result matched? True\n",
3461
+ "\n",
3462
+ "\n"
3463
+ ]
3464
+ },
3465
+ {
3466
+ "name": "stderr",
3467
+ "output_type": "stream",
3468
+ "text": [
3469
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3470
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3471
+ ]
3472
+ },
3473
+ {
3474
+ "name": "stdout",
3475
+ "output_type": "stream",
3476
+ "text": [
3477
+ "SELECT MAX(pts_fb_away) FROM other_stats WHERE team_abbreviation_away = 'HOU';\n",
3478
+ "SQLite:\n",
3479
+ "SELECT MAX(pts_fb_away) FROM other_stats WHERE team_name_away = 'Houston Rockets'\n",
3480
+ "Statement valid? False\n",
3481
+ "SQLite matched? False\n",
3482
+ "Result matched? False\n",
3483
+ "\n",
3484
+ "\n"
3485
+ ]
3486
+ },
3487
+ {
3488
+ "name": "stderr",
3489
+ "output_type": "stream",
3490
+ "text": [
3491
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3492
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3493
+ ]
3494
+ },
3495
+ {
3496
+ "name": "stdout",
3497
+ "output_type": "stream",
3498
+ "text": [
3499
+ "SELECT MAX(pts_paint_away) FROM other_stats WHERE team_abbreviation_away = 'MIL';\n",
3500
+ "SQLite:\n",
3501
+ "SELECT MAX(pts_off_to_away) FROM other_stats WHERE team_name_away = 'Milwaukee Bucks'\n",
3502
+ "Statement valid? False\n",
3503
+ "SQLite matched? False\n",
3504
+ "Result matched? False\n",
3505
+ "\n",
3506
+ "\n"
3507
+ ]
3508
+ },
3509
+ {
3510
+ "name": "stderr",
3511
+ "output_type": "stream",
3512
+ "text": [
3513
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3514
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3515
+ ]
3516
+ },
3517
+ {
3518
+ "name": "stdout",
3519
+ "output_type": "stream",
3520
+ "text": [
3521
+ "SELECT MIN(pts_away) FROM game WHERE team_abbreviation_away = 'GSW';\n",
3522
+ "SQLite:\n",
3523
+ "SELECT MIN(pts_away) FROM game WHERE team_name_away = 'Golden State Warriors'\n",
3524
+ "['65.0']\n",
3525
+ "['65.0']\n",
3526
+ "Statement valid? True\n",
3527
+ "SQLite matched? False\n",
3528
+ "Result matched? True\n",
3529
+ "\n",
3530
+ "\n"
3531
+ ]
3532
+ },
3533
+ {
3534
+ "name": "stderr",
3535
+ "output_type": "stream",
3536
+ "text": [
3537
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3538
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3539
+ ]
3540
+ },
3541
+ {
3542
+ "name": "stdout",
3543
+ "output_type": "stream",
3544
+ "text": [
3545
+ "SELECT COUNT(*) FROM game WHERE team_abbreviation_away = 'TOR' AND pts_away = 100;\n",
3546
+ "SQLite:\n",
3547
+ "SELECT COUNT(*) FROM other_stats WHERE team_name_away = 'Toronto Raptors' AND pts_away = 100\n",
3548
+ "Statement valid? False\n",
3549
+ "SQLite matched? False\n",
3550
+ "Result matched? False\n",
3551
+ "\n",
3552
+ "\n"
3553
+ ]
3554
+ },
3555
+ {
3556
+ "name": "stderr",
3557
+ "output_type": "stream",
3558
+ "text": [
3559
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3560
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3561
+ ]
3562
+ },
3563
+ {
3564
+ "name": "stdout",
3565
+ "output_type": "stream",
3566
+ "text": [
3567
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'BOS';\n",
3568
+ "SQLite:\n",
3569
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Boston Celtics'\n",
3570
+ "Statement valid? False\n",
3571
+ "SQLite matched? False\n",
3572
+ "Result matched? False\n",
3573
+ "\n",
3574
+ "\n"
3575
+ ]
3576
+ },
3577
+ {
3578
+ "name": "stderr",
3579
+ "output_type": "stream",
3580
+ "text": [
3581
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3582
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3583
+ ]
3584
+ },
3585
+ {
3586
+ "name": "stdout",
3587
+ "output_type": "stream",
3588
+ "text": [
3589
+ "SELECT MAX(fg3m_home) FROM game WHERE team_name_home = 'Brooklyn Nets';\n",
3590
+ "SQLite: SELECT MAX(fg3m_home) FROM game WHERE team_name_home = 'Brooklyn Nets'\n",
3591
+ "['22.0']\n",
3592
+ "['22.0']\n",
3593
+ "Statement valid? True\n",
3594
+ "SQLite matched? True\n",
3595
+ "Result matched? True\n",
3596
+ "\n",
3597
+ "\n",
3598
+ "SELECT MAX(pts_2nd_chance_home) FROM other_stats WHERE team_abbreviation_home = 'CHI';\n",
3599
+ "SQLite:\n",
3600
+ "SELECT MAX(pts_2nd_chance_home) FROM game WHERE team_name_home = 'Chicago Bulls'\n",
3601
+ "Statement valid? False\n",
3602
+ "SQLite matched? False\n",
3603
+ "Result matched? False\n",
3604
+ "\n",
3605
+ "\n"
3606
+ ]
3607
+ },
3608
+ {
3609
+ "name": "stderr",
3610
+ "output_type": "stream",
3611
+ "text": [
3612
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3613
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3614
+ ]
3615
+ },
3616
+ {
3617
+ "name": "stdout",
3618
+ "output_type": "stream",
3619
+ "text": [
3620
+ "SELECT COUNT(*) FROM game WHERE (pts_home + pts_away) >= 250 AND season_id = '22005';\n",
3621
+ "SQLite:\n",
3622
+ "SELECT COUNT(*) FROM game WHERE (pts_home + pts_away) >= 250 AND season_id = '22005'\n",
3623
+ "['6']\n",
3624
+ "['6']\n",
3625
+ "Statement valid? True\n",
3626
+ "SQLite matched? True\n",
3627
+ "Result matched? True\n",
3628
+ "\n",
3629
+ "\n"
3630
+ ]
3631
+ },
3632
+ {
3633
+ "name": "stderr",
3634
+ "output_type": "stream",
3635
+ "text": [
3636
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3637
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n",
3638
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3639
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3640
+ ]
3641
+ },
3642
+ {
3643
+ "name": "stdout",
3644
+ "output_type": "stream",
3645
+ "text": [
3646
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'MIA';\n",
3647
+ "SQLite:\n",
3648
+ "SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Miami Heat'\n",
3649
+ "Statement valid? False\n",
3650
+ "SQLite matched? False\n",
3651
+ "Result matched? False\n",
3652
+ "\n",
3653
+ "\n"
3654
+ ]
3655
+ },
3656
+ {
3657
+ "name": "stderr",
3658
+ "output_type": "stream",
3659
+ "text": [
3660
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3661
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3662
+ ]
3663
+ },
3664
+ {
3665
+ "name": "stdout",
3666
+ "output_type": "stream",
3667
+ "text": [
3668
+ "SELECT MAX(reb_home) FROM game WHERE team_abbreviation_home = 'NYK';\n",
3669
+ "SQLite:\n",
3670
+ "SELECT MAX(reb_home) FROM game WHERE team_name_home = 'New York Knicks'\n",
3671
+ "['75.0']\n",
3672
+ "['75.0']\n",
3673
+ "Statement valid? True\n",
3674
+ "SQLite matched? False\n",
3675
+ "Result matched? True\n",
3676
+ "\n",
3677
+ "\n"
3678
+ ]
3679
+ },
3680
+ {
3681
+ "name": "stderr",
3682
+ "output_type": "stream",
3683
+ "text": [
3684
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3685
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3686
+ ]
3687
+ },
3688
+ {
3689
+ "name": "stdout",
3690
+ "output_type": "stream",
3691
+ "text": [
3692
+ "SELECT MAX(pts_fb_home) FROM other_stats WHERE team_abbreviation_home = 'HOU';\n",
3693
+ "SQLite:\n",
3694
+ "SELECT MAX(pts_fb_home) FROM game WHERE team_name_home = 'Houston Rockets'\n",
3695
+ "Statement valid? False\n",
3696
+ "SQLite matched? False\n",
3697
+ "Result matched? False\n",
3698
+ "\n",
3699
+ "\n"
3700
+ ]
3701
+ },
3702
+ {
3703
+ "name": "stderr",
3704
+ "output_type": "stream",
3705
+ "text": [
3706
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3707
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3708
+ ]
3709
+ },
3710
+ {
3711
+ "name": "stdout",
3712
+ "output_type": "stream",
3713
+ "text": [
3714
+ "SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'CLE';\n",
3715
+ "SQLite:\n",
3716
+ "SELECT MAX(pts_away) FROM game WHERE team_name_away = 'Cleveland Cavaliers' AND season_type = 'Regular'\n",
3717
+ "['140.0']\n",
3718
+ "['None']\n",
3719
+ "Statement valid? True\n",
3720
+ "SQLite matched? False\n",
3721
+ "Result matched? False\n",
3722
+ "\n",
3723
+ "\n"
3724
+ ]
3725
+ },
3726
+ {
3727
+ "name": "stderr",
3728
+ "output_type": "stream",
3729
+ "text": [
3730
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3731
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3732
+ ]
3733
+ },
3734
+ {
3735
+ "name": "stdout",
3736
+ "output_type": "stream",
3737
+ "text": [
3738
+ "SELECT MAX(total_turnovers_home) FROM other_stats WHERE team_abbreviation_home = 'PHI';\n",
3739
+ "SQLite:\n",
3740
+ "SELECT MAX(total_turnovers_home) FROM other_stats WHERE team_name_home = 'Philadelphia 76ers'\n",
3741
+ "Statement valid? False\n",
3742
+ "SQLite matched? False\n",
3743
+ "Result matched? False\n",
3744
+ "\n",
3745
+ "\n"
3746
+ ]
3747
+ },
3748
+ {
3749
+ "name": "stderr",
3750
+ "output_type": "stream",
3751
+ "text": [
3752
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3753
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3754
+ ]
3755
+ },
3756
+ {
3757
+ "name": "stdout",
3758
+ "output_type": "stream",
3759
+ "text": [
3760
+ "SELECT MAX(largest_lead_away) FROM other_stats WHERE team_abbreviation_away = 'NYK';\n",
3761
+ "SQLite:\n",
3762
+ "SELECT MAX(largest_lead_away) FROM other_stats WHERE team_name_away = 'New York Knicks'\n",
3763
+ "Statement valid? False\n",
3764
+ "SQLite matched? False\n",
3765
+ "Result matched? False\n",
3766
+ "\n",
3767
+ "\n"
3768
+ ]
3769
+ },
3770
+ {
3771
+ "name": "stderr",
3772
+ "output_type": "stream",
3773
+ "text": [
3774
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3775
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3776
+ ]
3777
+ },
3778
+ {
3779
+ "name": "stdout",
3780
+ "output_type": "stream",
3781
+ "text": [
3782
+ "SELECT MAX(total_turnovers_away) FROM other_stats WHERE team_abbreviation_away = 'BKN';\n",
3783
+ "SQLite:\n",
3784
+ "SELECT MAX(team_turnovers_away) FROM other_stats WHERE team_name_away = 'Brooklyn Nets'\n",
3785
+ "Statement valid? False\n",
3786
+ "SQLite matched? False\n",
3787
+ "Result matched? False\n",
3788
+ "\n",
3789
+ "\n"
3790
+ ]
3791
+ },
3792
+ {
3793
+ "name": "stderr",
3794
+ "output_type": "stream",
3795
+ "text": [
3796
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3797
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3798
+ ]
3799
+ },
3800
+ {
3801
+ "name": "stdout",
3802
+ "output_type": "stream",
3803
+ "text": [
3804
+ "SELECT COUNT(*) FROM game WHERE ABS(pts_home - pts_away) = 1;\n",
3805
+ "SQLite:\n",
3806
+ "SELECT COUNT(*) FROM game WHERE pts_home - pts_away = 1\n",
3807
+ "['3050']\n",
3808
+ "['1572']\n",
3809
+ "Statement valid? True\n",
3810
+ "SQLite matched? False\n",
3811
+ "Result matched? False\n",
3812
+ "\n",
3813
+ "\n"
3814
+ ]
3815
+ },
3816
+ {
3817
+ "name": "stderr",
3818
+ "output_type": "stream",
3819
+ "text": [
3820
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3821
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3822
+ ]
3823
+ },
3824
+ {
3825
+ "name": "stdout",
3826
+ "output_type": "stream",
3827
+ "text": [
3828
+ "SELECT COUNT(*) FROM game WHERE team_abbreviation_home = 'LAL' AND season_id = '22010';\n",
3829
+ "SQLite: SELECT COUNT(*) FROM game WHERE team_name_home = 'Los Angeles Lakers' AND season_id = '22010'\n",
3830
+ "['41']\n",
3831
+ "['41']\n",
3832
+ "Statement valid? True\n",
3833
+ "SQLite matched? False\n",
3834
+ "Result matched? True\n",
3835
+ "\n",
3836
+ "\n"
3837
+ ]
3838
+ },
3839
+ {
3840
+ "name": "stderr",
3841
+ "output_type": "stream",
3842
+ "text": [
3843
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3844
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3845
+ ]
3846
+ },
3847
+ {
3848
+ "name": "stdout",
3849
+ "output_type": "stream",
3850
+ "text": [
3851
+ "SELECT MAX(pts_home) FROM game;\n",
3852
+ "SQLite:\n",
3853
+ "SELECT game_id, MAX(pts_home) AS highest_scoring_home_game FROM game GROUP BY game_id ORDER BY highest_scoring_home_game DESC LIMIT 1\n",
3854
+ "['192.0']\n",
3855
+ "['0031600001', '192.0']\n",
3856
+ "Statement valid? True\n",
3857
+ "SQLite matched? False\n",
3858
+ "Result matched? True\n",
3859
+ "\n",
3860
+ "\n"
3861
+ ]
3862
+ },
3863
+ {
3864
+ "name": "stderr",
3865
+ "output_type": "stream",
3866
+ "text": [
3867
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3868
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3869
+ ]
3870
+ },
3871
+ {
3872
+ "name": "stdout",
3873
+ "output_type": "stream",
3874
+ "text": [
3875
+ "SELECT MIN(fg_pct_away) FROM game;\n",
3876
+ "SQLite:\n",
3877
+ "SELECT MIN(fg_pct_away) FROM game\n",
3878
+ "['0.156']\n",
3879
+ "['0.156']\n",
3880
+ "Statement valid? True\n",
3881
+ "SQLite matched? True\n",
3882
+ "Result matched? True\n",
3883
+ "\n",
3884
+ "\n"
3885
+ ]
3886
+ },
3887
+ {
3888
+ "name": "stderr",
3889
+ "output_type": "stream",
3890
+ "text": [
3891
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3892
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3893
+ ]
3894
+ },
3895
+ {
3896
+ "name": "stdout",
3897
+ "output_type": "stream",
3898
+ "text": [
3899
+ "SELECT COUNT(*) FROM game WHERE ABS(pts_home - pts_away) > 40;\n",
3900
+ "SQLite:\n",
3901
+ "SELECT COUNT(*) FROM game WHERE pts_home > 40 OR pts_away > 40\n",
3902
+ "['307']\n",
3903
+ "['65696']\n",
3904
+ "Statement valid? True\n",
3905
+ "SQLite matched? False\n",
3906
+ "Result matched? False\n",
3907
+ "\n",
3908
+ "\n"
3909
+ ]
3910
+ },
3911
+ {
3912
+ "name": "stderr",
3913
+ "output_type": "stream",
3914
+ "text": [
3915
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3916
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3917
+ ]
3918
+ },
3919
+ {
3920
+ "name": "stdout",
3921
+ "output_type": "stream",
3922
+ "text": [
3923
+ "SELECT MAX(stl_away) FROM game;\n",
3924
+ "SQLite:\n",
3925
+ "SELECT MAX(stl_away) FROM game\n",
3926
+ "['27.0']\n",
3927
+ "['27.0']\n",
3928
+ "Statement valid? True\n",
3929
+ "SQLite matched? True\n",
3930
+ "Result matched? True\n",
3931
+ "\n",
3932
+ "\n"
3933
+ ]
3934
+ },
3935
+ {
3936
+ "name": "stderr",
3937
+ "output_type": "stream",
3938
+ "text": [
3939
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3940
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3941
+ ]
3942
+ },
3943
+ {
3944
+ "name": "stdout",
3945
+ "output_type": "stream",
3946
+ "text": [
3947
+ "SELECT COUNT(*) FROM game WHERE pts_home >= 120 AND pts_away >= 120;\n",
3948
+ "SQLite:\n",
3949
+ "SELECT COUNT(*) FROM game WHERE pts_home >= 120 AND pts_away >= 120\n",
3950
+ "['3043']\n",
3951
+ "['3043']\n",
3952
+ "Statement valid? True\n",
3953
+ "SQLite matched? True\n",
3954
+ "Result matched? True\n",
3955
+ "\n",
3956
+ "\n"
3957
+ ]
3958
+ },
3959
+ {
3960
+ "name": "stderr",
3961
+ "output_type": "stream",
3962
+ "text": [
3963
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
3964
+ "Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
3965
+ ]
3966
+ },
3967
+ {
3968
+ "name": "stdout",
3969
+ "output_type": "stream",
3970
+ "text": [
3971
+ "SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30;\n",
3972
+ "SQLite:\n",
3973
+ "SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30\n",
3974
+ "['11305']\n",
3975
+ "['11305']\n",
3976
+ "Statement valid? True\n",
3977
+ "SQLite matched? True\n",
3978
+ "Result matched? True\n",
3979
+ "\n",
3980
+ "\n",
3981
+ "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n",
3982
+ "SQLite:\n",
3983
+ "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers'\n",
3984
+ "['44.0']\n",
3985
+ "['44.0']\n",
3986
+ "Statement valid? True\n",
3987
+ "SQLite matched? True\n",
3988
+ "Result matched? True\n",
3989
+ "\n",
3990
+ "\n",
3991
+ "Percent valid: 0.5862068965517241\n",
3992
+ "Percent SQLite matched: 0.27586206896551724\n",
3993
+ "Percent result matched: 0.4482758620689655\n"
3994
+ ]
3995
+ }
3996
+ ],
3997
+ "source": [
3998
+ "num_valid = 0\n",
3999
+ "num_sql_matched = 0\n",
4000
+ "num_result_matched = 0\n",
4001
+ "counter = 0\n",
4002
+ "for v in val_dataset:\n",
4003
+ " # Obtain sample natural language question and sql_query\n",
4004
+ " #v = val_dataset[random.randint(0, len(val_dataset) - 1)]\n",
4005
+ " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n",
4006
+ " user_prompt = full_example[:prompt_length]\n",
4007
+ " question, sql_query = full_example[prompt_length:].split(\"SQLite:\\n\")\n",
4008
+ " #print(question)\n",
4009
+ " #print(sql_query)\n",
4010
+ "\n",
4011
+ " if len(sql_query) <= 90:\n",
4012
+ " # Obtain model output\n",
4013
+ " input_text = \"How many points to the Los Angeles Lakers average at home?\"\n",
4014
+ " message = [{'role': 'user', 'content': input_prompt + question}]\n",
4015
+ " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
4016
+ "\n",
4017
+ " # Generate SQL query\n",
4018
+ " outputs = model.generate(\n",
4019
+ " inputs,\n",
4020
+ " max_new_tokens=256,\n",
4021
+ " eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n",
4022
+ " )\n",
4023
+ " model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
4024
+ "\n",
4025
+ " print(sql_query)\n",
4026
+ " print(model_output.split(\";\")[0])\n",
4027
+ " #print()\n",
4028
+ " #print(model_output)\n",
4029
+ " result = compare_result(sql_query, model_output)\n",
4030
+ " print(\"Statement valid? \" + str(result[0]))\n",
4031
+ " print(\"SQLite matched? \" + str(result[1]))\n",
4032
+ " print(\"Result matched? \" + str(result[2]))\n",
4033
+ " print()\n",
4034
+ " print()\n",
4035
+ " counter += 1\n",
4036
+ "\n",
4037
+ " if result[0]:\n",
4038
+ " num_valid += 1\n",
4039
+ " if result[1]:\n",
4040
+ " num_sql_matched += 1\n",
4041
+ " if result[2]:\n",
4042
+ " num_result_matched += 1\n",
4043
+ "\n",
4044
+ "print(\"Percent valid: \" + str(num_valid / counter))\n",
4045
+ "print(\"Percent SQLite matched: \" + str(num_sql_matched / counter))\n",
4046
+ "print(\"Percent result matched: \" + str(num_result_matched / counter))"
4047
+ ]
4048
  }
4049
  ],
4050
  "metadata": {