Adding completed pre-training testing runs to python notebook
Browse files- test_pretrained.ipynb +293 -38
test_pretrained.ipynb
CHANGED
@@ -26,9 +26,9 @@
|
|
26 |
"Total dataset examples: 1044\n",
|
27 |
"\n",
|
28 |
"\n",
|
29 |
-
"What
|
30 |
-
"SELECT
|
31 |
-
"
|
32 |
]
|
33 |
}
|
34 |
],
|
@@ -83,7 +83,7 @@
|
|
83 |
},
|
84 |
{
|
85 |
"cell_type": "code",
|
86 |
-
"execution_count":
|
87 |
"metadata": {},
|
88 |
"outputs": [],
|
89 |
"source": [
|
@@ -287,10 +287,9 @@
|
|
287 |
"output_type": "stream",
|
288 |
"text": [
|
289 |
"SQLite:\n",
|
290 |
-
"SELECT
|
291 |
-
"FROM game\n",
|
292 |
-
"WHERE
|
293 |
-
"OR (team_name_home = 'Brooklyn Nets' AND team_name_away = 'Toronto Raptors');\n",
|
294 |
"\n"
|
295 |
]
|
296 |
}
|
@@ -323,7 +322,7 @@
|
|
323 |
"output_type": "stream",
|
324 |
"text": [
|
325 |
"cleaned\n",
|
326 |
-
"(
|
327 |
]
|
328 |
}
|
329 |
],
|
@@ -368,14 +367,15 @@
|
|
368 |
"name": "stdout",
|
369 |
"output_type": "stream",
|
370 |
"text": [
|
371 |
-
"
|
372 |
-
"SELECT
|
373 |
-
"
|
374 |
"SQLite:\n",
|
375 |
-
"SELECT
|
376 |
-
"FROM game\n",
|
377 |
-
"WHERE
|
378 |
-
"
|
|
|
379 |
"\n",
|
380 |
"Statement valid? True\n",
|
381 |
"SQLite matched? False\n",
|
@@ -508,20 +508,9 @@
|
|
508 |
},
|
509 |
{
|
510 |
"cell_type": "code",
|
511 |
-
"execution_count":
|
512 |
"metadata": {},
|
513 |
-
"outputs": [
|
514 |
-
{
|
515 |
-
"name": "stdout",
|
516 |
-
"output_type": "stream",
|
517 |
-
"text": [
|
518 |
-
"Less than 90 results:\n",
|
519 |
-
"Percent valid: 0.0653061224489796\n",
|
520 |
-
"Percent SQLite matched: 0.00816326530612245\n",
|
521 |
-
"Percent result matched: 0.024489795918367346\n"
|
522 |
-
]
|
523 |
-
}
|
524 |
-
],
|
525 |
"source": [
|
526 |
"def run_evaluation(nba_df, title):\n",
|
527 |
" counter = 0\n",
|
@@ -550,27 +539,293 @@
|
|
550 |
" counter += 1\n",
|
551 |
" if counter % 50 == 0:\n",
|
552 |
" print(\"Completed \" + str(counter))\n",
|
553 |
-
" elif counter == 20:\n",
|
554 |
-
" break\n",
|
555 |
"\n",
|
556 |
" # Print evaluation results\n",
|
557 |
-
" print(title + \" results:\")\n",
|
558 |
" print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
|
559 |
" print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
|
560 |
-
" print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))
|
561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
"less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
|
563 |
"run_evaluation(less_than_90_df, \"Less than 90\")\n",
|
564 |
-
"\
|
565 |
-
"# Run evaluation on all training data\n",
|
566 |
-
"#run_evaluation(df, \"All training data\")"
|
567 |
]
|
568 |
},
|
569 |
{
|
570 |
"cell_type": "markdown",
|
571 |
"metadata": {},
|
572 |
"source": [
|
573 |
-
"# Evaluate on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
]
|
575 |
}
|
576 |
],
|
|
|
26 |
"Total dataset examples: 1044\n",
|
27 |
"\n",
|
28 |
"\n",
|
29 |
+
"What is the average number of tov in home games by the Miami Heat?\n",
|
30 |
+
"SELECT AVG(tov_home) FROM game WHERE team_name_home = 'Miami Heat';\n",
|
31 |
+
"14.627184466019418\n"
|
32 |
]
|
33 |
}
|
34 |
],
|
|
|
83 |
},
|
84 |
{
|
85 |
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
"metadata": {},
|
88 |
"outputs": [],
|
89 |
"source": [
|
|
|
287 |
"output_type": "stream",
|
288 |
"text": [
|
289 |
"SQLite:\n",
|
290 |
+
"SELECT AVG(tov_home) \n",
|
291 |
+
"FROM game \n",
|
292 |
+
"WHERE team_name_home = 'Miami Heat';\n",
|
|
|
293 |
"\n"
|
294 |
]
|
295 |
}
|
|
|
322 |
"output_type": "stream",
|
323 |
"text": [
|
324 |
"cleaned\n",
|
325 |
+
"(14.627184466019418,)\n"
|
326 |
]
|
327 |
}
|
328 |
],
|
|
|
367 |
"name": "stdout",
|
368 |
"output_type": "stream",
|
369 |
"text": [
|
370 |
+
"How many times have the Houston Rockets won an away game while scoring at least 110 points?\n",
|
371 |
+
"SELECT COUNT(*) FROM game WHERE team_abbreviation_away = 'HOU' AND pts_away >= 110 AND wl_away = 'W';\n",
|
372 |
+
"425\n",
|
373 |
"SQLite:\n",
|
374 |
+
"SELECT COUNT(*) \n",
|
375 |
+
"FROM game \n",
|
376 |
+
"WHERE team_name_away = 'Houston Rockets' \n",
|
377 |
+
"AND wl_away = 'W' \n",
|
378 |
+
"AND pts_away >= 110;\n",
|
379 |
"\n",
|
380 |
"Statement valid? True\n",
|
381 |
"SQLite matched? False\n",
|
|
|
508 |
},
|
509 |
{
|
510 |
"cell_type": "code",
|
511 |
+
"execution_count": 7,
|
512 |
"metadata": {},
|
513 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
"source": [
|
515 |
"def run_evaluation(nba_df, title):\n",
|
516 |
" counter = 0\n",
|
|
|
539 |
" counter += 1\n",
|
540 |
" if counter % 50 == 0:\n",
|
541 |
" print(\"Completed \" + str(counter))\n",
|
|
|
|
|
542 |
"\n",
|
543 |
" # Print evaluation results\n",
|
544 |
+
" print(\"\\n\" + title + \" results:\")\n",
|
545 |
" print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
|
546 |
" print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
|
547 |
+
" print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))"
|
548 |
+
]
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"cell_type": "markdown",
|
552 |
+
"metadata": {},
|
553 |
+
"source": [
|
554 |
+
"# Evaluate on less than 90 dataset"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "code",
|
559 |
+
"execution_count": 8,
|
560 |
+
"metadata": {},
|
561 |
+
"outputs": [
|
562 |
+
{
|
563 |
+
"name": "stdout",
|
564 |
+
"output_type": "stream",
|
565 |
+
"text": [
|
566 |
+
"Completed 50\n",
|
567 |
+
"Completed 100\n",
|
568 |
+
"Completed 150\n",
|
569 |
+
"Completed 200\n",
|
570 |
+
"\n",
|
571 |
+
"Less than 90 results:\n",
|
572 |
+
"Percent valid: 0.8612244897959184\n",
|
573 |
+
"Percent SQLite matched: 0.4163265306122449\n",
|
574 |
+
"Percent result matched: 0.6530612244897959\n",
|
575 |
+
"Dataset length: 245\n"
|
576 |
+
]
|
577 |
+
}
|
578 |
+
],
|
579 |
+
"source": [
|
580 |
"less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
|
581 |
"run_evaluation(less_than_90_df, \"Less than 90\")\n",
|
582 |
+
"print(\"Dataset length: \" + str(len(less_than_90_df)))"
|
|
|
|
|
583 |
]
|
584 |
},
|
585 |
{
|
586 |
"cell_type": "markdown",
|
587 |
"metadata": {},
|
588 |
"source": [
|
589 |
+
"# Evaluate on game table queries"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"cell_type": "code",
|
594 |
+
"execution_count": 9,
|
595 |
+
"metadata": {},
|
596 |
+
"outputs": [
|
597 |
+
{
|
598 |
+
"name": "stdout",
|
599 |
+
"output_type": "stream",
|
600 |
+
"text": [
|
601 |
+
"Completed 50\n",
|
602 |
+
"Completed 100\n",
|
603 |
+
"Completed 150\n",
|
604 |
+
"Completed 200\n",
|
605 |
+
"Completed 250\n",
|
606 |
+
"Completed 300\n",
|
607 |
+
"Completed 350\n",
|
608 |
+
"Completed 400\n",
|
609 |
+
"Completed 450\n",
|
610 |
+
"Completed 500\n",
|
611 |
+
"Completed 550\n",
|
612 |
+
"Completed 600\n",
|
613 |
+
"Completed 650\n",
|
614 |
+
"Completed 700\n",
|
615 |
+
"Completed 750\n",
|
616 |
+
"Completed 800\n",
|
617 |
+
"\n",
|
618 |
+
"Queries from game results:\n",
|
619 |
+
"Percent valid: 0.7708830548926014\n",
|
620 |
+
"Percent SQLite matched: 0.1431980906921241\n",
|
621 |
+
"Percent result matched: 0.40692124105011934\n",
|
622 |
+
"Dataset length: 838\n"
|
623 |
+
]
|
624 |
+
}
|
625 |
+
],
|
626 |
+
"source": [
|
627 |
+
"game_queries = pd.read_csv(\"./train-data/queries_from_game.tsv\", sep='\\t')\n",
|
628 |
+
"run_evaluation(game_queries, \"Queries from game\")\n",
|
629 |
+
"print(\"Dataset length: \" + str(len(game_queries)))"
|
630 |
+
]
|
631 |
+
},
|
632 |
+
{
|
633 |
+
"cell_type": "markdown",
|
634 |
+
"metadata": {},
|
635 |
+
"source": [
|
636 |
+
"## Evaluate on other stats queries"
|
637 |
+
]
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "code",
|
641 |
+
"execution_count": 10,
|
642 |
+
"metadata": {},
|
643 |
+
"outputs": [
|
644 |
+
{
|
645 |
+
"name": "stdout",
|
646 |
+
"output_type": "stream",
|
647 |
+
"text": [
|
648 |
+
"Completed 50\n",
|
649 |
+
"Completed 100\n",
|
650 |
+
"Completed 150\n",
|
651 |
+
"\n",
|
652 |
+
"Queries from other stats results:\n",
|
653 |
+
"Percent valid: 0.07792207792207792\n",
|
654 |
+
"Percent SQLite matched: 0.0\n",
|
655 |
+
"Percent result matched: 0.0\n",
|
656 |
+
"Dataset length: 154\n"
|
657 |
+
]
|
658 |
+
}
|
659 |
+
],
|
660 |
+
"source": [
|
661 |
+
"other_stats_queries = pd.read_csv(\"./train-data/queries_from_other_stats.tsv\", sep='\\t')\n",
|
662 |
+
"run_evaluation(other_stats_queries, \"Queries from other stats\")\n",
|
663 |
+
"print(\"Dataset length: \" + str(len(other_stats_queries)))"
|
664 |
+
]
|
665 |
+
},
|
666 |
+
{
|
667 |
+
"cell_type": "markdown",
|
668 |
+
"metadata": {},
|
669 |
+
"source": [
|
670 |
+
"## Evaluate on team queries"
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "code",
|
675 |
+
"execution_count": 11,
|
676 |
+
"metadata": {},
|
677 |
+
"outputs": [
|
678 |
+
{
|
679 |
+
"name": "stdout",
|
680 |
+
"output_type": "stream",
|
681 |
+
"text": [
|
682 |
+
"Completed 50\n",
|
683 |
+
"\n",
|
684 |
+
"Queries from team results:\n",
|
685 |
+
"Percent valid: 0.75\n",
|
686 |
+
"Percent SQLite matched: 0.2692307692307692\n",
|
687 |
+
"Percent result matched: 0.6153846153846154\n",
|
688 |
+
"Dataset length: 52\n"
|
689 |
+
]
|
690 |
+
}
|
691 |
+
],
|
692 |
+
"source": [
|
693 |
+
"team_queries = pd.read_csv(\"./train-data/queries_from_team.tsv\", sep='\\t')\n",
|
694 |
+
"run_evaluation(team_queries, \"Queries from team\")\n",
|
695 |
+
"print(\"Dataset length: \" + str(len(team_queries)))"
|
696 |
+
]
|
697 |
+
},
|
698 |
+
{
|
699 |
+
"cell_type": "markdown",
|
700 |
+
"metadata": {},
|
701 |
+
"source": [
|
702 |
+
"## Evaluate on queries requiring join statements"
|
703 |
+
]
|
704 |
+
},
|
705 |
+
{
|
706 |
+
"cell_type": "code",
|
707 |
+
"execution_count": 12,
|
708 |
+
"metadata": {},
|
709 |
+
"outputs": [
|
710 |
+
{
|
711 |
+
"name": "stdout",
|
712 |
+
"output_type": "stream",
|
713 |
+
"text": [
|
714 |
+
"Completed 50\n",
|
715 |
+
"Completed 100\n",
|
716 |
+
"Completed 150\n",
|
717 |
+
"\n",
|
718 |
+
"Queries with join results:\n",
|
719 |
+
"Percent valid: 0.06486486486486487\n",
|
720 |
+
"Percent SQLite matched: 0.0\n",
|
721 |
+
"Percent result matched: 0.010810810810810811\n",
|
722 |
+
"Dataset length: 185\n"
|
723 |
+
]
|
724 |
+
}
|
725 |
+
],
|
726 |
+
"source": [
|
727 |
+
"join_queries = pd.read_csv(\"./train-data/with_join.tsv\", sep='\\t')\n",
|
728 |
+
"run_evaluation(join_queries, \"Queries with join\")\n",
|
729 |
+
"print(\"Dataset length: \" + str(len(join_queries)))"
|
730 |
+
]
|
731 |
+
},
|
732 |
+
{
|
733 |
+
"cell_type": "markdown",
|
734 |
+
"metadata": {},
|
735 |
+
"source": [
|
736 |
+
"## Evaluate on queries not requiring join statements"
|
737 |
+
]
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"cell_type": "code",
|
741 |
+
"execution_count": 13,
|
742 |
+
"metadata": {},
|
743 |
+
"outputs": [
|
744 |
+
{
|
745 |
+
"name": "stdout",
|
746 |
+
"output_type": "stream",
|
747 |
+
"text": [
|
748 |
+
"Completed 50\n",
|
749 |
+
"Completed 100\n",
|
750 |
+
"Completed 150\n",
|
751 |
+
"Completed 200\n",
|
752 |
+
"Completed 250\n",
|
753 |
+
"Completed 300\n",
|
754 |
+
"Completed 350\n",
|
755 |
+
"Completed 400\n",
|
756 |
+
"Completed 450\n",
|
757 |
+
"Completed 500\n",
|
758 |
+
"Completed 550\n",
|
759 |
+
"Completed 600\n",
|
760 |
+
"Completed 650\n",
|
761 |
+
"Completed 700\n",
|
762 |
+
"Completed 750\n",
|
763 |
+
"Completed 800\n",
|
764 |
+
"Completed 850\n",
|
765 |
+
"\n",
|
766 |
+
"Queries without join results:\n",
|
767 |
+
"Percent valid: 0.7974388824214202\n",
|
768 |
+
"Percent SQLite matched: 0.1559953434225844\n",
|
769 |
+
"Percent result matched: 0.4318975552968568\n",
|
770 |
+
"Dataset length: 859\n"
|
771 |
+
]
|
772 |
+
}
|
773 |
+
],
|
774 |
+
"source": [
|
775 |
+
"no_join_queries = pd.read_csv(\"./train-data/without_join.tsv\", sep='\\t')\n",
|
776 |
+
"run_evaluation(no_join_queries, \"Queries without join\")\n",
|
777 |
+
"print(\"Dataset length: \" + str(len(no_join_queries)))"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "markdown",
|
782 |
+
"metadata": {},
|
783 |
+
"source": [
|
784 |
+
"## Evaluate on full training dataset"
|
785 |
+
]
|
786 |
+
},
|
787 |
+
{
|
788 |
+
"cell_type": "code",
|
789 |
+
"execution_count": 14,
|
790 |
+
"metadata": {},
|
791 |
+
"outputs": [
|
792 |
+
{
|
793 |
+
"name": "stdout",
|
794 |
+
"output_type": "stream",
|
795 |
+
"text": [
|
796 |
+
"Completed 50\n",
|
797 |
+
"Completed 100\n",
|
798 |
+
"Completed 150\n",
|
799 |
+
"Completed 200\n",
|
800 |
+
"Completed 250\n",
|
801 |
+
"Completed 300\n",
|
802 |
+
"Completed 350\n",
|
803 |
+
"Completed 400\n",
|
804 |
+
"Completed 450\n",
|
805 |
+
"Completed 500\n",
|
806 |
+
"Completed 550\n",
|
807 |
+
"Completed 600\n",
|
808 |
+
"Completed 650\n",
|
809 |
+
"Completed 700\n",
|
810 |
+
"Completed 750\n",
|
811 |
+
"Completed 800\n",
|
812 |
+
"Completed 850\n",
|
813 |
+
"Completed 900\n",
|
814 |
+
"Completed 950\n",
|
815 |
+
"Completed 1000\n",
|
816 |
+
"\n",
|
817 |
+
"All training data results:\n",
|
818 |
+
"Percent valid: 0.6676245210727969\n",
|
819 |
+
"Percent SQLite matched: 0.12835249042145594\n",
|
820 |
+
"Percent result matched: 0.35823754789272033\n",
|
821 |
+
"Dataset length: 1044\n"
|
822 |
+
]
|
823 |
+
}
|
824 |
+
],
|
825 |
+
"source": [
|
826 |
+
"# Run evaluation on all training data\n",
|
827 |
+
"run_evaluation(df, \"All training data\")\n",
|
828 |
+
"print(\"Dataset length: \" + str(len(df)))"
|
829 |
]
|
830 |
}
|
831 |
],
|