DeanGumas commited on
Commit
a30f35d
·
1 Parent(s): 9f2b199

Adding completed pre-training testing runs to python notebook

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +293 -38
test_pretrained.ipynb CHANGED
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "What was the combined rebound total for the Toronto Raptors and Brooklyn Nets in their highest scoring game against each other?\n",
30
- "SELECT MAX(g.pts_home + g.pts_away) AS total_points, g.reb_home + g.reb_away AS total_rebounds FROM game g WHERE (g.team_name_home = 'Toronto Raptors' AND g.team_name_away = 'Brooklyn Nets') OR (g.team_name_home = 'Brooklyn Nets' AND g.team_name_away = 'Toronto Raptors') ORDER BY total_points DESC LIMIT 1;\n",
31
- "272.0 | 101.0 \n"
32
  ]
33
  }
34
  ],
@@ -83,7 +83,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 3,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
@@ -287,10 +287,9 @@
287
  "output_type": "stream",
288
  "text": [
289
  "SQLite:\n",
290
- "SELECT SUM(reb_home + reb_away) AS combined_rebounds\n",
291
- "FROM game\n",
292
- "WHERE (team_name_home = 'Toronto Raptors' AND team_name_away = 'Brooklyn Nets')\n",
293
- "OR (team_name_home = 'Brooklyn Nets' AND team_name_away = 'Toronto Raptors');\n",
294
  "\n"
295
  ]
296
  }
@@ -323,7 +322,7 @@
323
  "output_type": "stream",
324
  "text": [
325
  "cleaned\n",
326
- "(4350.0,)\n"
327
  ]
328
  }
329
  ],
@@ -368,14 +367,15 @@
368
  "name": "stdout",
369
  "output_type": "stream",
370
  "text": [
371
- "What was the three-point shooting percentage for the Los Angeles Clippers in games against the Los Angeles Lakers?\n",
372
- "SELECT AVG( CASE WHEN team_name_home = 'LA Clippers' THEN fg3_pct_home ELSE fg3_pct_away END ) AS avg_3pt_percentage FROM game WHERE (team_name_home = 'LA Clippers' AND team_name_away = 'Los Angeles Lakers') OR (team_name_home = 'Los Angeles Lakers' AND team_name_away = 'LA Clippers');\n",
373
- "0.3734705882\n",
374
  "SQLite:\n",
375
- "SELECT team_name_home, team_name_away, AVG(fg3_pct_home) AS three_point_percentage\n",
376
- "FROM game\n",
377
- "WHERE team_name_home = 'Los Angeles Clippers' AND team_name_away = 'Los Angeles Lakers'\n",
378
- "GROUP BY team_name_home, team_name_away;\n",
 
379
  "\n",
380
  "Statement valid? True\n",
381
  "SQLite matched? False\n",
@@ -508,20 +508,9 @@
508
  },
509
  {
510
  "cell_type": "code",
511
- "execution_count": 9,
512
  "metadata": {},
513
- "outputs": [
514
- {
515
- "name": "stdout",
516
- "output_type": "stream",
517
- "text": [
518
- "Less than 90 results:\n",
519
- "Percent valid: 0.0653061224489796\n",
520
- "Percent SQLite matched: 0.00816326530612245\n",
521
- "Percent result matched: 0.024489795918367346\n"
522
- ]
523
- }
524
- ],
525
  "source": [
526
  "def run_evaluation(nba_df, title):\n",
527
  " counter = 0\n",
@@ -550,27 +539,293 @@
550
  " counter += 1\n",
551
  " if counter % 50 == 0:\n",
552
  " print(\"Completed \" + str(counter))\n",
553
- " elif counter == 20:\n",
554
- " break\n",
555
  "\n",
556
  " # Print evaluation results\n",
557
- " print(title + \" results:\")\n",
558
  " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
559
  " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
560
- " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))\n",
561
- "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
563
  "run_evaluation(less_than_90_df, \"Less than 90\")\n",
564
- "\n",
565
- "# Run evaluation on all training data\n",
566
- "#run_evaluation(df, \"All training data\")"
567
  ]
568
  },
569
  {
570
  "cell_type": "markdown",
571
  "metadata": {},
572
  "source": [
573
- "# Evaluate on less than 90 dataset"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  ]
575
  }
576
  ],
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "What is the average number of tov in home games by the Miami Heat?\n",
30
+ "SELECT AVG(tov_home) FROM game WHERE team_name_home = 'Miami Heat';\n",
31
+ "14.627184466019418\n"
32
  ]
33
  }
34
  ],
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": null,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
 
287
  "output_type": "stream",
288
  "text": [
289
  "SQLite:\n",
290
+ "SELECT AVG(tov_home) \n",
291
+ "FROM game \n",
292
+ "WHERE team_name_home = 'Miami Heat';\n",
 
293
  "\n"
294
  ]
295
  }
 
322
  "output_type": "stream",
323
  "text": [
324
  "cleaned\n",
325
+ "(14.627184466019418,)\n"
326
  ]
327
  }
328
  ],
 
367
  "name": "stdout",
368
  "output_type": "stream",
369
  "text": [
370
+ "How many times have the Houston Rockets won an away game while scoring at least 110 points?\n",
371
+ "SELECT COUNT(*) FROM game WHERE team_abbreviation_away = 'HOU' AND pts_away >= 110 AND wl_away = 'W';\n",
372
+ "425\n",
373
  "SQLite:\n",
374
+ "SELECT COUNT(*) \n",
375
+ "FROM game \n",
376
+ "WHERE team_name_away = 'Houston Rockets' \n",
377
+ "AND wl_away = 'W' \n",
378
+ "AND pts_away >= 110;\n",
379
  "\n",
380
  "Statement valid? True\n",
381
  "SQLite matched? False\n",
 
508
  },
509
  {
510
  "cell_type": "code",
511
+ "execution_count": 7,
512
  "metadata": {},
513
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
514
  "source": [
515
  "def run_evaluation(nba_df, title):\n",
516
  " counter = 0\n",
 
539
  " counter += 1\n",
540
  " if counter % 50 == 0:\n",
541
  " print(\"Completed \" + str(counter))\n",
 
 
542
  "\n",
543
  " # Print evaluation results\n",
544
+ " print(\"\\n\" + title + \" results:\")\n",
545
  " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
546
  " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
547
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "markdown",
552
+ "metadata": {},
553
+ "source": [
554
+ "# Evaluate on less than 90 dataset"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": 8,
560
+ "metadata": {},
561
+ "outputs": [
562
+ {
563
+ "name": "stdout",
564
+ "output_type": "stream",
565
+ "text": [
566
+ "Completed 50\n",
567
+ "Completed 100\n",
568
+ "Completed 150\n",
569
+ "Completed 200\n",
570
+ "\n",
571
+ "Less than 90 results:\n",
572
+ "Percent valid: 0.8612244897959184\n",
573
+ "Percent SQLite matched: 0.4163265306122449\n",
574
+ "Percent result matched: 0.6530612244897959\n",
575
+ "Dataset length: 245\n"
576
+ ]
577
+ }
578
+ ],
579
+ "source": [
580
  "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
581
  "run_evaluation(less_than_90_df, \"Less than 90\")\n",
582
+ "print(\"Dataset length: \" + str(len(less_than_90_df)))"
 
 
583
  ]
584
  },
585
  {
586
  "cell_type": "markdown",
587
  "metadata": {},
588
  "source": [
589
+ "# Evaluate on game table queries"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 9,
595
+ "metadata": {},
596
+ "outputs": [
597
+ {
598
+ "name": "stdout",
599
+ "output_type": "stream",
600
+ "text": [
601
+ "Completed 50\n",
602
+ "Completed 100\n",
603
+ "Completed 150\n",
604
+ "Completed 200\n",
605
+ "Completed 250\n",
606
+ "Completed 300\n",
607
+ "Completed 350\n",
608
+ "Completed 400\n",
609
+ "Completed 450\n",
610
+ "Completed 500\n",
611
+ "Completed 550\n",
612
+ "Completed 600\n",
613
+ "Completed 650\n",
614
+ "Completed 700\n",
615
+ "Completed 750\n",
616
+ "Completed 800\n",
617
+ "\n",
618
+ "Queries from game results:\n",
619
+ "Percent valid: 0.7708830548926014\n",
620
+ "Percent SQLite matched: 0.1431980906921241\n",
621
+ "Percent result matched: 0.40692124105011934\n",
622
+ "Dataset length: 838\n"
623
+ ]
624
+ }
625
+ ],
626
+ "source": [
627
+ "game_queries = pd.read_csv(\"./train-data/queries_from_game.tsv\", sep='\\t')\n",
628
+ "run_evaluation(game_queries, \"Queries from game\")\n",
629
+ "print(\"Dataset length: \" + str(len(game_queries)))"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "markdown",
634
+ "metadata": {},
635
+ "source": [
636
+ "## Evaluate on other stats queries"
637
+ ]
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": 10,
642
+ "metadata": {},
643
+ "outputs": [
644
+ {
645
+ "name": "stdout",
646
+ "output_type": "stream",
647
+ "text": [
648
+ "Completed 50\n",
649
+ "Completed 100\n",
650
+ "Completed 150\n",
651
+ "\n",
652
+ "Queries from other stats results:\n",
653
+ "Percent valid: 0.07792207792207792\n",
654
+ "Percent SQLite matched: 0.0\n",
655
+ "Percent result matched: 0.0\n",
656
+ "Dataset length: 154\n"
657
+ ]
658
+ }
659
+ ],
660
+ "source": [
661
+ "other_stats_queries = pd.read_csv(\"./train-data/queries_from_other_stats.tsv\", sep='\\t')\n",
662
+ "run_evaluation(other_stats_queries, \"Queries from other stats\")\n",
663
+ "print(\"Dataset length: \" + str(len(other_stats_queries)))"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "markdown",
668
+ "metadata": {},
669
+ "source": [
670
+ "## Evaluate on team queries"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "code",
675
+ "execution_count": 11,
676
+ "metadata": {},
677
+ "outputs": [
678
+ {
679
+ "name": "stdout",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "Completed 50\n",
683
+ "\n",
684
+ "Queries from team results:\n",
685
+ "Percent valid: 0.75\n",
686
+ "Percent SQLite matched: 0.2692307692307692\n",
687
+ "Percent result matched: 0.6153846153846154\n",
688
+ "Dataset length: 52\n"
689
+ ]
690
+ }
691
+ ],
692
+ "source": [
693
+ "team_queries = pd.read_csv(\"./train-data/queries_from_team.tsv\", sep='\\t')\n",
694
+ "run_evaluation(team_queries, \"Queries from team\")\n",
695
+ "print(\"Dataset length: \" + str(len(team_queries)))"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "markdown",
700
+ "metadata": {},
701
+ "source": [
702
+ "## Evaluate on queries requiring join statements"
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "code",
707
+ "execution_count": 12,
708
+ "metadata": {},
709
+ "outputs": [
710
+ {
711
+ "name": "stdout",
712
+ "output_type": "stream",
713
+ "text": [
714
+ "Completed 50\n",
715
+ "Completed 100\n",
716
+ "Completed 150\n",
717
+ "\n",
718
+ "Queries with join results:\n",
719
+ "Percent valid: 0.06486486486486487\n",
720
+ "Percent SQLite matched: 0.0\n",
721
+ "Percent result matched: 0.010810810810810811\n",
722
+ "Dataset length: 185\n"
723
+ ]
724
+ }
725
+ ],
726
+ "source": [
727
+ "join_queries = pd.read_csv(\"./train-data/with_join.tsv\", sep='\\t')\n",
728
+ "run_evaluation(join_queries, \"Queries with join\")\n",
729
+ "print(\"Dataset length: \" + str(len(join_queries)))"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "markdown",
734
+ "metadata": {},
735
+ "source": [
736
+ "## Evaluate on queries not requiring join statements"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": 13,
742
+ "metadata": {},
743
+ "outputs": [
744
+ {
745
+ "name": "stdout",
746
+ "output_type": "stream",
747
+ "text": [
748
+ "Completed 50\n",
749
+ "Completed 100\n",
750
+ "Completed 150\n",
751
+ "Completed 200\n",
752
+ "Completed 250\n",
753
+ "Completed 300\n",
754
+ "Completed 350\n",
755
+ "Completed 400\n",
756
+ "Completed 450\n",
757
+ "Completed 500\n",
758
+ "Completed 550\n",
759
+ "Completed 600\n",
760
+ "Completed 650\n",
761
+ "Completed 700\n",
762
+ "Completed 750\n",
763
+ "Completed 800\n",
764
+ "Completed 850\n",
765
+ "\n",
766
+ "Queries without join results:\n",
767
+ "Percent valid: 0.7974388824214202\n",
768
+ "Percent SQLite matched: 0.1559953434225844\n",
769
+ "Percent result matched: 0.4318975552968568\n",
770
+ "Dataset length: 859\n"
771
+ ]
772
+ }
773
+ ],
774
+ "source": [
775
+ "no_join_queries = pd.read_csv(\"./train-data/without_join.tsv\", sep='\\t')\n",
776
+ "run_evaluation(no_join_queries, \"Queries without join\")\n",
777
+ "print(\"Dataset length: \" + str(len(no_join_queries)))"
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "markdown",
782
+ "metadata": {},
783
+ "source": [
784
+ "## Evaluate on full training dataset"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": 14,
790
+ "metadata": {},
791
+ "outputs": [
792
+ {
793
+ "name": "stdout",
794
+ "output_type": "stream",
795
+ "text": [
796
+ "Completed 50\n",
797
+ "Completed 100\n",
798
+ "Completed 150\n",
799
+ "Completed 200\n",
800
+ "Completed 250\n",
801
+ "Completed 300\n",
802
+ "Completed 350\n",
803
+ "Completed 400\n",
804
+ "Completed 450\n",
805
+ "Completed 500\n",
806
+ "Completed 550\n",
807
+ "Completed 600\n",
808
+ "Completed 650\n",
809
+ "Completed 700\n",
810
+ "Completed 750\n",
811
+ "Completed 800\n",
812
+ "Completed 850\n",
813
+ "Completed 900\n",
814
+ "Completed 950\n",
815
+ "Completed 1000\n",
816
+ "\n",
817
+ "All training data results:\n",
818
+ "Percent valid: 0.6676245210727969\n",
819
+ "Percent SQLite matched: 0.12835249042145594\n",
820
+ "Percent result matched: 0.35823754789272033\n",
821
+ "Dataset length: 1044\n"
822
+ ]
823
+ }
824
+ ],
825
+ "source": [
826
+ "# Run evaluation on all training data\n",
827
+ "run_evaluation(df, \"All training data\")\n",
828
+ "print(\"Dataset length: \" + str(len(df)))"
829
  ]
830
  }
831
  ],