updating test pretrained script to evaluate model outputs
Browse files- test_pretrained.ipynb +205 -57
- train-data/sql_train.tsv +1 -1
test_pretrained.ipynb
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
},
|
17 |
{
|
18 |
"cell_type": "code",
|
19 |
-
"execution_count":
|
20 |
"metadata": {},
|
21 |
"outputs": [
|
22 |
{
|
@@ -26,9 +26,9 @@
|
|
26 |
"Total dataset examples: 1044\n",
|
27 |
"\n",
|
28 |
"\n",
|
29 |
-
"
|
30 |
-
"SELECT
|
31 |
-
"
|
32 |
]
|
33 |
}
|
34 |
],
|
@@ -56,7 +56,7 @@
|
|
56 |
},
|
57 |
{
|
58 |
"cell_type": "code",
|
59 |
-
"execution_count":
|
60 |
"metadata": {},
|
61 |
"outputs": [],
|
62 |
"source": [
|
@@ -80,50 +80,153 @@
|
|
80 |
},
|
81 |
{
|
82 |
"cell_type": "code",
|
83 |
-
"execution_count":
|
84 |
"metadata": {},
|
85 |
"outputs": [],
|
86 |
"source": [
|
87 |
-
"input_text = \"\"\"You are an AI assistant that
|
|
|
88 |
"\n",
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"
|
92 |
-
"
|
93 |
-
"
|
94 |
-
"
|
|
|
|
|
|
|
|
|
|
|
95 |
"\n",
|
96 |
-
"
|
97 |
-
"
|
98 |
-
"
|
99 |
-
"
|
100 |
-
"
|
101 |
-
"
|
102 |
-
"
|
103 |
-
"
|
104 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
"\n",
|
106 |
-
"
|
107 |
-
"
|
108 |
-
"
|
109 |
-
"
|
110 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
"\n",
|
112 |
-
"
|
113 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
"SQLite:\n",
|
115 |
"SELECT MAX(pts_home) \n",
|
116 |
"FROM game \n",
|
117 |
"WHERE team_name_home = 'Los Angeles Lakers';\n",
|
118 |
"\n",
|
119 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
"SQLite:\n",
|
121 |
-
"SELECT
|
122 |
"FROM game\n",
|
123 |
-
"WHERE
|
124 |
-
"
|
125 |
-
"
|
126 |
-
"
|
|
|
|
|
127 |
]
|
128 |
},
|
129 |
{
|
@@ -135,7 +238,7 @@
|
|
135 |
},
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
-
"execution_count":
|
139 |
"metadata": {},
|
140 |
"outputs": [
|
141 |
{
|
@@ -153,9 +256,9 @@
|
|
153 |
"output_type": "stream",
|
154 |
"text": [
|
155 |
"SQLite:\n",
|
156 |
-
"SELECT
|
157 |
-
"FROM
|
158 |
-
"WHERE
|
159 |
"\n"
|
160 |
]
|
161 |
}
|
@@ -180,7 +283,7 @@
|
|
180 |
},
|
181 |
{
|
182 |
"cell_type": "code",
|
183 |
-
"execution_count":
|
184 |
"metadata": {},
|
185 |
"outputs": [
|
186 |
{
|
@@ -188,7 +291,11 @@
|
|
188 |
"output_type": "stream",
|
189 |
"text": [
|
190 |
"cleaned\n",
|
191 |
-
"(
|
|
|
|
|
|
|
|
|
192 |
]
|
193 |
}
|
194 |
],
|
@@ -222,25 +329,34 @@
|
|
222 |
},
|
223 |
{
|
224 |
"cell_type": "code",
|
225 |
-
"execution_count":
|
226 |
"metadata": {},
|
227 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
{
|
229 |
"name": "stdout",
|
230 |
"output_type": "stream",
|
231 |
"text": [
|
232 |
-
"
|
233 |
-
"
|
234 |
-
"\n",
|
235 |
-
"
|
|
|
236 |
"FROM game \n",
|
237 |
-
"WHERE team_name_home = '
|
|
|
|
|
238 |
"\n",
|
239 |
-
"
|
240 |
-
"
|
241 |
-
"
|
242 |
-
"SQL matched? True\n",
|
243 |
-
"Result matched? True\n"
|
244 |
]
|
245 |
}
|
246 |
],
|
@@ -260,19 +376,51 @@
|
|
260 |
" cursor.execute(query)\n",
|
261 |
" rows = cursor.fetchall()\n",
|
262 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
263 |
" # Check if this is a multi-line query\n",
|
264 |
" if \"|\" in sample_result:\n",
|
265 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
" else:\n",
|
267 |
-
"
|
268 |
-
"
|
269 |
-
"
|
|
|
|
|
|
|
270 |
"\n",
|
271 |
" # Compare results and return\n",
|
272 |
-
" return
|
273 |
" except:\n",
|
274 |
" return False, False\n",
|
275 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
277 |
"print(\"SQL matched? \" + str(result[0]))\n",
|
278 |
"print(\"Result matched? \" + str(result[1]))"
|
|
|
16 |
},
|
17 |
{
|
18 |
"cell_type": "code",
|
19 |
+
"execution_count": 7,
|
20 |
"metadata": {},
|
21 |
"outputs": [
|
22 |
{
|
|
|
26 |
"Total dataset examples: 1044\n",
|
27 |
"\n",
|
28 |
"\n",
|
29 |
+
"List the full names of all teams founded in the 1980s.\n",
|
30 |
+
"SELECT full_name FROM team WHERE year_founded BETWEEN 1980 AND 1989;\n",
|
31 |
+
"Dallas Mavericks, Miami Heat, Minnesota Timberwolves, Orlando Magic, Charlotte Hornets\n"
|
32 |
]
|
33 |
}
|
34 |
],
|
|
|
56 |
},
|
57 |
{
|
58 |
"cell_type": "code",
|
59 |
+
"execution_count": 8,
|
60 |
"metadata": {},
|
61 |
"outputs": [],
|
62 |
"source": [
|
|
|
80 |
},
|
81 |
{
|
82 |
"cell_type": "code",
|
83 |
+
"execution_count": 9,
|
84 |
"metadata": {},
|
85 |
"outputs": [],
|
86 |
"source": [
|
87 |
+
"input_text = \"\"\"You are an AI assistant that converts natural language queries into valid SQLite queries.\n",
|
88 |
+
"Database Schema and Explanations\n",
|
89 |
"\n",
|
90 |
+
"team Table\n",
|
91 |
+
"Stores information about NBA teams.\n",
|
92 |
+
"CREATE TABLE IF NOT EXISTS \"team\" (\n",
|
93 |
+
" \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
|
94 |
+
" \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
|
95 |
+
" \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
|
96 |
+
" \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
|
97 |
+
" \"city\" TEXT, -- City where the team is based\n",
|
98 |
+
" \"state\" TEXT, -- State where the team is located\n",
|
99 |
+
" \"year_founded\" REAL -- Year the team was established\n",
|
100 |
+
");\n",
|
101 |
"\n",
|
102 |
+
"game Table\n",
|
103 |
+
"Contains detailed statistics for each NBA game, including home and away team performance.\n",
|
104 |
+
"CREATE TABLE IF NOT EXISTS \"game\" (\n",
|
105 |
+
" \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
|
106 |
+
" \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
|
107 |
+
" \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
|
108 |
+
" \"team_name_home\" TEXT, -- Full name of the home team\n",
|
109 |
+
" \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
|
110 |
+
" \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
|
111 |
+
" \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
|
112 |
+
" \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
|
113 |
+
" \"min\" INTEGER, -- Total minutes played in the game\n",
|
114 |
+
" \"fgm_home\" REAL, -- Field goals made by the home team\n",
|
115 |
+
" \"fga_home\" REAL, -- Field goals attempted by the home team\n",
|
116 |
+
" \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
|
117 |
+
" \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
|
118 |
+
" \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
|
119 |
+
" \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
|
120 |
+
" \"ftm_home\" REAL, -- Free throws made by the home team\n",
|
121 |
+
" \"fta_home\" REAL, -- Free throws attempted by the home team\n",
|
122 |
+
" \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
|
123 |
+
" \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
|
124 |
+
" \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
|
125 |
+
" \"reb_home\" REAL, -- Total rebounds by the home team\n",
|
126 |
+
" \"ast_home\" REAL, -- Assists by the home team\n",
|
127 |
+
" \"stl_home\" REAL, -- Steals by the home team\n",
|
128 |
+
" \"blk_home\" REAL, -- Blocks by the home team\n",
|
129 |
+
" \"tov_home\" REAL, -- Turnovers by the home team\n",
|
130 |
+
" \"pf_home\" REAL, -- Personal fouls by the home team\n",
|
131 |
+
" \"pts_home\" REAL, -- Total points scored by the home team\n",
|
132 |
+
" \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
|
133 |
+
" \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
134 |
+
" \"team_id_away\" TEXT, -- ID of the away team\n",
|
135 |
+
" \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
|
136 |
+
" \"team_name_away\" TEXT, -- Full name of the away team\n",
|
137 |
+
" \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
|
138 |
+
" \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
|
139 |
+
" \"fgm_away\" REAL, -- Field goals made by the away team\n",
|
140 |
+
" \"fga_away\" REAL, -- Field goals attempted by the away team\n",
|
141 |
+
" \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
|
142 |
+
" \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
|
143 |
+
" \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
|
144 |
+
" \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
|
145 |
+
" \"ftm_away\" REAL, -- Free throws made by the away team\n",
|
146 |
+
" \"fta_away\" REAL, -- Free throws attempted by the away team\n",
|
147 |
+
" \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
|
148 |
+
" \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
|
149 |
+
" \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
|
150 |
+
" \"reb_away\" REAL, -- Total rebounds by the away team\n",
|
151 |
+
" \"ast_away\" REAL, -- Assists by the away team\n",
|
152 |
+
" \"stl_away\" REAL, -- Steals by the away team\n",
|
153 |
+
" \"blk_away\" REAL, -- Blocks by the away team\n",
|
154 |
+
" \"tov_away\" REAL, -- Turnovers by the away team\n",
|
155 |
+
" \"pf_away\" REAL, -- Personal fouls by the away team\n",
|
156 |
+
" \"pts_away\" REAL, -- Total points scored by the away team\n",
|
157 |
+
" \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
|
158 |
+
" \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
|
159 |
+
" \"season_type\" TEXT -- Regular season or playoffs\n",
|
160 |
+
");\n",
|
161 |
"\n",
|
162 |
+
"other_stats Table\n",
|
163 |
+
"Stores additional game statistics, linked to the game table via game_id.\n",
|
164 |
+
"CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
|
165 |
+
" \"game_id\" TEXT, -- Unique game identifier (links to \"game\" table)\n",
|
166 |
+
" \"league_id\" TEXT, -- League identifier\n",
|
167 |
+
" \"team_id_home\" TEXT, -- Home team identifier\n",
|
168 |
+
" \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
|
169 |
+
" \"team_city_home\" TEXT, -- Home team city\n",
|
170 |
+
" \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
|
171 |
+
" \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
|
172 |
+
" \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
|
173 |
+
" \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
|
174 |
+
" \"lead_changes\" INTEGER, -- Number of lead changes in the game\n",
|
175 |
+
" \"times_tied\" INTEGER, -- Number of times the score was tied\n",
|
176 |
+
" \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
|
177 |
+
" \"total_turnovers_home\" INTEGER, -- Total turnovers in the game\n",
|
178 |
+
" \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
|
179 |
+
" \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
|
180 |
+
" \"team_id_away\" TEXT, -- Away team identifier\n",
|
181 |
+
" \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
|
182 |
+
" \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
|
183 |
+
" \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
|
184 |
+
" \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
|
185 |
+
" \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
|
186 |
+
" \"total_turnovers_away\" INTEGER, -- Total turnovers in the game\n",
|
187 |
+
" \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
|
188 |
+
" \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
|
189 |
+
");\n",
|
190 |
"\n",
|
191 |
+
"\n",
|
192 |
+
"Query Guidelines\n",
|
193 |
+
"Use team_name_home and team_name_away to match teams.\n",
|
194 |
+
"\n",
|
195 |
+
"To filter by season, use season_id = '2YYYY'.\n",
|
196 |
+
"\n",
|
197 |
+
"Example: To get games from 2005, use season_id = '22005'.\n",
|
198 |
+
"\n",
|
199 |
+
"The game_id column links the game and other_stats tables.\n",
|
200 |
+
"\n",
|
201 |
+
"Ensure queries return relevant columns and avoid unnecessary joins.\n",
|
202 |
+
"\n",
|
203 |
+
"Example User Requests and SQLite Queries\n",
|
204 |
+
"Request:\n",
|
205 |
+
"\"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
|
206 |
"SQLite:\n",
|
207 |
"SELECT MAX(pts_home) \n",
|
208 |
"FROM game \n",
|
209 |
"WHERE team_name_home = 'Los Angeles Lakers';\n",
|
210 |
"\n",
|
211 |
+
"Request:\n",
|
212 |
+
"\"How many points did the Miami Heat score on January 10, 2010?\"\n",
|
213 |
+
"SQLite:\n",
|
214 |
+
"SELECT team_name_home, pts_home, team_name_away, pts_away \n",
|
215 |
+
"FROM game \n",
|
216 |
+
"WHERE DATE(game_date) = '2010-01-10' \n",
|
217 |
+
"AND (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');\n",
|
218 |
+
"\n",
|
219 |
+
"Request:\n",
|
220 |
+
"\"Which team won the most home games in the 2000 season?\"\n",
|
221 |
"SQLite:\n",
|
222 |
+
"SELECT team_name_home, COUNT(*) AS wins\n",
|
223 |
"FROM game\n",
|
224 |
+
"WHERE wl_home = 'W' AND season_id = '22000'\n",
|
225 |
+
"GROUP BY team_name_home\n",
|
226 |
+
"ORDER BY wins DESC\n",
|
227 |
+
"LIMIT 1;\n",
|
228 |
+
"\n",
|
229 |
+
"Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following question: \"\"\""
|
230 |
]
|
231 |
},
|
232 |
{
|
|
|
238 |
},
|
239 |
{
|
240 |
"cell_type": "code",
|
241 |
+
"execution_count": 10,
|
242 |
"metadata": {},
|
243 |
"outputs": [
|
244 |
{
|
|
|
256 |
"output_type": "stream",
|
257 |
"text": [
|
258 |
"SQLite:\n",
|
259 |
+
"SELECT full_name \n",
|
260 |
+
"FROM team \n",
|
261 |
+
"WHERE year_founded BETWEEN 1980 AND 1989;\n",
|
262 |
"\n"
|
263 |
]
|
264 |
}
|
|
|
283 |
},
|
284 |
{
|
285 |
"cell_type": "code",
|
286 |
+
"execution_count": 11,
|
287 |
"metadata": {},
|
288 |
"outputs": [
|
289 |
{
|
|
|
291 |
"output_type": "stream",
|
292 |
"text": [
|
293 |
"cleaned\n",
|
294 |
+
"('Dallas Mavericks',)\n",
|
295 |
+
"('Miami Heat',)\n",
|
296 |
+
"('Minnesota Timberwolves',)\n",
|
297 |
+
"('Orlando Magic',)\n",
|
298 |
+
"('Charlotte Hornets',)\n"
|
299 |
]
|
300 |
}
|
301 |
],
|
|
|
329 |
},
|
330 |
{
|
331 |
"cell_type": "code",
|
332 |
+
"execution_count": 67,
|
333 |
"metadata": {},
|
334 |
"outputs": [
|
335 |
+
{
|
336 |
+
"name": "stderr",
|
337 |
+
"output_type": "stream",
|
338 |
+
"text": [
|
339 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
340 |
+
"Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n"
|
341 |
+
]
|
342 |
+
},
|
343 |
{
|
344 |
"name": "stdout",
|
345 |
"output_type": "stream",
|
346 |
"text": [
|
347 |
+
"How many times did the Minnesota Timberwolves lose at home in the 2004 season despite recording more steals and blocks than their opponent?\n",
|
348 |
+
"SELECT COUNT(*) FROM game g WHERE g.team_abbreviation_home = 'MIN' AND g.wl_home = 'L' AND g.stl_home > g.stl_away AND g.blk_home > g.blk_away AND g.season_id = '22004';\n",
|
349 |
+
"0\n",
|
350 |
+
"SQLite:\n",
|
351 |
+
"SELECT COUNT(*) \n",
|
352 |
"FROM game \n",
|
353 |
+
"WHERE team_name_home = 'Minnesota Timberwolves' \n",
|
354 |
+
"AND wl_home = 'L' \n",
|
355 |
+
"AND season_id = '22004';\n",
|
356 |
"\n",
|
357 |
+
"[(17,)]\n",
|
358 |
+
"SQL matched? False\n",
|
359 |
+
"Result matched? False\n"
|
|
|
|
|
360 |
]
|
361 |
}
|
362 |
],
|
|
|
376 |
" cursor.execute(query)\n",
|
377 |
" rows = cursor.fetchall()\n",
|
378 |
"\n",
|
379 |
+
" # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n",
|
380 |
+
" query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
381 |
+
" sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
382 |
+
" query_match = (query == sample_query)\n",
|
383 |
+
"\n",
|
384 |
" # Check if this is a multi-line query\n",
|
385 |
" if \"|\" in sample_result:\n",
|
386 |
+
" result_list = sample_result.split(\"|\") \n",
|
387 |
+
" for i in range(len(result_list)):\n",
|
388 |
+
" result_list[i] = str(result_list[i]).strip()\n",
|
389 |
+
" result = False\n",
|
390 |
+
" for row in rows:\n",
|
391 |
+
" for r in row:\n",
|
392 |
+
" if str(r) in result_list:\n",
|
393 |
+
" return query_match, True\n",
|
394 |
+
" print(rows)\n",
|
395 |
+
" return query_match, result\n",
|
396 |
" else:\n",
|
397 |
+
" print(rows)\n",
|
398 |
+
" result = False\n",
|
399 |
+
" for row in rows:\n",
|
400 |
+
" for r in row:\n",
|
401 |
+
" if str(r) == str(sample_result):\n",
|
402 |
+
" return query_match, True\n",
|
403 |
"\n",
|
404 |
" # Compare results and return\n",
|
405 |
+
" return query_match, result\n",
|
406 |
" except:\n",
|
407 |
" return False, False\n",
|
408 |
"\n",
|
409 |
+
"# Obtain sample\n",
|
410 |
+
"sample = df.sample(n=1)\n",
|
411 |
+
"print(sample[\"natural_query\"].values[0])\n",
|
412 |
+
"print(sample[\"sql_query\"].values[0])\n",
|
413 |
+
"print(sample[\"result\"].values[0])\n",
|
414 |
+
"\n",
|
415 |
+
"# Create message with sample query and run model\n",
|
416 |
+
"message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
|
417 |
+
"inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
|
418 |
+
"outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
|
419 |
+
"\n",
|
420 |
+
"# Print output\n",
|
421 |
+
"query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
422 |
+
"print(query_output)\n",
|
423 |
+
"\n",
|
424 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
425 |
"print(\"SQL matched? \" + str(result[0]))\n",
|
426 |
"print(\"Result matched? \" + str(result[1]))"
|
train-data/sql_train.tsv
CHANGED
@@ -470,7 +470,7 @@ What is the highest combined pts in any game involving the Miami Heat? SELECT MA
|
|
470 |
How many away games did the Milwaukee Bucks play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Milwaukee Bucks' AND season_id = '22018'; 41.0
|
471 |
In which season did the Cleveland Cavaliers have the highest average ft_pct at home? SELECT season_id, AVG(ft_pct_home) as avg_stat FROM game WHERE team_name_home = 'Cleveland Cavaliers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 2014.0
|
472 |
In which season did the Golden State Warriors have the highest average ft_pct at home? SELECT season_id, AVG(ft_pct_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 2016.0
|
473 |
-
In which season did the Houston Rockets have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Houston Rockets' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1;
|
474 |
What is the highest combined ast in any game involving the Los Angeles Lakers? SELECT MAX(ast_home + ast_away) FROM game WHERE team_name_home = 'Los Angeles Lakers' OR team_name_away = 'Los Angeles Lakers'; 86.0
|
475 |
How many away games did the Chicago Bulls play in the 2022 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Chicago Bulls' AND season_id = '22022'; 41.0
|
476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|
|
|
470 |
How many away games did the Milwaukee Bucks play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Milwaukee Bucks' AND season_id = '22018'; 41.0
|
471 |
In which season did the Cleveland Cavaliers have the highest average ft_pct at home? SELECT season_id, AVG(ft_pct_home) as avg_stat FROM game WHERE team_name_home = 'Cleveland Cavaliers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 2014.0
|
472 |
In which season did the Golden State Warriors have the highest average ft_pct at home? SELECT season_id, AVG(ft_pct_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 2016.0
|
473 |
+
In which season did the Houston Rockets have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Houston Rockets' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 41984|58.0
|
474 |
What is the highest combined ast in any game involving the Los Angeles Lakers? SELECT MAX(ast_home + ast_away) FROM game WHERE team_name_home = 'Los Angeles Lakers' OR team_name_away = 'Los Angeles Lakers'; 86.0
|
475 |
How many away games did the Chicago Bulls play in the 2022 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Chicago Bulls' AND season_id = '22022'; 41.0
|
476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|