added fine-tune testing on validation set with short queries only
Browse files- finetune_model.ipynb +716 -0
finetune_model.ipynb
CHANGED
@@ -3329,6 +3329,722 @@
|
|
3329 |
"\n",
|
3330 |
"# break"
|
3331 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3332 |
}
|
3333 |
],
|
3334 |
"metadata": {
|
|
|
3329 |
"\n",
|
3330 |
"# break"
|
3331 |
]
|
3332 |
+
},
|
3333 |
+
{
|
3334 |
+
"cell_type": "markdown",
|
3335 |
+
"metadata": {},
|
3336 |
+
"source": [
|
3337 |
+
"## Test validation set only on short queries"
|
3338 |
+
]
|
3339 |
+
},
|
3340 |
+
{
|
3341 |
+
"cell_type": "code",
|
3342 |
+
"execution_count": 43,
|
3343 |
+
"metadata": {},
|
3344 |
+
"outputs": [
|
3345 |
+
{
|
3346 |
+
"name": "stderr",
|
3347 |
+
"output_type": "stream",
|
3348 |
+
"text": [
|
3349 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3350 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n",
|
3351 |
+
"c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\bitsandbytes\\autograd\\_functions.py:315: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
|
3352 |
+
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
|
3353 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3354 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3355 |
+
]
|
3356 |
+
},
|
3357 |
+
{
|
3358 |
+
"name": "stdout",
|
3359 |
+
"output_type": "stream",
|
3360 |
+
"text": [
|
3361 |
+
"SELECT MAX(fg_pct_home) FROM game WHERE team_name_home = 'Miami Heat';\n",
|
3362 |
+
"SQLite:\n",
|
3363 |
+
"SELECT MAX(fg_pct_home) FROM game WHERE team_name_home = 'Miami Heat'\n",
|
3364 |
+
"['0.675']\n",
|
3365 |
+
"['0.675']\n",
|
3366 |
+
"Statement valid? True\n",
|
3367 |
+
"SQLite matched? True\n",
|
3368 |
+
"Result matched? True\n",
|
3369 |
+
"\n",
|
3370 |
+
"\n"
|
3371 |
+
]
|
3372 |
+
},
|
3373 |
+
{
|
3374 |
+
"name": "stderr",
|
3375 |
+
"output_type": "stream",
|
3376 |
+
"text": [
|
3377 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3378 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3379 |
+
]
|
3380 |
+
},
|
3381 |
+
{
|
3382 |
+
"name": "stdout",
|
3383 |
+
"output_type": "stream",
|
3384 |
+
"text": [
|
3385 |
+
"SELECT MAX(pts_home) FROM game WHERE team_abbreviation_home = 'GSW';\n",
|
3386 |
+
"SQLite:\n",
|
3387 |
+
"SELECT MAX(pts_home) FROM game WHERE team_name_home = 'Golden State Warriors'\n",
|
3388 |
+
"['149.0']\n",
|
3389 |
+
"['155.0']\n",
|
3390 |
+
"Statement valid? True\n",
|
3391 |
+
"SQLite matched? False\n",
|
3392 |
+
"Result matched? False\n",
|
3393 |
+
"\n",
|
3394 |
+
"\n"
|
3395 |
+
]
|
3396 |
+
},
|
3397 |
+
{
|
3398 |
+
"name": "stderr",
|
3399 |
+
"output_type": "stream",
|
3400 |
+
"text": [
|
3401 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3402 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3403 |
+
]
|
3404 |
+
},
|
3405 |
+
{
|
3406 |
+
"name": "stdout",
|
3407 |
+
"output_type": "stream",
|
3408 |
+
"text": [
|
3409 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'CHI';\n",
|
3410 |
+
"SQLite:\n",
|
3411 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Chicago Bulls'\n",
|
3412 |
+
"Statement valid? False\n",
|
3413 |
+
"SQLite matched? False\n",
|
3414 |
+
"Result matched? False\n",
|
3415 |
+
"\n",
|
3416 |
+
"\n"
|
3417 |
+
]
|
3418 |
+
},
|
3419 |
+
{
|
3420 |
+
"name": "stderr",
|
3421 |
+
"output_type": "stream",
|
3422 |
+
"text": [
|
3423 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3424 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3425 |
+
]
|
3426 |
+
},
|
3427 |
+
{
|
3428 |
+
"name": "stdout",
|
3429 |
+
"output_type": "stream",
|
3430 |
+
"text": [
|
3431 |
+
"SELECT MAX(pts_fb_home) FROM other_stats WHERE team_abbreviation_home = 'BKN';\n",
|
3432 |
+
"SQLite:\n",
|
3433 |
+
"SELECT MAX(pts_fb_home) FROM game WHERE team_name_home = 'Brooklyn Nets'\n",
|
3434 |
+
"Statement valid? False\n",
|
3435 |
+
"SQLite matched? False\n",
|
3436 |
+
"Result matched? False\n",
|
3437 |
+
"\n",
|
3438 |
+
"\n"
|
3439 |
+
]
|
3440 |
+
},
|
3441 |
+
{
|
3442 |
+
"name": "stderr",
|
3443 |
+
"output_type": "stream",
|
3444 |
+
"text": [
|
3445 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3446 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3447 |
+
]
|
3448 |
+
},
|
3449 |
+
{
|
3450 |
+
"name": "stdout",
|
3451 |
+
"output_type": "stream",
|
3452 |
+
"text": [
|
3453 |
+
"SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'LAL';\n",
|
3454 |
+
"SQLite:\n",
|
3455 |
+
"SELECT MAX(pts_away) FROM game WHERE team_name_away = 'Los Angeles Lakers'\n",
|
3456 |
+
"['153.0']\n",
|
3457 |
+
"['153.0']\n",
|
3458 |
+
"Statement valid? True\n",
|
3459 |
+
"SQLite matched? False\n",
|
3460 |
+
"Result matched? True\n",
|
3461 |
+
"\n",
|
3462 |
+
"\n"
|
3463 |
+
]
|
3464 |
+
},
|
3465 |
+
{
|
3466 |
+
"name": "stderr",
|
3467 |
+
"output_type": "stream",
|
3468 |
+
"text": [
|
3469 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3470 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3471 |
+
]
|
3472 |
+
},
|
3473 |
+
{
|
3474 |
+
"name": "stdout",
|
3475 |
+
"output_type": "stream",
|
3476 |
+
"text": [
|
3477 |
+
"SELECT MAX(pts_fb_away) FROM other_stats WHERE team_abbreviation_away = 'HOU';\n",
|
3478 |
+
"SQLite:\n",
|
3479 |
+
"SELECT MAX(pts_fb_away) FROM other_stats WHERE team_name_away = 'Houston Rockets'\n",
|
3480 |
+
"Statement valid? False\n",
|
3481 |
+
"SQLite matched? False\n",
|
3482 |
+
"Result matched? False\n",
|
3483 |
+
"\n",
|
3484 |
+
"\n"
|
3485 |
+
]
|
3486 |
+
},
|
3487 |
+
{
|
3488 |
+
"name": "stderr",
|
3489 |
+
"output_type": "stream",
|
3490 |
+
"text": [
|
3491 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3492 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3493 |
+
]
|
3494 |
+
},
|
3495 |
+
{
|
3496 |
+
"name": "stdout",
|
3497 |
+
"output_type": "stream",
|
3498 |
+
"text": [
|
3499 |
+
"SELECT MAX(pts_paint_away) FROM other_stats WHERE team_abbreviation_away = 'MIL';\n",
|
3500 |
+
"SQLite:\n",
|
3501 |
+
"SELECT MAX(pts_off_to_away) FROM other_stats WHERE team_name_away = 'Milwaukee Bucks'\n",
|
3502 |
+
"Statement valid? False\n",
|
3503 |
+
"SQLite matched? False\n",
|
3504 |
+
"Result matched? False\n",
|
3505 |
+
"\n",
|
3506 |
+
"\n"
|
3507 |
+
]
|
3508 |
+
},
|
3509 |
+
{
|
3510 |
+
"name": "stderr",
|
3511 |
+
"output_type": "stream",
|
3512 |
+
"text": [
|
3513 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3514 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3515 |
+
]
|
3516 |
+
},
|
3517 |
+
{
|
3518 |
+
"name": "stdout",
|
3519 |
+
"output_type": "stream",
|
3520 |
+
"text": [
|
3521 |
+
"SELECT MIN(pts_away) FROM game WHERE team_abbreviation_away = 'GSW';\n",
|
3522 |
+
"SQLite:\n",
|
3523 |
+
"SELECT MIN(pts_away) FROM game WHERE team_name_away = 'Golden State Warriors'\n",
|
3524 |
+
"['65.0']\n",
|
3525 |
+
"['65.0']\n",
|
3526 |
+
"Statement valid? True\n",
|
3527 |
+
"SQLite matched? False\n",
|
3528 |
+
"Result matched? True\n",
|
3529 |
+
"\n",
|
3530 |
+
"\n"
|
3531 |
+
]
|
3532 |
+
},
|
3533 |
+
{
|
3534 |
+
"name": "stderr",
|
3535 |
+
"output_type": "stream",
|
3536 |
+
"text": [
|
3537 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3538 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3539 |
+
]
|
3540 |
+
},
|
3541 |
+
{
|
3542 |
+
"name": "stdout",
|
3543 |
+
"output_type": "stream",
|
3544 |
+
"text": [
|
3545 |
+
"SELECT COUNT(*) FROM game WHERE team_abbreviation_away = 'TOR' AND pts_away = 100;\n",
|
3546 |
+
"SQLite:\n",
|
3547 |
+
"SELECT COUNT(*) FROM other_stats WHERE team_name_away = 'Toronto Raptors' AND pts_away = 100\n",
|
3548 |
+
"Statement valid? False\n",
|
3549 |
+
"SQLite matched? False\n",
|
3550 |
+
"Result matched? False\n",
|
3551 |
+
"\n",
|
3552 |
+
"\n"
|
3553 |
+
]
|
3554 |
+
},
|
3555 |
+
{
|
3556 |
+
"name": "stderr",
|
3557 |
+
"output_type": "stream",
|
3558 |
+
"text": [
|
3559 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3560 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3561 |
+
]
|
3562 |
+
},
|
3563 |
+
{
|
3564 |
+
"name": "stdout",
|
3565 |
+
"output_type": "stream",
|
3566 |
+
"text": [
|
3567 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'BOS';\n",
|
3568 |
+
"SQLite:\n",
|
3569 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Boston Celtics'\n",
|
3570 |
+
"Statement valid? False\n",
|
3571 |
+
"SQLite matched? False\n",
|
3572 |
+
"Result matched? False\n",
|
3573 |
+
"\n",
|
3574 |
+
"\n"
|
3575 |
+
]
|
3576 |
+
},
|
3577 |
+
{
|
3578 |
+
"name": "stderr",
|
3579 |
+
"output_type": "stream",
|
3580 |
+
"text": [
|
3581 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3582 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3583 |
+
]
|
3584 |
+
},
|
3585 |
+
{
|
3586 |
+
"name": "stdout",
|
3587 |
+
"output_type": "stream",
|
3588 |
+
"text": [
|
3589 |
+
"SELECT MAX(fg3m_home) FROM game WHERE team_name_home = 'Brooklyn Nets';\n",
|
3590 |
+
"SQLite: SELECT MAX(fg3m_home) FROM game WHERE team_name_home = 'Brooklyn Nets'\n",
|
3591 |
+
"['22.0']\n",
|
3592 |
+
"['22.0']\n",
|
3593 |
+
"Statement valid? True\n",
|
3594 |
+
"SQLite matched? True\n",
|
3595 |
+
"Result matched? True\n",
|
3596 |
+
"\n",
|
3597 |
+
"\n",
|
3598 |
+
"SELECT MAX(pts_2nd_chance_home) FROM other_stats WHERE team_abbreviation_home = 'CHI';\n",
|
3599 |
+
"SQLite:\n",
|
3600 |
+
"SELECT MAX(pts_2nd_chance_home) FROM game WHERE team_name_home = 'Chicago Bulls'\n",
|
3601 |
+
"Statement valid? False\n",
|
3602 |
+
"SQLite matched? False\n",
|
3603 |
+
"Result matched? False\n",
|
3604 |
+
"\n",
|
3605 |
+
"\n"
|
3606 |
+
]
|
3607 |
+
},
|
3608 |
+
{
|
3609 |
+
"name": "stderr",
|
3610 |
+
"output_type": "stream",
|
3611 |
+
"text": [
|
3612 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3613 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3614 |
+
]
|
3615 |
+
},
|
3616 |
+
{
|
3617 |
+
"name": "stdout",
|
3618 |
+
"output_type": "stream",
|
3619 |
+
"text": [
|
3620 |
+
"SELECT COUNT(*) FROM game WHERE (pts_home + pts_away) >= 250 AND season_id = '22005';\n",
|
3621 |
+
"SQLite:\n",
|
3622 |
+
"SELECT COUNT(*) FROM game WHERE (pts_home + pts_away) >= 250 AND season_id = '22005'\n",
|
3623 |
+
"['6']\n",
|
3624 |
+
"['6']\n",
|
3625 |
+
"Statement valid? True\n",
|
3626 |
+
"SQLite matched? True\n",
|
3627 |
+
"Result matched? True\n",
|
3628 |
+
"\n",
|
3629 |
+
"\n"
|
3630 |
+
]
|
3631 |
+
},
|
3632 |
+
{
|
3633 |
+
"name": "stderr",
|
3634 |
+
"output_type": "stream",
|
3635 |
+
"text": [
|
3636 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3637 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n",
|
3638 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3639 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3640 |
+
]
|
3641 |
+
},
|
3642 |
+
{
|
3643 |
+
"name": "stdout",
|
3644 |
+
"output_type": "stream",
|
3645 |
+
"text": [
|
3646 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_abbreviation_home = 'MIA';\n",
|
3647 |
+
"SQLite:\n",
|
3648 |
+
"SELECT MAX(largest_lead_home) FROM other_stats WHERE team_name_home = 'Miami Heat'\n",
|
3649 |
+
"Statement valid? False\n",
|
3650 |
+
"SQLite matched? False\n",
|
3651 |
+
"Result matched? False\n",
|
3652 |
+
"\n",
|
3653 |
+
"\n"
|
3654 |
+
]
|
3655 |
+
},
|
3656 |
+
{
|
3657 |
+
"name": "stderr",
|
3658 |
+
"output_type": "stream",
|
3659 |
+
"text": [
|
3660 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3661 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3662 |
+
]
|
3663 |
+
},
|
3664 |
+
{
|
3665 |
+
"name": "stdout",
|
3666 |
+
"output_type": "stream",
|
3667 |
+
"text": [
|
3668 |
+
"SELECT MAX(reb_home) FROM game WHERE team_abbreviation_home = 'NYK';\n",
|
3669 |
+
"SQLite:\n",
|
3670 |
+
"SELECT MAX(reb_home) FROM game WHERE team_name_home = 'New York Knicks'\n",
|
3671 |
+
"['75.0']\n",
|
3672 |
+
"['75.0']\n",
|
3673 |
+
"Statement valid? True\n",
|
3674 |
+
"SQLite matched? False\n",
|
3675 |
+
"Result matched? True\n",
|
3676 |
+
"\n",
|
3677 |
+
"\n"
|
3678 |
+
]
|
3679 |
+
},
|
3680 |
+
{
|
3681 |
+
"name": "stderr",
|
3682 |
+
"output_type": "stream",
|
3683 |
+
"text": [
|
3684 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3685 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3686 |
+
]
|
3687 |
+
},
|
3688 |
+
{
|
3689 |
+
"name": "stdout",
|
3690 |
+
"output_type": "stream",
|
3691 |
+
"text": [
|
3692 |
+
"SELECT MAX(pts_fb_home) FROM other_stats WHERE team_abbreviation_home = 'HOU';\n",
|
3693 |
+
"SQLite:\n",
|
3694 |
+
"SELECT MAX(pts_fb_home) FROM game WHERE team_name_home = 'Houston Rockets'\n",
|
3695 |
+
"Statement valid? False\n",
|
3696 |
+
"SQLite matched? False\n",
|
3697 |
+
"Result matched? False\n",
|
3698 |
+
"\n",
|
3699 |
+
"\n"
|
3700 |
+
]
|
3701 |
+
},
|
3702 |
+
{
|
3703 |
+
"name": "stderr",
|
3704 |
+
"output_type": "stream",
|
3705 |
+
"text": [
|
3706 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3707 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3708 |
+
]
|
3709 |
+
},
|
3710 |
+
{
|
3711 |
+
"name": "stdout",
|
3712 |
+
"output_type": "stream",
|
3713 |
+
"text": [
|
3714 |
+
"SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'CLE';\n",
|
3715 |
+
"SQLite:\n",
|
3716 |
+
"SELECT MAX(pts_away) FROM game WHERE team_name_away = 'Cleveland Cavaliers' AND season_type = 'Regular'\n",
|
3717 |
+
"['140.0']\n",
|
3718 |
+
"['None']\n",
|
3719 |
+
"Statement valid? True\n",
|
3720 |
+
"SQLite matched? False\n",
|
3721 |
+
"Result matched? False\n",
|
3722 |
+
"\n",
|
3723 |
+
"\n"
|
3724 |
+
]
|
3725 |
+
},
|
3726 |
+
{
|
3727 |
+
"name": "stderr",
|
3728 |
+
"output_type": "stream",
|
3729 |
+
"text": [
|
3730 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3731 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3732 |
+
]
|
3733 |
+
},
|
3734 |
+
{
|
3735 |
+
"name": "stdout",
|
3736 |
+
"output_type": "stream",
|
3737 |
+
"text": [
|
3738 |
+
"SELECT MAX(total_turnovers_home) FROM other_stats WHERE team_abbreviation_home = 'PHI';\n",
|
3739 |
+
"SQLite:\n",
|
3740 |
+
"SELECT MAX(total_turnovers_home) FROM other_stats WHERE team_name_home = 'Philadelphia 76ers'\n",
|
3741 |
+
"Statement valid? False\n",
|
3742 |
+
"SQLite matched? False\n",
|
3743 |
+
"Result matched? False\n",
|
3744 |
+
"\n",
|
3745 |
+
"\n"
|
3746 |
+
]
|
3747 |
+
},
|
3748 |
+
{
|
3749 |
+
"name": "stderr",
|
3750 |
+
"output_type": "stream",
|
3751 |
+
"text": [
|
3752 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3753 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3754 |
+
]
|
3755 |
+
},
|
3756 |
+
{
|
3757 |
+
"name": "stdout",
|
3758 |
+
"output_type": "stream",
|
3759 |
+
"text": [
|
3760 |
+
"SELECT MAX(largest_lead_away) FROM other_stats WHERE team_abbreviation_away = 'NYK';\n",
|
3761 |
+
"SQLite:\n",
|
3762 |
+
"SELECT MAX(largest_lead_away) FROM other_stats WHERE team_name_away = 'New York Knicks'\n",
|
3763 |
+
"Statement valid? False\n",
|
3764 |
+
"SQLite matched? False\n",
|
3765 |
+
"Result matched? False\n",
|
3766 |
+
"\n",
|
3767 |
+
"\n"
|
3768 |
+
]
|
3769 |
+
},
|
3770 |
+
{
|
3771 |
+
"name": "stderr",
|
3772 |
+
"output_type": "stream",
|
3773 |
+
"text": [
|
3774 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3775 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3776 |
+
]
|
3777 |
+
},
|
3778 |
+
{
|
3779 |
+
"name": "stdout",
|
3780 |
+
"output_type": "stream",
|
3781 |
+
"text": [
|
3782 |
+
"SELECT MAX(total_turnovers_away) FROM other_stats WHERE team_abbreviation_away = 'BKN';\n",
|
3783 |
+
"SQLite:\n",
|
3784 |
+
"SELECT MAX(team_turnovers_away) FROM other_stats WHERE team_name_away = 'Brooklyn Nets'\n",
|
3785 |
+
"Statement valid? False\n",
|
3786 |
+
"SQLite matched? False\n",
|
3787 |
+
"Result matched? False\n",
|
3788 |
+
"\n",
|
3789 |
+
"\n"
|
3790 |
+
]
|
3791 |
+
},
|
3792 |
+
{
|
3793 |
+
"name": "stderr",
|
3794 |
+
"output_type": "stream",
|
3795 |
+
"text": [
|
3796 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3797 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3798 |
+
]
|
3799 |
+
},
|
3800 |
+
{
|
3801 |
+
"name": "stdout",
|
3802 |
+
"output_type": "stream",
|
3803 |
+
"text": [
|
3804 |
+
"SELECT COUNT(*) FROM game WHERE ABS(pts_home - pts_away) = 1;\n",
|
3805 |
+
"SQLite:\n",
|
3806 |
+
"SELECT COUNT(*) FROM game WHERE pts_home - pts_away = 1\n",
|
3807 |
+
"['3050']\n",
|
3808 |
+
"['1572']\n",
|
3809 |
+
"Statement valid? True\n",
|
3810 |
+
"SQLite matched? False\n",
|
3811 |
+
"Result matched? False\n",
|
3812 |
+
"\n",
|
3813 |
+
"\n"
|
3814 |
+
]
|
3815 |
+
},
|
3816 |
+
{
|
3817 |
+
"name": "stderr",
|
3818 |
+
"output_type": "stream",
|
3819 |
+
"text": [
|
3820 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3821 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3822 |
+
]
|
3823 |
+
},
|
3824 |
+
{
|
3825 |
+
"name": "stdout",
|
3826 |
+
"output_type": "stream",
|
3827 |
+
"text": [
|
3828 |
+
"SELECT COUNT(*) FROM game WHERE team_abbreviation_home = 'LAL' AND season_id = '22010';\n",
|
3829 |
+
"SQLite: SELECT COUNT(*) FROM game WHERE team_name_home = 'Los Angeles Lakers' AND season_id = '22010'\n",
|
3830 |
+
"['41']\n",
|
3831 |
+
"['41']\n",
|
3832 |
+
"Statement valid? True\n",
|
3833 |
+
"SQLite matched? False\n",
|
3834 |
+
"Result matched? True\n",
|
3835 |
+
"\n",
|
3836 |
+
"\n"
|
3837 |
+
]
|
3838 |
+
},
|
3839 |
+
{
|
3840 |
+
"name": "stderr",
|
3841 |
+
"output_type": "stream",
|
3842 |
+
"text": [
|
3843 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3844 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3845 |
+
]
|
3846 |
+
},
|
3847 |
+
{
|
3848 |
+
"name": "stdout",
|
3849 |
+
"output_type": "stream",
|
3850 |
+
"text": [
|
3851 |
+
"SELECT MAX(pts_home) FROM game;\n",
|
3852 |
+
"SQLite:\n",
|
3853 |
+
"SELECT game_id, MAX(pts_home) AS highest_scoring_home_game FROM game GROUP BY game_id ORDER BY highest_scoring_home_game DESC LIMIT 1\n",
|
3854 |
+
"['192.0']\n",
|
3855 |
+
"['0031600001', '192.0']\n",
|
3856 |
+
"Statement valid? True\n",
|
3857 |
+
"SQLite matched? False\n",
|
3858 |
+
"Result matched? True\n",
|
3859 |
+
"\n",
|
3860 |
+
"\n"
|
3861 |
+
]
|
3862 |
+
},
|
3863 |
+
{
|
3864 |
+
"name": "stderr",
|
3865 |
+
"output_type": "stream",
|
3866 |
+
"text": [
|
3867 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3868 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3869 |
+
]
|
3870 |
+
},
|
3871 |
+
{
|
3872 |
+
"name": "stdout",
|
3873 |
+
"output_type": "stream",
|
3874 |
+
"text": [
|
3875 |
+
"SELECT MIN(fg_pct_away) FROM game;\n",
|
3876 |
+
"SQLite:\n",
|
3877 |
+
"SELECT MIN(fg_pct_away) FROM game\n",
|
3878 |
+
"['0.156']\n",
|
3879 |
+
"['0.156']\n",
|
3880 |
+
"Statement valid? True\n",
|
3881 |
+
"SQLite matched? True\n",
|
3882 |
+
"Result matched? True\n",
|
3883 |
+
"\n",
|
3884 |
+
"\n"
|
3885 |
+
]
|
3886 |
+
},
|
3887 |
+
{
|
3888 |
+
"name": "stderr",
|
3889 |
+
"output_type": "stream",
|
3890 |
+
"text": [
|
3891 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3892 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3893 |
+
]
|
3894 |
+
},
|
3895 |
+
{
|
3896 |
+
"name": "stdout",
|
3897 |
+
"output_type": "stream",
|
3898 |
+
"text": [
|
3899 |
+
"SELECT COUNT(*) FROM game WHERE ABS(pts_home - pts_away) > 40;\n",
|
3900 |
+
"SQLite:\n",
|
3901 |
+
"SELECT COUNT(*) FROM game WHERE pts_home > 40 OR pts_away > 40\n",
|
3902 |
+
"['307']\n",
|
3903 |
+
"['65696']\n",
|
3904 |
+
"Statement valid? True\n",
|
3905 |
+
"SQLite matched? False\n",
|
3906 |
+
"Result matched? False\n",
|
3907 |
+
"\n",
|
3908 |
+
"\n"
|
3909 |
+
]
|
3910 |
+
},
|
3911 |
+
{
|
3912 |
+
"name": "stderr",
|
3913 |
+
"output_type": "stream",
|
3914 |
+
"text": [
|
3915 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3916 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3917 |
+
]
|
3918 |
+
},
|
3919 |
+
{
|
3920 |
+
"name": "stdout",
|
3921 |
+
"output_type": "stream",
|
3922 |
+
"text": [
|
3923 |
+
"SELECT MAX(stl_away) FROM game;\n",
|
3924 |
+
"SQLite:\n",
|
3925 |
+
"SELECT MAX(stl_away) FROM game\n",
|
3926 |
+
"['27.0']\n",
|
3927 |
+
"['27.0']\n",
|
3928 |
+
"Statement valid? True\n",
|
3929 |
+
"SQLite matched? True\n",
|
3930 |
+
"Result matched? True\n",
|
3931 |
+
"\n",
|
3932 |
+
"\n"
|
3933 |
+
]
|
3934 |
+
},
|
3935 |
+
{
|
3936 |
+
"name": "stderr",
|
3937 |
+
"output_type": "stream",
|
3938 |
+
"text": [
|
3939 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3940 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3941 |
+
]
|
3942 |
+
},
|
3943 |
+
{
|
3944 |
+
"name": "stdout",
|
3945 |
+
"output_type": "stream",
|
3946 |
+
"text": [
|
3947 |
+
"SELECT COUNT(*) FROM game WHERE pts_home >= 120 AND pts_away >= 120;\n",
|
3948 |
+
"SQLite:\n",
|
3949 |
+
"SELECT COUNT(*) FROM game WHERE pts_home >= 120 AND pts_away >= 120\n",
|
3950 |
+
"['3043']\n",
|
3951 |
+
"['3043']\n",
|
3952 |
+
"Statement valid? True\n",
|
3953 |
+
"SQLite matched? True\n",
|
3954 |
+
"Result matched? True\n",
|
3955 |
+
"\n",
|
3956 |
+
"\n"
|
3957 |
+
]
|
3958 |
+
},
|
3959 |
+
{
|
3960 |
+
"name": "stderr",
|
3961 |
+
"output_type": "stream",
|
3962 |
+
"text": [
|
3963 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
3964 |
+
"Setting `pad_token_id` to `eos_token_id`:32022 for open-end generation.\n"
|
3965 |
+
]
|
3966 |
+
},
|
3967 |
+
{
|
3968 |
+
"name": "stdout",
|
3969 |
+
"output_type": "stream",
|
3970 |
+
"text": [
|
3971 |
+
"SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30;\n",
|
3972 |
+
"SQLite:\n",
|
3973 |
+
"SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30\n",
|
3974 |
+
"['11305']\n",
|
3975 |
+
"['11305']\n",
|
3976 |
+
"Statement valid? True\n",
|
3977 |
+
"SQLite matched? True\n",
|
3978 |
+
"Result matched? True\n",
|
3979 |
+
"\n",
|
3980 |
+
"\n",
|
3981 |
+
"SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n",
|
3982 |
+
"SQLite:\n",
|
3983 |
+
"SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers'\n",
|
3984 |
+
"['44.0']\n",
|
3985 |
+
"['44.0']\n",
|
3986 |
+
"Statement valid? True\n",
|
3987 |
+
"SQLite matched? True\n",
|
3988 |
+
"Result matched? True\n",
|
3989 |
+
"\n",
|
3990 |
+
"\n",
|
3991 |
+
"Percent valid: 0.5862068965517241\n",
|
3992 |
+
"Percent SQLite matched: 0.27586206896551724\n",
|
3993 |
+
"Percent result matched: 0.4482758620689655\n"
|
3994 |
+
]
|
3995 |
+
}
|
3996 |
+
],
|
3997 |
+
"source": [
|
3998 |
+
"num_valid = 0\n",
|
3999 |
+
"num_sql_matched = 0\n",
|
4000 |
+
"num_result_matched = 0\n",
|
4001 |
+
"counter = 0\n",
|
4002 |
+
"for v in val_dataset:\n",
|
4003 |
+
" # Obtain sample natural language question and sql_query\n",
|
4004 |
+
" #v = val_dataset[random.randint(0, len(val_dataset) - 1)]\n",
|
4005 |
+
" full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n",
|
4006 |
+
" user_prompt = full_example[:prompt_length]\n",
|
4007 |
+
" question, sql_query = full_example[prompt_length:].split(\"SQLite:\\n\")\n",
|
4008 |
+
" #print(question)\n",
|
4009 |
+
" #print(sql_query)\n",
|
4010 |
+
"\n",
|
4011 |
+
" if len(sql_query) <= 90:\n",
|
4012 |
+
" # Obtain model output\n",
|
4013 |
+
" input_text = \"How many points to the Los Angeles Lakers average at home?\"\n",
|
4014 |
+
" message = [{'role': 'user', 'content': input_prompt + question}]\n",
|
4015 |
+
" inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
|
4016 |
+
"\n",
|
4017 |
+
" # Generate SQL query\n",
|
4018 |
+
" outputs = model.generate(\n",
|
4019 |
+
" inputs,\n",
|
4020 |
+
" max_new_tokens=256,\n",
|
4021 |
+
" eos_token_id=tokenizer.convert_tokens_to_ids(\"<|endofsql|>\")\n",
|
4022 |
+
" )\n",
|
4023 |
+
" model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
4024 |
+
"\n",
|
4025 |
+
" print(sql_query)\n",
|
4026 |
+
" print(model_output.split(\";\")[0])\n",
|
4027 |
+
" #print()\n",
|
4028 |
+
" #print(model_output)\n",
|
4029 |
+
" result = compare_result(sql_query, model_output)\n",
|
4030 |
+
" print(\"Statement valid? \" + str(result[0]))\n",
|
4031 |
+
" print(\"SQLite matched? \" + str(result[1]))\n",
|
4032 |
+
" print(\"Result matched? \" + str(result[2]))\n",
|
4033 |
+
" print()\n",
|
4034 |
+
" print()\n",
|
4035 |
+
" counter += 1\n",
|
4036 |
+
"\n",
|
4037 |
+
" if result[0]:\n",
|
4038 |
+
" num_valid += 1\n",
|
4039 |
+
" if result[1]:\n",
|
4040 |
+
" num_sql_matched += 1\n",
|
4041 |
+
" if result[2]:\n",
|
4042 |
+
" num_result_matched += 1\n",
|
4043 |
+
"\n",
|
4044 |
+
"print(\"Percent valid: \" + str(num_valid / counter))\n",
|
4045 |
+
"print(\"Percent SQLite matched: \" + str(num_sql_matched / counter))\n",
|
4046 |
+
"print(\"Percent result matched: \" + str(num_result_matched / counter))"
|
4047 |
+
]
|
4048 |
}
|
4049 |
],
|
4050 |
"metadata": {
|