licesma commited on
Commit
1860a94
·
1 Parent(s): 2ea9086

Add second RAG notebook

Browse files
src/rag/table_documents.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ team_table_document = '''team Table
2
+ Stores information about NBA teams.
3
+ CREATE TABLE IF NOT EXISTS "team" (
4
+ "id" TEXT PRIMARY KEY, -- Unique identifier for the team
5
+ "full_name" TEXT, -- Full official name of the team (e.g., "Los Angeles Lakers")
6
+ "abbreviation" TEXT, -- Shortened team name (e.g., "LAL")
7
+ "nickname" TEXT, -- Commonly used nickname for the team (e.g., "Lakers")
8
+ "city" TEXT, -- City where the team is based
9
+ "state" TEXT, -- State where the team is located
10
+ "year_founded" REAL -- Year the team was established
11
+ );'''
12
+
13
+ game_table_document = '''game Table
14
+ Contains detailed statistics for each NBA game, including home and away team performance.
15
+ CREATE TABLE IF NOT EXISTS "game" (
16
+ "season_id" TEXT, -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season)
17
+ "team_id_home" TEXT, -- ID of the home team (matches "id" in team table)
18
+ "team_abbreviation_home" TEXT, -- Abbreviation of the home team
19
+ "team_name_home" TEXT, -- Full name of the home team
20
+ "game_id" TEXT PRIMARY KEY, -- Unique identifier for the game
21
+ "game_date" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)
22
+ "matchup_home" TEXT, -- Matchup details including opponent (e.g., "LAL vs. BOS")
23
+ "wl_home" TEXT, -- "W" if the home team won, "L" if they lost
24
+ "min" INTEGER, -- Total minutes played in the game
25
+ "fgm_home" REAL, -- Field goals made by the home team
26
+ "fga_home" REAL, -- Field goals attempted by the home team
27
+ "fg_pct_home" REAL, -- Field goal percentage of the home team
28
+ "fg3m_home" REAL, -- Three-point field goals made by the home team
29
+ "fg3a_home" REAL, -- Three-point attempts by the home team
30
+ "fg3_pct_home" REAL, -- Three-point field goal percentage of the home team
31
+ "ftm_home" REAL, -- Free throws made by the home team
32
+ "fta_home" REAL, -- Free throws attempted by the home team
33
+ "ft_pct_home" REAL, -- Free throw percentage of the home team
34
+ "oreb_home" REAL, -- Offensive rebounds by the home team
35
+ "dreb_home" REAL, -- Defensive rebounds by the home team
36
+ "reb_home" REAL, -- Total rebounds by the home team
37
+ "ast_home" REAL, -- Assists by the home team
38
+ "stl_home" REAL, -- Steals by the home team
39
+ "blk_home" REAL, -- Blocks by the home team
40
+ "tov_home" REAL, -- Turnovers by the home team
41
+ "pf_home" REAL, -- Personal fouls by the home team
42
+ "pts_home" REAL, -- Total points scored by the home team
43
+ "plus_minus_home" INTEGER, -- Plus/minus rating for the home team
44
+ "video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
45
+ "team_id_away" TEXT, -- ID of the away team
46
+ "team_abbreviation_away" TEXT, -- Abbreviation of the away team
47
+ "team_name_away" TEXT, -- Full name of the away team
48
+ "matchup_away" TEXT, -- Matchup details from the away team’s perspective
49
+ "wl_away" TEXT, -- "W" if the away team won, "L" if they lost
50
+ "fgm_away" REAL, -- Field goals made by the away team
51
+ "fga_away" REAL, -- Field goals attempted by the away team
52
+ "fg_pct_away" REAL, -- Field goal percentage of the away team
53
+ "fg3m_away" REAL, -- Three-point field goals made by the away team
54
+ "fg3a_away" REAL, -- Three-point attempts by the away team
55
+ "fg3_pct_away" REAL, -- Three-point field goal percentage of the away team
56
+ "ftm_away" REAL, -- Free throws made by the away team
57
+ "fta_away" REAL, -- Free throws attempted by the away team
58
+ "ft_pct_away" REAL, -- Free throw percentage of the away team
59
+ "oreb_away" REAL, -- Offensive rebounds by the away team
60
+ "dreb_away" REAL, -- Defensive rebounds by the away team
61
+ "reb_away" REAL, -- Total rebounds by the away team
62
+ "ast_away" REAL, -- Assists by the away team
63
+ "stl_away" REAL, -- Steals by the away team
64
+ "blk_away" REAL, -- Blocks by the away team
65
+ "tov_away" REAL, -- Turnovers by the away team
66
+ "pf_away" REAL, -- Personal fouls by the away team
67
+ "pts_away" REAL, -- Total points scored by the away team
68
+ "plus_minus_away" INTEGER, -- Plus/minus rating for the away team
69
+ "video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
70
+ "season_type" TEXT -- Regular season or playoffs
71
+ );
72
+ '''
73
+
74
+ other_stats_table_document = '''other_stats Table
75
+ Stores additional statistics, linked to the game table via game_id.
76
+ CREATE TABLE IF NOT EXISTS "other_stats" (
77
+ "game_id" TEXT, -- Unique game identifier, matches id column from game table
78
+ "league_id" TEXT, -- League identifier
79
+ "team_id_home" TEXT, -- Home team identifier
80
+ "team_abbreviation_home" TEXT, -- Home team abbreviation
81
+ "team_city_home" TEXT, -- Home team city
82
+ "pts_paint_home" INTEGER, -- Points in the paint by the home team
83
+ "pts_2nd_chance_home" INTEGER, -- Second chance points by the home team
84
+ "pts_fb_home" INTEGER, -- Fast break points by the home team
85
+ "largest_lead_home" INTEGER,-- Largest lead by the home team
86
+ "lead_changes" INTEGER, -- Number of lead changes
87
+ "times_tied" INTEGER, -- Number of times the score was tied
88
+ "team_turnovers_home" INTEGER, -- Home team turnovers
89
+ "total_turnovers_home" INTEGER, -- Total turnovers by the home team
90
+ "team_rebounds_home" INTEGER, -- Home team rebounds
91
+ "pts_off_to_home" INTEGER, -- Points off turnovers by the home team
92
+ "team_id_away" TEXT, -- Away team identifier
93
+ "team_abbreviation_away" TEXT, -- Away team abbreviation
94
+ "pts_paint_away" INTEGER, -- Points in the paint by the away team
95
+ "pts_2nd_chance_away" INTEGER, -- Second chance points by the away team
96
+ "pts_fb_away" INTEGER, -- Fast break points by the away team
97
+ "largest_lead_away" INTEGER,-- Largest lead by the away team
98
+ "team_turnovers_away" INTEGER, -- Away team turnovers
99
+ "total_turnovers_away" INTEGER, -- Total turnovers by the away team
100
+ "team_rebounds_away" INTEGER, -- Away team rebounds
101
+ "pts_off_to_away" INTEGER -- Points off turnovers by the away team
102
+ );
103
+ '''
104
+
src/rag/table_retriever.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ team_table_document = '''team Table
2
+ Stores information about NBA teams.
3
+ CREATE TABLE IF NOT EXISTS "team" (
4
+ "id" TEXT PRIMARY KEY, -- Unique identifier for the team
5
+ "full_name" TEXT, -- Full official name of the team (e.g., "Los Angeles Lakers")
6
+ "abbreviation" TEXT, -- Shortened team name (e.g., "LAL")
7
+ "nickname" TEXT, -- Commonly used nickname for the team (e.g., "Lakers")
8
+ "city" TEXT, -- City where the team is based
9
+ "state" TEXT, -- State where the team is located
10
+ "year_founded" REAL -- Year the team was established
11
+ );'''
12
+
13
+ game_table_document = '''game Table
14
+ Contains detailed statistics for each NBA game, including home and away team performance.
15
+ CREATE TABLE IF NOT EXISTS "game" (
16
+ "season_id" TEXT, -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season)
17
+ "team_id_home" TEXT, -- ID of the home team (matches "id" in team table)
18
+ "team_abbreviation_home" TEXT, -- Abbreviation of the home team
19
+ "team_name_home" TEXT, -- Full name of the home team
20
+ "game_id" TEXT PRIMARY KEY, -- Unique identifier for the game
21
+ "game_date" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)
22
+ "matchup_home" TEXT, -- Matchup details including opponent (e.g., "LAL vs. BOS")
23
+ "wl_home" TEXT, -- "W" if the home team won, "L" if they lost
24
+ "min" INTEGER, -- Total minutes played in the game
25
+ "fgm_home" REAL, -- Field goals made by the home team
26
+ "fga_home" REAL, -- Field goals attempted by the home team
27
+ "fg_pct_home" REAL, -- Field goal percentage of the home team
28
+ "fg3m_home" REAL, -- Three-point field goals made by the home team
29
+ "fg3a_home" REAL, -- Three-point attempts by the home team
30
+ "fg3_pct_home" REAL, -- Three-point field goal percentage of the home team
31
+ "ftm_home" REAL, -- Free throws made by the home team
32
+ "fta_home" REAL, -- Free throws attempted by the home team
33
+ "ft_pct_home" REAL, -- Free throw percentage of the home team
34
+ "oreb_home" REAL, -- Offensive rebounds by the home team
35
+ "dreb_home" REAL, -- Defensive rebounds by the home team
36
+ "reb_home" REAL, -- Total rebounds by the home team
37
+ "ast_home" REAL, -- Assists by the home team
38
+ "stl_home" REAL, -- Steals by the home team
39
+ "blk_home" REAL, -- Blocks by the home team
40
+ "tov_home" REAL, -- Turnovers by the home team
41
+ "pf_home" REAL, -- Personal fouls by the home team
42
+ "pts_home" REAL, -- Total points scored by the home team
43
+ "plus_minus_home" INTEGER, -- Plus/minus rating for the home team
44
+ "video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
45
+ "team_id_away" TEXT, -- ID of the away team
46
+ "team_abbreviation_away" TEXT, -- Abbreviation of the away team
47
+ "team_name_away" TEXT, -- Full name of the away team
48
+ "matchup_away" TEXT, -- Matchup details from the away team’s perspective
49
+ "wl_away" TEXT, -- "W" if the away team won, "L" if they lost
50
+ "fgm_away" REAL, -- Field goals made by the away team
51
+ "fga_away" REAL, -- Field goals attempted by the away team
52
+ "fg_pct_away" REAL, -- Field goal percentage of the away team
53
+ "fg3m_away" REAL, -- Three-point field goals made by the away team
54
+ "fg3a_away" REAL, -- Three-point attempts by the away team
55
+ "fg3_pct_away" REAL, -- Three-point field goal percentage of the away team
56
+ "ftm_away" REAL, -- Free throws made by the away team
57
+ "fta_away" REAL, -- Free throws attempted by the away team
58
+ "ft_pct_away" REAL, -- Free throw percentage of the away team
59
+ "oreb_away" REAL, -- Offensive rebounds by the away team
60
+ "dreb_away" REAL, -- Defensive rebounds by the away team
61
+ "reb_away" REAL, -- Total rebounds by the away team
62
+ "ast_away" REAL, -- Assists by the away team
63
+ "stl_away" REAL, -- Steals by the away team
64
+ "blk_away" REAL, -- Blocks by the away team
65
+ "tov_away" REAL, -- Turnovers by the away team
66
+ "pf_away" REAL, -- Personal fouls by the away team
67
+ "pts_away" REAL, -- Total points scored by the away team
68
+ "plus_minus_away" INTEGER, -- Plus/minus rating for the away team
69
+ "video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
70
+ "season_type" TEXT -- Regular season or playoffs
71
+ );
72
+ '''
73
+
74
+ other_stats_table_document = '''other_stats Table
75
+ Stores additional statistics, linked to the game table via game_id.
76
+ CREATE TABLE IF NOT EXISTS "other_stats" (
77
+ "game_id" TEXT, -- Unique game identifier, matches id column from game table
78
+ "league_id" TEXT, -- League identifier
79
+ "team_id_home" TEXT, -- Home team identifier
80
+ "team_abbreviation_home" TEXT, -- Home team abbreviation
81
+ "team_city_home" TEXT, -- Home team city
82
+ "pts_paint_home" INTEGER, -- Points in the paint by the home team
83
+ "pts_2nd_chance_home" INTEGER, -- Second chance points by the home team
84
+ "pts_fb_home" INTEGER, -- Fast break points by the home team
85
+ "largest_lead_home" INTEGER,-- Largest lead by the home team
86
+ "lead_changes" INTEGER, -- Number of lead changes
87
+ "times_tied" INTEGER, -- Number of times the score was tied
88
+ "team_turnovers_home" INTEGER, -- Home team turnovers
89
+ "total_turnovers_home" INTEGER, -- Total turnovers by the home team
90
+ "team_rebounds_home" INTEGER, -- Home team rebounds
91
+ "pts_off_to_home" INTEGER, -- Points off turnovers by the home team
92
+ "team_id_away" TEXT, -- Away team identifier
93
+ "team_abbreviation_away" TEXT, -- Away team abbreviation
94
+ "pts_paint_away" INTEGER, -- Points in the paint by the away team
95
+ "pts_2nd_chance_away" INTEGER, -- Second chance points by the away team
96
+ "pts_fb_away" INTEGER, -- Fast break points by the away team
97
+ "largest_lead_away" INTEGER,-- Largest lead by the away team
98
+ "team_turnovers_away" INTEGER, -- Away team turnovers
99
+ "total_turnovers_away" INTEGER, -- Total turnovers by the away team
100
+ "team_rebounds_away" INTEGER, -- Away team rebounds
101
+ "pts_off_to_away" INTEGER -- Points off turnovers by the away team
102
+ );
103
+ '''
104
+
105
+ team_name_document = '''Team Name Information
106
+ In plaintext user questions, only the full team names will be used, but in the queries you may use either full names or abbreviations.
107
+ Full names are used with the game table, while abbreviations should be used with the other_stats table.
108
+ Team names and abbreviations (separated by |):
109
+ Atlanta Hawks|ATL, Boston Celtics|BOS, Cleveland Cavaliers|CLE, New Orleans Pelicans|NOP,
110
+ Chicago Bulls|CHI, Dallas Mavericks|DAL, Denver Nuggets|DEN, Golden State Warriors|GSW,
111
+ Houston Rockets|HOU, Los Angeles Clippers|LAC, Los Angeles Lakers|LAL, Miami Heat|MIA,
112
+ Milwaukee Bucks|MIL, Minnesota Timberwolves|MIN, Brooklyn Nets|BKN, New York Knicks|NYK,
113
+ Orlando Magic|ORL, Indiana Pacers|IND, Philadelphia 76ers|PHI, Phoenix Suns|PHX,
114
+ Portland Trail Blazers|POR, Sacramento Kings|SAC, San Antonio Spurs|SAS,
115
+ Oklahoma City Thunder|OKC, Toronto Raptors|TOR, Utah Jazz|UTA, Memphis Grizzlies|MEM,
116
+ Washington Wizards|WAS, Detroit Pistons|DET, Charlotte Hornets|CHA
117
+ '''
118
+
119
+ def retrieve_doc(has_team_schema, has_game_schema, has_other_stats_schema, has_team_names = True):
120
+ documents = []
121
+ # Now scores should be a 1D tensor with length equal to available_docs
122
+ if has_team_schema:
123
+ documents.append(team_table_document)
124
+ if has_game_schema:
125
+ documents.append(game_table_document)
126
+ if has_other_stats_schema:
127
+ documents.append(other_stats_table_document)
128
+ if has_team_names:
129
+ documents.append(team_name_document)
130
+
131
+ return documents
src/rag/team_documents.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ team_name_document = '''Team Name Information
2
+ In plaintext user questions, only the full team names will be used, but in the queries you may use either full names or abbreviations.
3
+ Full names are used with the game table, while abbreviations should be used with the other_stats table.
4
+ Team names and abbreviations (separated by |):
5
+ Atlanta Hawks|ATL, Boston Celtics|BOS, Cleveland Cavaliers|CLE, New Orleans Pelicans|NOP,
6
+ Chicago Bulls|CHI, Dallas Mavericks|DAL, Denver Nuggets|DEN, Golden State Warriors|GSW,
7
+ Houston Rockets|HOU, Los Angeles Clippers|LAC, Los Angeles Lakers|LAL, Miami Heat|MIA,
8
+ Milwaukee Bucks|MIL, Minnesota Timberwolves|MIN, Brooklyn Nets|BKN, New York Knicks|NYK,
9
+ Orlando Magic|ORL, Indiana Pacers|IND, Philadelphia 76ers|PHI, Phoenix Suns|PHX,
10
+ Portland Trail Blazers|POR, Sacramento Kings|SAC, San Antonio Spurs|SAS,
11
+ Oklahoma City Thunder|OKC, Toronto Raptors|TOR, Utah Jazz|UTA, Memphis Grizzlies|MEM,
12
+ Washington Wizards|WAS, Detroit Pistons|DET, Charlotte Hornets|CHA
13
+ '''
test_rag_2.ipynb ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9ba5b9ac",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Notebook to evaluate RAG performance"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "afeb236f",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pandas as pd\n",
19
+ "import warnings\n",
20
+ "import torch\n",
21
+ "import time\n",
22
+ "import math\n",
23
+ "import sqlite3 as sql\n",
24
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
25
+ "from huggingface_hub import snapshot_download\n",
26
+ "import sys\n",
27
+ "import os"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "b7c75665",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Create RAG document store"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "0e202df8",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "is_google_colab=False"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "cc6c4ccd",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "current_path = \"./\"\n",
56
+ "\n",
57
+ "def get_path(rel_path):\n",
58
+ " return os.path.join(current_path, rel_path)\n",
59
+ "\n",
60
+ "if is_google_colab:\n",
61
+ " hugging_face_path = snapshot_download(\n",
62
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
63
+ " repo_type=\"model\", \n",
64
+ " allow_patterns=[\"src/*\", \"train-data/*\", \"deepseek-coder-1.3b-instruct/*\", \"nba-data/*\"], \n",
65
+ " )\n",
66
+ " sys.path.append(hugging_face_path)\n",
67
+ " current_path = hugging_face_path"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "d589714b",
74
+ "metadata": {},
75
+ "outputs": [
76
+ {
77
+ "name": "stderr",
78
+ "output_type": "stream",
79
+ "text": [
80
+ "/opt/anaconda3/envs/CSCI544/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
81
+ " from .autonotebook import tqdm as notebook_tqdm\n"
82
+ ]
83
+ },
84
+ {
85
+ "name": "stdout",
86
+ "output_type": "stream",
87
+ "text": [
88
+ "Total dataset examples: 1044\n",
89
+ "\n",
90
+ "\n"
91
+ ]
92
+ }
93
+ ],
94
+ "source": [
95
+ "\n",
96
+ "\n",
97
+ "warnings.filterwarnings(\"ignore\")\n",
98
+ "\n",
99
+ "# Establish a database connection once (adjust the DB path as needed)\n",
100
+ "connection = sql.connect(get_path('nba-data/nba.sqlite'))\n",
101
+ "cursor = connection.cursor()\n",
102
+ "\n",
103
+ "# ------------------------------\n",
104
+ "# Load dataset and print summary\n",
105
+ "# ------------------------------\n",
106
+ "df = pd.read_csv(get_path(\"train-data/expanded_ql_train.tsv\"), sep='\\t')\n",
107
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
108
+ "print(\"\\n\")\n",
109
+ "\n",
110
+ "# ------------------------------\n",
111
+ "# Load tokenizer and model\n",
112
+ "# ------------------------------\n",
113
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
114
+ "tokenizer = AutoTokenizer.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"))\n",
115
+ "model = AutoModelForCausalLM.from_pretrained(get_path(\n",
116
+ " \"deepseek-coder-1.3b-instruct\"),\n",
117
+ " torch_dtype=torch.bfloat16,\n",
118
+ " device_map=device\n",
119
+ ")\n",
120
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
121
+ "\n",
122
+ "\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "499d2745",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Define compare result function for evaluation process"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 2,
136
+ "id": "268561cd",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "from src.evaluation.compare_result import compare_result\n",
141
+ "from src.rag.table_retriever import retrieve_doc"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "id": "e7393ccb",
147
+ "metadata": {},
148
+ "source": [
149
+ "## Create evaluation loop for RAG model"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "500f003b",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "# ------------------------------\n",
160
+ "# Function to evaluate the model on a given dataset\n",
161
+ "# ------------------------------\n",
162
+ "def run_evaluation(nba_df, title):\n",
163
+ " counter = 0\n",
164
+ " num_valid = 0\n",
165
+ " num_sql_matched = 0\n",
166
+ " num_result_matched = 0\n",
167
+ " for index, row in nba_df.iterrows():\n",
168
+ " # Retrieve relevant schema chunks via RAG\n",
169
+ " relevant_schemas = retrieve_doc(row['team_flag'], row['game_flag'], row['other_stats_flag'], False)\n",
170
+ " schema_block = \"\\n\\n\".join(relevant_schemas)\n",
171
+ " \n",
172
+ " #print(row[\"natural_query\"])\n",
173
+ " #print(row[\"sql_query\"])\n",
174
+ " #print(schema_block)\n",
175
+ " #return\n",
176
+ " # Build the prompt with instructions, schema, examples, and current request.\n",
177
+ " input_text = f\"\"\"\n",
178
+ "You are an AI assistant that generates SQLite queries for an NBA database based on user questions.\n",
179
+ "\n",
180
+ "### Relevant Schema:\n",
181
+ "{schema_block}\n",
182
+ "\n",
183
+ "### Instructions:\n",
184
+ "- Generate a valid SQLite query to retrieve relevant data from the database.\n",
185
+ "- Use column names correctly based on the provided schema.\n",
186
+ "- Output only the SQLite query as plain text.\n",
187
+ "\n",
188
+ "### Team Name Information:\n",
189
+ "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
190
+ "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
191
+ "Notice they are separated by the | character in the following list:\n",
192
+ "\n",
193
+ "Atlanta Hawks|ATL\n",
194
+ "Boston Celtics|BOS\n",
195
+ "Cleveland Cavaliers|CLE\n",
196
+ "New Orleans Pelicans|NOP\n",
197
+ "Chicago Bulls|CHI\n",
198
+ "Dallas Mavericks|DAL\n",
199
+ "Denver Nuggets|DEN\n",
200
+ "Golden State Warriors|GSW\n",
201
+ "Houston Rockets|HOU\n",
202
+ "Los Angeles Clippers|LAC\n",
203
+ "Los Angeles Lakers|LAL\n",
204
+ "Miami Heat|MIA\n",
205
+ "Milwaukee Bucks|MIL\n",
206
+ "Minnesota Timberwolves|MIN\n",
207
+ "Brooklyn Nets|BKN\n",
208
+ "New York Knicks|NYK\n",
209
+ "Orlando Magic|ORL\n",
210
+ "Indiana Pacers|IND\n",
211
+ "Philadelphia 76ers|PHI\n",
212
+ "Phoenix Suns|PHX\n",
213
+ "Portland Trail Blazers|POR\n",
214
+ "Sacramento Kings|SAC\n",
215
+ "San Antonio Spurs|SAS\n",
216
+ "Oklahoma City Thunder|OKC\n",
217
+ "Toronto Raptors|TOR\n",
218
+ "Utah Jazz|UTA\n",
219
+ "Memphis Grizzlies|MEM\n",
220
+ "Washington Wizards|WAS\n",
221
+ "Detroit Pistons|DET\n",
222
+ "Charlotte Hornets|CHA\n",
223
+ "\n",
224
+ "### Query Guidelines:\n",
225
+ "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
226
+ "\n",
227
+ "To filter by season, use season_id = '2YYYY'.\n",
228
+ "\n",
229
+ "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
230
+ "\n",
231
+ "Ensure queries return relevant columns and avoid unnecessary joins.\n",
232
+ "\n",
233
+ "### Example User Requests and SQLite Queries\n",
234
+ "Request:\n",
235
+ "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
236
+ "SQLite:\n",
237
+ "SELECT MAX(pts_home)\n",
238
+ "FROM game\n",
239
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
240
+ "\n",
241
+ "Request:\n",
242
+ "\"Which teams are located in the state of California?\"\n",
243
+ "SQLite:\n",
244
+ "SELECT full_name FROM team WHERE state = 'California';\n",
245
+ "\n",
246
+ "Request:\n",
247
+ "\"Which team had the highest number of team turnovers in an away game?\"\n",
248
+ "SQLite:\n",
249
+ "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
250
+ "\n",
251
+ "Request:\n",
252
+ "\"Which teams were founded before 1979?\"\n",
253
+ "SQLite:\n",
254
+ "SELECT full_name FROM team WHERE year_founded < 1979;\n",
255
+ "\n",
256
+ "Request:\n",
257
+ "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
258
+ "SQLite:\n",
259
+ "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
260
+ "FROM game\n",
261
+ "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
262
+ "\n",
263
+ "Generate only the SQLite query prefaced by SQLite: and no other text. Now generate an SQLite query for the following user request.\n",
264
+ "Request: {row[\"natural_query\"]}\n",
265
+ "\"\"\"\n",
266
+ " messages = [{'role': 'user', 'content': input_text}]\n",
267
+ " prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
268
+ " inputs = tokenizer(prompt_text, return_tensors=\"pt\", padding=True).to(model.device)\n",
269
+ " \n",
270
+ " outputs = model.generate(\n",
271
+ " **inputs,\n",
272
+ " max_new_tokens=512,\n",
273
+ " do_sample=False,\n",
274
+ " top_k=50,\n",
275
+ " top_p=0.95,\n",
276
+ " num_return_sequences=1,\n",
277
+ " eos_token_id=tokenizer.eos_token_id,\n",
278
+ " pad_token_id=tokenizer.eos_token_id\n",
279
+ " )\n",
280
+ " \n",
281
+ " # Decode the model output.\n",
282
+ " generated_query = tokenizer.decode(outputs[0][len(inputs[\"input_ids\"][0]):], skip_special_tokens=True)\n",
283
+ " \n",
284
+ " # Clean generated query: remove any prefix and truncate after first semicolon.\n",
285
+ " if generated_query.startswith(\"SQLite:\"):\n",
286
+ " clean_query = generated_query[len(\"SQLite:\"):].strip()\n",
287
+ " elif generated_query.startswith(\"SQL:\"):\n",
288
+ " clean_query = generated_query[len(\"SQL:\"):].strip()\n",
289
+ " else:\n",
290
+ " clean_query = generated_query.strip()\n",
291
+ " \n",
292
+ " semicolon_idx = clean_query.find(\";\")\n",
293
+ " if semicolon_idx != -1:\n",
294
+ " clean_query = clean_query[:semicolon_idx+1]\n",
295
+ " \n",
296
+ " # Execute the cleaned query on the SQLite DB to obtain the actual result.\n",
297
+ " \"\"\"\n",
298
+ " try:\n",
299
+ " cursor.execute(clean_query)\n",
300
+ " rows = cursor.fetchall()\n",
301
+ " if rows and isinstance(rows[0], (tuple, list)) and len(rows[0]) > 0:\n",
302
+ " actual_result = rows[0][0]\n",
303
+ " elif rows:\n",
304
+ " actual_result = rows[0]\n",
305
+ " else:\n",
306
+ " actual_result = \"\"\n",
307
+ " except Exception as e:\n",
308
+ " actual_result = \"Error executing query: \" + str(e)\n",
309
+ " \"\"\"\n",
310
+ " \n",
311
+ " # Compare the ground truth query and expected result to the generated query and actual result.\n",
312
+ " valid, sql_matched, result_matched = compare_result(cursor, row[\"sql_query\"], row[\"result\"], generated_query)\n",
313
+ " \"\"\"\n",
314
+ " print(\"=============================================\")\n",
315
+ " print(f\"Overall Valid: {valid}\")\n",
316
+ " print(f\"SQL Query Matched: {sql_matched}\")\n",
317
+ " print(f\"Result Matched: {result_matched}\")\n",
318
+ " print(\"=============================================\\n\")\n",
319
+ " \n",
320
+ " # Print debug output.\n",
321
+ " print(\"----- Ground Truth SQL Query -----\")\n",
322
+ " print(row[\"sql_query\"])\n",
323
+ " print(\"------------------------------------\\n\")\n",
324
+ " print(\"----- Model Generated SQL Query -----\")\n",
325
+ " print(generated_query)\n",
326
+ " print(\"---------------------------------------\\n\")\n",
327
+ " \n",
328
+ " print(\"----- Expected Result -----\")\n",
329
+ " print(row[\"result\"])\n",
330
+ " print(\"----- Actual DB Result -----\")\n",
331
+ " print(actual_result)\n",
332
+ " print(\"-------------------------------------------------\\n\")\n",
333
+ " \"\"\"\n",
334
+ " if valid:\n",
335
+ " num_valid += 1\n",
336
+ " if sql_matched:\n",
337
+ " num_sql_matched += 1\n",
338
+ " if result_matched:\n",
339
+ " num_result_matched += 1\n",
340
+ " \n",
341
+ " counter += 1\n",
342
+ "\n",
343
+ " # CONTROL ITERS\n",
344
+ " # if counter == 2:\n",
345
+ " # break\n",
346
+ " \n",
347
+ " if counter % 50 == 0:\n",
348
+ " print(\"Completed \" + str(counter))\n",
349
+ " \n",
350
+ " print(\"\\n\" + title + \" results:\")\n",
351
+ " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
352
+ " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
353
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))\n",
354
+ " print(\"Dataset length: \" + str(len(nba_df)))\n",
355
+ " print(\"-------------------\")\n",
356
+ " print(\"Num queries tested: \", counter)\n",
357
+ " print(\"Num correct queries: \", num_result_matched)\n",
358
+ " print(\"Acc: \", (num_result_matched / counter)*100)\n",
359
+ " print(\"-------------------\")\n",
360
+ " "
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "markdown",
365
+ "id": "9c23d082",
366
+ "metadata": {},
367
+ "source": [
368
+ "## Run evaluation using RAG"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 7,
374
+ "id": "6eb6a1c1",
375
+ "metadata": {},
376
+ "outputs": [
377
+ {
378
+ "name": "stdout",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "Completed 50\n",
382
+ "Completed 100\n",
383
+ "Completed 150\n",
384
+ "Completed 200\n",
385
+ "Completed 250\n",
386
+ "Completed 300\n",
387
+ "Completed 350\n",
388
+ "Completed 400\n",
389
+ "Completed 450\n",
390
+ "Completed 500\n",
391
+ "Completed 550\n",
392
+ "Completed 600\n",
393
+ "Completed 650\n",
394
+ "Completed 700\n",
395
+ "Completed 750\n",
396
+ "Completed 800\n",
397
+ "Completed 850\n",
398
+ "Completed 900\n",
399
+ "Completed 950\n",
400
+ "Completed 1000\n",
401
+ "\n",
402
+ "All training data results:\n",
403
+ "Percent valid: 0.7988505747126436\n",
404
+ "Percent SQLite matched: 0.13409961685823754\n",
405
+ "Percent result matched: 0.3850574712643678\n",
406
+ "Dataset length: 1044\n",
407
+ "-------------------\n",
408
+ "Num queries tested: 1044\n",
409
+ "Num correct queries: 402\n",
410
+ "Acc: 38.50574712643678\n",
411
+ "-------------------\n",
412
+ "Dataset length: 1044\n"
413
+ ]
414
+ }
415
+ ],
416
+ "source": [
417
+ "# ------------------------------\n",
418
+ "# Run evaluation on the full training dataset\n",
419
+ "# ------------------------------\n",
420
+ "run_evaluation(df, \"All training data\")\n",
421
+ "print(\"Dataset length: \" + str(len(df)))"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "id": "f298cfa1",
427
+ "metadata": {},
428
+ "source": [
429
+ "## Run RAG evaluation on small query dataset"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "id": "121855db",
436
+ "metadata": {},
437
+ "outputs": [
438
+ {
439
+ "name": "stdout",
440
+ "output_type": "stream",
441
+ "text": [
442
+ "Completed 50\n",
443
+ "Completed 100\n",
444
+ "Completed 150\n",
445
+ "Completed 200\n",
446
+ "\n",
447
+ "Less than 90 results:\n",
448
+ "Percent valid: 0.8979591836734694\n",
449
+ "Percent SQLite matched: 0.37551020408163266\n",
450
+ "Percent result matched: 0.7061224489795919\n",
451
+ "Dataset length: 245\n",
452
+ "-------------------\n",
453
+ "Num queries tested: 245\n",
454
+ "Num correct queries: 173\n",
455
+ "Acc: 70.61224489795919\n",
456
+ "-------------------\n",
457
+ "Dataset length: 245\n"
458
+ ]
459
+ }
460
+ ],
461
+ "source": [
462
+ "less_than_90_df = pd.read_csv(get_path(\"train-data/less_than_90.tsv\"), sep='\\t')\n",
463
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
464
+ "print(\"Dataset length: \" + str(len(less_than_90_df)))"
465
+ ]
466
+ }
467
+ ],
468
+ "metadata": {
469
+ "kernelspec": {
470
+ "display_name": "CSCI544",
471
+ "language": "python",
472
+ "name": "python3"
473
+ },
474
+ "language_info": {
475
+ "codemirror_mode": {
476
+ "name": "ipython",
477
+ "version": 3
478
+ },
479
+ "file_extension": ".py",
480
+ "mimetype": "text/x-python",
481
+ "name": "python",
482
+ "nbconvert_exporter": "python",
483
+ "pygments_lexer": "ipython3",
484
+ "version": "3.11.11"
485
+ }
486
+ },
487
+ "nbformat": 4,
488
+ "nbformat_minor": 5
489
+ }
train-data/expanded_sql_train.tsv ADDED
The diff for this file is too large to render. See raw diff