licesma commited on
Commit
c1d6d12
·
1 Parent(s): 30ff71b

Add a source folder

Browse files
src/evaluation/compare_result.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ def compare_result(cursor, sample_query, sample_result, query_output):
4
+ # Clean model output to only have the query output
5
+ if query_output[0:7] == "SQLite:":
6
+ query = query_output[7:]
7
+ elif query_output[0:4] == "SQL:":
8
+ query = query_output[4:]
9
+ else:
10
+ query = query_output
11
+
12
+ # Try to execute query, if it fails, then this is a failure of the model
13
+ try:
14
+ # Execute query and obtain result
15
+ cursor.execute(query)
16
+ rows = cursor.fetchall()
17
+
18
+ # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
19
+ query = query.replace(" ", "").replace("\n", "").replace("\t", "")
20
+ sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
21
+ query_match = (query == sample_query)
22
+
23
+ # If the queries match, the results clearly also match
24
+ if query_match:
25
+ return True, True, True
26
+
27
+ # Check if this is a multi-line query
28
+ if "|" in sample_result or "(" in sample_result:
29
+ #print(rows)
30
+ # Create list of results by stripping separators and splitting on them
31
+ if "(" in sample_result:
32
+ sample_result = sample_result.replace("(", "").replace(")", "")
33
+ result_list = sample_result.split(",")
34
+ else:
35
+ result_list = sample_result.split("|")
36
+
37
+ # Strip all results in list
38
+ for i in range(len(result_list)):
39
+ result_list[i] = str(result_list[i]).strip()
40
+
41
+ # Loop through model result and see if it matches training example
42
+ result = False
43
+ for row in rows:
44
+ for r in row:
45
+ for res in result_list:
46
+ try:
47
+ if math.isclose(float(r), float(res), abs_tol=0.5):
48
+ return True, query_match, True
49
+ except:
50
+ if r in res or res in r:
51
+ return True, query_match, True
52
+
53
+ # Check if the model returned a sum of examples as opposed to the whole thing
54
+ if len(rows) == 1:
55
+ for r in rows[0]:
56
+ if r == str(len(result_list)):
57
+ return True, query_match, True
58
+
59
+ return True, query_match, result
60
+ # Else the sample result is a single value or string
61
+ else:
62
+ #print(rows)
63
+ result = False
64
+ # Loop through model result and see if it contains the sample result
65
+ for row in rows:
66
+ for r in row:
67
+ # Check by string
68
+ if str(r) in str(sample_result):
69
+ try:
70
+ if math.isclose(float(r), float(sample_result), abs_tol=0.5):
71
+ return True, query_match, True
72
+ except:
73
+ return True, query_match, True
74
+ # Check by number, using try incase the cast as float fails
75
+ try:
76
+ if math.isclose(float(r), float(sample_result), abs_tol=0.5):
77
+ return True, query_match, True
78
+ except:
79
+ pass
80
+
81
+ # Check if the model returned a list of examples instead of a total sum (both acceptable)
82
+ try:
83
+ if len(rows) > 1 and len(rows) == int(sample_result):
84
+ return True, query_match, True
85
+ if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):
86
+ return True, query_match, True
87
+ except:
88
+ pass
89
+
90
+ # Compare results and return
91
+ return True, query_match, result
92
+ except:
93
+ return False, False, False
src/prompts/prompt.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_text = """You are an AI assistant that converts natural language queries into valid SQLite queries.
2
+ Database Schema and Explanations
3
+
4
+ team Table
5
+ Stores information about NBA teams.
6
+ CREATE TABLE IF NOT EXISTS "team" (
7
+ "id" TEXT PRIMARY KEY, -- Unique identifier for the team
8
+ "full_name" TEXT, -- Full official name of the team (e.g., "Los Angeles Lakers")
9
+ "abbreviation" TEXT, -- Shortened team name (e.g., "LAL")
10
+ "nickname" TEXT, -- Commonly used nickname for the team (e.g., "Lakers")
11
+ "city" TEXT, -- City where the team is based
12
+ "state" TEXT, -- State where the team is located
13
+ "year_founded" REAL -- Year the team was established
14
+ );
15
+
16
+ game Table
17
+ Contains detailed statistics for each NBA game, including home and away team performance.
18
+ CREATE TABLE IF NOT EXISTS "game" (
19
+ "season_id" TEXT, -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season)
20
+ "team_id_home" TEXT, -- ID of the home team (matches "id" in team table)
21
+ "team_abbreviation_home" TEXT, -- Abbreviation of the home team
22
+ "team_name_home" TEXT, -- Full name of the home team
23
+ "game_id" TEXT PRIMARY KEY, -- Unique identifier for the game
24
+ "game_date" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)
25
+ "matchup_home" TEXT, -- Matchup details including opponent (e.g., "LAL vs. BOS")
26
+ "wl_home" TEXT, -- "W" if the home team won, "L" if they lost
27
+ "min" INTEGER, -- Total minutes played in the game
28
+ "fgm_home" REAL, -- Field goals made by the home team
29
+ "fga_home" REAL, -- Field goals attempted by the home team
30
+ "fg_pct_home" REAL, -- Field goal percentage of the home team
31
+ "fg3m_home" REAL, -- Three-point field goals made by the home team
32
+ "fg3a_home" REAL, -- Three-point attempts by the home team
33
+ "fg3_pct_home" REAL, -- Three-point field goal percentage of the home team
34
+ "ftm_home" REAL, -- Free throws made by the home team
35
+ "fta_home" REAL, -- Free throws attempted by the home team
36
+ "ft_pct_home" REAL, -- Free throw percentage of the home team
37
+ "oreb_home" REAL, -- Offensive rebounds by the home team
38
+ "dreb_home" REAL, -- Defensive rebounds by the home team
39
+ "reb_home" REAL, -- Total rebounds by the home team
40
+ "ast_home" REAL, -- Assists by the home team
41
+ "stl_home" REAL, -- Steals by the home team
42
+ "blk_home" REAL, -- Blocks by the home team
43
+ "tov_home" REAL, -- Turnovers by the home team
44
+ "pf_home" REAL, -- Personal fouls by the home team
45
+ "pts_home" REAL, -- Total points scored by the home team
46
+ "plus_minus_home" INTEGER, -- Plus/minus rating for the home team
47
+ "video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
48
+ "team_id_away" TEXT, -- ID of the away team
49
+ "team_abbreviation_away" TEXT, -- Abbreviation of the away team
50
+ "team_name_away" TEXT, -- Full name of the away team
51
+ "matchup_away" TEXT, -- Matchup details from the away team’s perspective
52
+ "wl_away" TEXT, -- "W" if the away team won, "L" if they lost
53
+ "fgm_away" REAL, -- Field goals made by the away team
54
+ "fga_away" REAL, -- Field goals attempted by the away team
55
+ "fg_pct_away" REAL, -- Field goal percentage of the away team
56
+ "fg3m_away" REAL, -- Three-point field goals made by the away team
57
+ "fg3a_away" REAL, -- Three-point attempts by the away team
58
+ "fg3_pct_away" REAL, -- Three-point field goal percentage of the away team
59
+ "ftm_away" REAL, -- Free throws made by the away team
60
+ "fta_away" REAL, -- Free throws attempted by the away team
61
+ "ft_pct_away" REAL, -- Free throw percentage of the away team
62
+ "oreb_away" REAL, -- Offensive rebounds by the away team
63
+ "dreb_away" REAL, -- Defensive rebounds by the away team
64
+ "reb_away" REAL, -- Total rebounds by the away team
65
+ "ast_away" REAL, -- Assists by the away team
66
+ "stl_away" REAL, -- Steals by the away team
67
+ "blk_away" REAL, -- Blocks by the away team
68
+ "tov_away" REAL, -- Turnovers by the away team
69
+ "pf_away" REAL, -- Personal fouls by the away team
70
+ "pts_away" REAL, -- Total points scored by the away team
71
+ "plus_minus_away" INTEGER, -- Plus/minus rating for the away team
72
+ "video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
73
+ "season_type" TEXT -- Regular season or playoffs
74
+ );
75
+
76
+ other_stats Table
77
+ Stores additional statistics, linked to the game table via game_id.
78
+ CREATE TABLE IF NOT EXISTS "other_stats" (
79
+ "game_id" TEXT, -- Unique game identifier, matches id column from game table
80
+ "league_id" TEXT, -- League identifier
81
+ "team_id_home" TEXT, -- Home team identifier
82
+ "team_abbreviation_home" TEXT, -- Home team abbreviation
83
+ "team_city_home" TEXT, -- Home team city
84
+ "pts_paint_home" INTEGER, -- Points in the paint by the home team
85
+ "pts_2nd_chance_home" INTEGER, -- Second chance points by the home team
86
+ "pts_fb_home" INTEGER, -- Fast break points by the home team
87
+ "largest_lead_home" INTEGER,-- Largest lead by the home team
88
+ "lead_changes" INTEGER, -- Number of lead changes
89
+ "times_tied" INTEGER, -- Number of times the score was tied
90
+ "team_turnovers_home" INTEGER, -- Home team turnovers
91
+ "total_turnovers_home" INTEGER, -- Total turnovers by the home team
92
+ "team_rebounds_home" INTEGER, -- Home team rebounds
93
+ "pts_off_to_home" INTEGER, -- Points off turnovers by the home team
94
+ "team_id_away" TEXT, -- Away team identifier
95
+ "team_abbreviation_away" TEXT, -- Away team abbreviation
96
+ "pts_paint_away" INTEGER, -- Points in the paint by the away team
97
+ "pts_2nd_chance_away" INTEGER, -- Second chance points by the away team
98
+ "pts_fb_away" INTEGER, -- Fast break points by the away team
99
+ "largest_lead_away" INTEGER,-- Largest lead by the away team
100
+ "team_turnovers_away" INTEGER, -- Away team turnovers
101
+ "total_turnovers_away" INTEGER, -- Total turnovers by the away team
102
+ "team_rebounds_away" INTEGER, -- Away team rebounds
103
+ "pts_off_to_away" INTEGER -- Points off turnovers by the away team
104
+ );
105
+
106
+
107
+ Team Name Information
108
+ In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations.
109
+ The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.
110
+ Notice they are separated by the | character in the following list:
111
+
112
+ Atlanta Hawks|ATL
113
+ Boston Celtics|BOS
114
+ Cleveland Cavaliers|CLE
115
+ New Orleans Pelicans|NOP
116
+ Chicago Bulls|CHI
117
+ Dallas Mavericks|DAL
118
+ Denver Nuggets|DEN
119
+ Golden State Warriors|GSW
120
+ Houston Rockets|HOU
121
+ Los Angeles Clippers|LAC
122
+ Los Angeles Lakers|LAL
123
+ Miami Heat|MIA
124
+ Milwaukee Bucks|MIL
125
+ Minnesota Timberwolves|MIN
126
+ Brooklyn Nets|BKN
127
+ New York Knicks|NYK
128
+ Orlando Magic|ORL
129
+ Indiana Pacers|IND
130
+ Philadelphia 76ers|PHI
131
+ Phoenix Suns|PHX
132
+ Portland Trail Blazers|POR
133
+ Sacramento Kings|SAC
134
+ San Antonio Spurs|SAS
135
+ Oklahoma City Thunder|OKC
136
+ Toronto Raptors|TOR
137
+ Utah Jazz|UTA
138
+ Memphis Grizzlies|MEM
139
+ Washington Wizards|WAS
140
+ Detroit Pistons|DET
141
+ Charlotte Hornets|CHA
142
+
143
+ Query Guidelines
144
+ Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.
145
+
146
+ To filter by season, use season_id = '2YYYY'.
147
+
148
+ Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = "21972". To get statistics from 2015, use a statement like: season_id = "22015".
149
+
150
+ Ensure queries return relevant columns and avoid unnecessary joins.
151
+
152
+ Example User Requests and SQLite Queries
153
+ Request:
154
+ "What is the most points the Los Angeles Lakers have ever scored at home?"
155
+ SQLite:
156
+ SELECT MAX(pts_home) FROM game WHERE team_name_home = 'Los Angeles Lakers';
157
+
158
+ Request:
159
+ "Which teams are located in the state of California?"
160
+ SQLite:
161
+ SELECT full_name FROM team WHERE state = 'California';
162
+
163
+ Request:
164
+ "Which team had the highest number of team turnovers in an away game?"
165
+ SQLite:
166
+ SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;
167
+
168
+ Request:
169
+ "Which teams were founded before 1979?"
170
+ SQLite:
171
+ SELECT full_name FROM team WHERE year_founded < 1979;
172
+
173
+ Request:
174
+ "Find the Boston Celtics largest home victory margin in the 2008 season."
175
+ SQLite:
176
+ SELECT MAX(pts_home - pts_away) AS biggest_win FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';
177
+
178
+ Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:
179
+ """
test_pretrained.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [
22
  {
@@ -26,9 +26,9 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "Which team committed the fewest total turnovers in an away game that resulted in a win?\n",
30
- "SELECT team_abbreviation_away FROM other_stats WHERE game_id IN (SELECT game_id FROM game WHERE wl_away = 'W') ORDER BY total_turnovers_away ASC LIMIT 1;\n",
31
- "PHX\n"
32
  ]
33
  }
34
  ],
@@ -58,7 +58,7 @@
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 2,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
@@ -83,193 +83,11 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 3,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
90
- "input_text = \"\"\"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
91
- "Database Schema and Explanations\n",
92
- "\n",
93
- "team Table\n",
94
- "Stores information about NBA teams.\n",
95
- "CREATE TABLE IF NOT EXISTS \"team\" (\n",
96
- " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
97
- " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
98
- " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
99
- " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
100
- " \"city\" TEXT, -- City where the team is based\n",
101
- " \"state\" TEXT, -- State where the team is located\n",
102
- " \"year_founded\" REAL -- Year the team was established\n",
103
- ");\n",
104
- "\n",
105
- "game Table\n",
106
- "Contains detailed statistics for each NBA game, including home and away team performance.\n",
107
- "CREATE TABLE IF NOT EXISTS \"game\" (\n",
108
- " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
109
- " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
110
- " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
111
- " \"team_name_home\" TEXT, -- Full name of the home team\n",
112
- " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
113
- " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
114
- " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
115
- " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
116
- " \"min\" INTEGER, -- Total minutes played in the game\n",
117
- " \"fgm_home\" REAL, -- Field goals made by the home team\n",
118
- " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
119
- " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
120
- " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
121
- " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
122
- " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
123
- " \"ftm_home\" REAL, -- Free throws made by the home team\n",
124
- " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
125
- " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
126
- " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
127
- " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
128
- " \"reb_home\" REAL, -- Total rebounds by the home team\n",
129
- " \"ast_home\" REAL, -- Assists by the home team\n",
130
- " \"stl_home\" REAL, -- Steals by the home team\n",
131
- " \"blk_home\" REAL, -- Blocks by the home team\n",
132
- " \"tov_home\" REAL, -- Turnovers by the home team\n",
133
- " \"pf_home\" REAL, -- Personal fouls by the home team\n",
134
- " \"pts_home\" REAL, -- Total points scored by the home team\n",
135
- " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
136
- " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
137
- " \"team_id_away\" TEXT, -- ID of the away team\n",
138
- " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
139
- " \"team_name_away\" TEXT, -- Full name of the away team\n",
140
- " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
141
- " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
142
- " \"fgm_away\" REAL, -- Field goals made by the away team\n",
143
- " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
144
- " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
145
- " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
146
- " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
147
- " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
148
- " \"ftm_away\" REAL, -- Free throws made by the away team\n",
149
- " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
150
- " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
151
- " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
152
- " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
153
- " \"reb_away\" REAL, -- Total rebounds by the away team\n",
154
- " \"ast_away\" REAL, -- Assists by the away team\n",
155
- " \"stl_away\" REAL, -- Steals by the away team\n",
156
- " \"blk_away\" REAL, -- Blocks by the away team\n",
157
- " \"tov_away\" REAL, -- Turnovers by the away team\n",
158
- " \"pf_away\" REAL, -- Personal fouls by the away team\n",
159
- " \"pts_away\" REAL, -- Total points scored by the away team\n",
160
- " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
161
- " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
162
- " \"season_type\" TEXT -- Regular season or playoffs\n",
163
- ");\n",
164
- "\n",
165
- "other_stats Table\n",
166
- "Stores additional statistics, linked to the game table via game_id.\n",
167
- "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
168
- " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
169
- " \"league_id\" TEXT, -- League identifier\n",
170
- " \"team_id_home\" TEXT, -- Home team identifier\n",
171
- " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
172
- " \"team_city_home\" TEXT, -- Home team city\n",
173
- " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
174
- " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
175
- " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
176
- " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
177
- " \"lead_changes\" INTEGER, -- Number of lead changes \n",
178
- " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
179
- " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
180
- " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
181
- " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
182
- " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
183
- " \"team_id_away\" TEXT, -- Away team identifier\n",
184
- " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
185
- " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
186
- " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
187
- " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
188
- " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
189
- " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
190
- " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
191
- " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
192
- " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
193
- ");\n",
194
- "\n",
195
- "\n",
196
- "Team Name Information\n",
197
- "In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. \n",
198
- "The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.\n",
199
- "Notice they are separated by the | character in the following list:\n",
200
- "\n",
201
- "Atlanta Hawks|ATL\n",
202
- "Boston Celtics|BOS\n",
203
- "Cleveland Cavaliers|CLE\n",
204
- "New Orleans Pelicans|NOP\n",
205
- "Chicago Bulls|CHI\n",
206
- "Dallas Mavericks|DAL\n",
207
- "Denver Nuggets|DEN\n",
208
- "Golden State Warriors|GSW\n",
209
- "Houston Rockets|HOU\n",
210
- "Los Angeles Clippers|LAC\n",
211
- "Los Angeles Lakers|LAL\n",
212
- "Miami Heat|MIA\n",
213
- "Milwaukee Bucks|MIL\n",
214
- "Minnesota Timberwolves|MIN\n",
215
- "Brooklyn Nets|BKN\n",
216
- "New York Knicks|NYK\n",
217
- "Orlando Magic|ORL\n",
218
- "Indiana Pacers|IND\n",
219
- "Philadelphia 76ers|PHI\n",
220
- "Phoenix Suns|PHX\n",
221
- "Portland Trail Blazers|POR\n",
222
- "Sacramento Kings|SAC\n",
223
- "San Antonio Spurs|SAS\n",
224
- "Oklahoma City Thunder|OKC\n",
225
- "Toronto Raptors|TOR\n",
226
- "Utah Jazz|UTA\n",
227
- "Memphis Grizzlies|MEM\n",
228
- "Washington Wizards|WAS\n",
229
- "Detroit Pistons|DET\n",
230
- "Charlotte Hornets|CHA\n",
231
- "\n",
232
- "Query Guidelines\n",
233
- "Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.\n",
234
- "\n",
235
- "To filter by season, use season_id = '2YYYY'.\n",
236
- "\n",
237
- "Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = \"21972\". To get statistics from 2015, use a statement like: season_id = \"22015\".\n",
238
- "\n",
239
- "Ensure queries return relevant columns and avoid unnecessary joins.\n",
240
- "\n",
241
- "Example User Requests and SQLite Queries\n",
242
- "Request:\n",
243
- "\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
244
- "SQLite:\n",
245
- "SELECT MAX(pts_home) \n",
246
- "FROM game \n",
247
- "WHERE team_name_home = 'Los Angeles Lakers';\n",
248
- "\n",
249
- "Request:\n",
250
- "\"Which teams are located in the state of California?\"\n",
251
- "SQLite:\n",
252
- "SELECT full_name FROM team WHERE state = 'California';\n",
253
- "\n",
254
- "Request:\n",
255
- "\"Which team had the highest number of team turnovers in an away game?\"\n",
256
- "SQLite:\n",
257
- "SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;\n",
258
- "\n",
259
- "Request:\n",
260
- "\"Which teams were founded before 1979?\"\n",
261
- "SQLite:\n",
262
- "SELECT full_name FROM team WHERE year_founded < 1979;\n",
263
- "\n",
264
- "Request:\n",
265
- "\"Find the Boston Celtics largest home victory margin in the 2008 season.\"\n",
266
- "SQLite:\n",
267
- "SELECT MAX(pts_home - pts_away) AS biggest_win\n",
268
- "FROM game\n",
269
- "WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';\n",
270
- "\n",
271
- "Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:\n",
272
- "\"\"\""
273
  ]
274
  },
275
  {
@@ -281,7 +99,7 @@
281
  },
282
  {
283
  "cell_type": "code",
284
- "execution_count": 4,
285
  "metadata": {},
286
  "outputs": [
287
  {
@@ -289,9 +107,7 @@
289
  "output_type": "stream",
290
  "text": [
291
  "SQLite:\n",
292
- "SELECT team_abbreviation_away \n",
293
- "FROM other_stats \n",
294
- "WHERE wl_away = 'W' AND total_turnovers_away < (SELECT MIN(total_turnovers_away) FROM other_stats WHERE wl_away = 'L');\n",
295
  "\n"
296
  ]
297
  }
@@ -316,7 +132,7 @@
316
  },
317
  {
318
  "cell_type": "code",
319
- "execution_count": 5,
320
  "metadata": {},
321
  "outputs": [
322
  {
@@ -361,29 +177,24 @@
361
  },
362
  {
363
  "cell_type": "code",
364
- "execution_count": 6,
365
  "metadata": {},
366
  "outputs": [
367
  {
368
- "name": "stdout",
369
- "output_type": "stream",
370
- "text": [
371
- "What is the largest lead the Minnesota Timberwolves had at home?\n",
372
- "SELECT MAX(largest_lead_home) as max_lead FROM other_stats WHERE team_abbreviation_home = 'MIN';\n",
373
- "48.0\n",
374
- "SQLite:\n",
375
- "SELECT MAX(largest_lead_home) \n",
376
- "FROM other_stats \n",
377
- "WHERE team_name_home = 'Minnesota Timberwolves';\n",
378
- "\n",
379
- "Statement valid? False\n",
380
- "SQLite matched? False\n",
381
- "Result matched? False\n"
382
  ]
383
  }
384
  ],
385
  "source": [
386
  "import math\n",
 
387
  "\n",
388
  "def compare_result(sample_query, sample_result, query_output):\n",
389
  " # Clean model output to only have the query output\n",
@@ -479,9 +290,18 @@
479
  "\n",
480
  "# Obtain sample\n",
481
  "sample = df.sample(n=1)\n",
 
 
 
 
 
 
 
 
482
  "print(sample[\"natural_query\"].values[0])\n",
483
  "print(sample[\"sql_query\"].values[0])\n",
484
  "print(sample[\"result\"].values[0])\n",
 
485
  "\n",
486
  "# Create message with sample query and run model\n",
487
  "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
@@ -495,7 +315,12 @@
495
  "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
496
  "print(\"Statement valid? \" + str(result[0]))\n",
497
  "print(\"SQLite matched? \" + str(result[1]))\n",
498
- "print(\"Result matched? \" + str(result[2]))"
 
 
 
 
 
499
  ]
500
  },
501
  {
@@ -830,7 +655,7 @@
830
  ],
831
  "metadata": {
832
  "kernelspec": {
833
- "display_name": "Python 3",
834
  "language": "python",
835
  "name": "python3"
836
  },
@@ -844,7 +669,7 @@
844
  "name": "python",
845
  "nbconvert_exporter": "python",
846
  "pygments_lexer": "ipython3",
847
- "version": "3.12.6"
848
  }
849
  },
850
  "nbformat": 4,
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 2,
20
  "metadata": {},
21
  "outputs": [
22
  {
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "Which team had the largest lead in a single game in the 2001 season?\n",
30
+ "SELECT g.team_name_home AS team, os.largest_lead_home AS lead FROM other_stats os JOIN game g ON os.game_id = g.game_id WHERE g.season_id = '22001' ORDER BY os.largest_lead_home DESC LIMIT 1;\n",
31
+ "Portland Trail Blazers|47\n"
32
  ]
33
  }
34
  ],
 
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": 3,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": 19,
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
90
+ "from src.prompts.prompt import input_text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ]
92
  },
93
  {
 
99
  },
100
  {
101
  "cell_type": "code",
102
+ "execution_count": 5,
103
  "metadata": {},
104
  "outputs": [
105
  {
 
107
  "output_type": "stream",
108
  "text": [
109
  "SQLite:\n",
110
+ "SELECT team_abbreviation_home FROM other_stats WHERE lead_changes = 1 AND season_id = '2001';\n",
 
 
111
  "\n"
112
  ]
113
  }
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 17,
136
  "metadata": {},
137
  "outputs": [
138
  {
 
177
  },
178
  {
179
  "cell_type": "code",
180
+ "execution_count": null,
181
  "metadata": {},
182
  "outputs": [
183
  {
184
+ "ename": "ImportError",
185
+ "evalue": "cannot import name 'compare_result_two' from 'src.evaluation.compare_result' (/Users/esteban/Documents/USC/spring_2025/NLP/SQL-Generation/src/evaluation/compare_result.py)",
186
+ "output_type": "error",
187
+ "traceback": [
188
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
189
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
190
+ "Cell \u001b[0;32mIn[30], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mevaluation\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcompare_result\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m compare_result_two\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompare_result\u001b[39m(sample_query, sample_result, query_output):\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Clean model output to only have the query output\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m query_output[\u001b[38;5;241m0\u001b[39m:\u001b[38;5;241m7\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSQLite:\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
191
+ "\u001b[0;31mImportError\u001b[0m: cannot import name 'compare_result_two' from 'src.evaluation.compare_result' (/Users/esteban/Documents/USC/spring_2025/NLP/SQL-Generation/src/evaluation/compare_result.py)"
 
 
 
 
 
 
192
  ]
193
  }
194
  ],
195
  "source": [
196
  "import math\n",
197
+ "from src.evaluation.compare_result import compare_result_two\n",
198
  "\n",
199
  "def compare_result(sample_query, sample_result, query_output):\n",
200
  " # Clean model output to only have the query output\n",
 
290
  "\n",
291
  "# Obtain sample\n",
292
  "sample = df.sample(n=1)\n",
293
+ "sample_dic = {\n",
294
+ " \"natural_query\": \"How many home games did the Miami Heat play in the 2021 season?\",\n",
295
+ " \"sql_query\": \"SELECT COUNT(*) FROM game WHERE team_name_home = 'Miami Heat' AND season_id = '22021';\",\n",
296
+ " \"result\": 41.0\n",
297
+ "}\n",
298
+ "\n",
299
+ "sample = pd.DataFrame([sample_dic])\n",
300
+ "\"\"\"\n",
301
  "print(sample[\"natural_query\"].values[0])\n",
302
  "print(sample[\"sql_query\"].values[0])\n",
303
  "print(sample[\"result\"].values[0])\n",
304
+ "\"\"\"\n",
305
  "\n",
306
  "# Create message with sample query and run model\n",
307
  "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
 
315
  "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
316
  "print(\"Statement valid? \" + str(result[0]))\n",
317
  "print(\"SQLite matched? \" + str(result[1]))\n",
318
+ "print(\"Result matched? \" + str(result[2]))\n",
319
+ "\n",
320
+ "result_two = compare_result_two(cursor, sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
321
+ "print(\"Statement valid? \" + str(result_two[0]))\n",
322
+ "print(\"SQLite matched? \" + str(result_two[1]))\n",
323
+ "print(\"Result matched? \" + str(result_two[2]))"
324
  ]
325
  },
326
  {
 
655
  ],
656
  "metadata": {
657
  "kernelspec": {
658
+ "display_name": "CSCI544",
659
  "language": "python",
660
  "name": "python3"
661
  },
 
669
  "name": "python",
670
  "nbconvert_exporter": "python",
671
  "pygments_lexer": "ipython3",
672
+ "version": "3.11.11"
673
  }
674
  },
675
  "nbformat": 4,