MikeTerekhov commited on
Commit
03182a1
·
verified ·
1 Parent(s): 2cfc97b

Create test_rag.py

Browse files
Files changed (1) hide show
  1. test_rag.py +364 -0
test_rag.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import warnings
3
+ import torch
4
+ import time
5
+ import math
6
+ import sqlite3 as sql
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from rag_metadata import SQLMetadataRetriever
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+ # Establish a database connection once (adjust the DB path as needed)
14
+ connection = sql.connect('./nba-data/nba.sqlite')
15
+ cursor = connection.cursor()
16
+
17
+ # ------------------------------
18
+ # Load dataset and print summary
19
+ # ------------------------------
20
+ df = pd.read_csv("./train-data/sql_train.tsv", sep='\t')
21
+ print("Total dataset examples: " + str(len(df)))
22
+ print("\n")
23
+
24
+ # ------------------------------
25
+ # Load tokenizer and model
26
+ # ------------------------------
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ "./deepseek-coder-1.3b-instruct",
31
+ torch_dtype=torch.bfloat16,
32
+ device_map=device
33
+ )
34
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
35
+
36
+ # ------------------------------
37
+ # Initialize RAG retriever and load schema metadata
38
+ # ------------------------------
39
+ retriever = SQLMetadataRetriever()
40
+
41
+ metadata_docs = [
42
+ '''team Table
43
+ Stores information about NBA teams.
44
+ CREATE TABLE IF NOT EXISTS "team" (
45
+ "id" TEXT PRIMARY KEY, -- Unique identifier for the team
46
+ "full_name" TEXT, -- Full official name of the team (e.g., "Los Angeles Lakers")
47
+ "abbreviation" TEXT, -- Shortened team name (e.g., "LAL")
48
+ "nickname" TEXT, -- Commonly used nickname for the team (e.g., "Lakers")
49
+ "city" TEXT, -- City where the team is based
50
+ "state" TEXT, -- State where the team is located
51
+ "year_founded" REAL -- Year the team was established
52
+ );''',
53
+ '''game Table
54
+ Contains detailed statistics for each NBA game, including home and away team performance.
55
+ CREATE TABLE IF NOT EXISTS "game" (
56
+ "season_id" TEXT, -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season)
57
+ "team_id_home" TEXT, -- ID of the home team (matches "id" in team table)
58
+ "team_abbreviation_home" TEXT, -- Abbreviation of the home team
59
+ "team_name_home" TEXT, -- Full name of the home team
60
+ "game_id" TEXT PRIMARY KEY, -- Unique identifier for the game
61
+ "game_date" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)
62
+ "matchup_home" TEXT, -- Matchup details including opponent (e.g., "LAL vs. BOS")
63
+ "wl_home" TEXT, -- "W" if the home team won, "L" if they lost
64
+ "min" INTEGER, -- Total minutes played in the game
65
+ "fgm_home" REAL, -- Field goals made by the home team
66
+ "fga_home" REAL, -- Field goals attempted by the home team
67
+ "fg_pct_home" REAL, -- Field goal percentage of the home team
68
+ "fg3m_home" REAL, -- Three-point field goals made by the home team
69
+ "fg3a_home" REAL, -- Three-point attempts by the home team
70
+ "fg3_pct_home" REAL, -- Three-point field goal percentage of the home team
71
+ "ftm_home" REAL, -- Free throws made by the home team
72
+ "fta_home" REAL, -- Free throws attempted by the home team
73
+ "ft_pct_home" REAL, -- Free throw percentage of the home team
74
+ "oreb_home" REAL, -- Offensive rebounds by the home team
75
+ "dreb_home" REAL, -- Defensive rebounds by the home team
76
+ "reb_home" REAL, -- Total rebounds by the home team
77
+ "ast_home" REAL, -- Assists by the home team
78
+ "stl_home" REAL, -- Steals by the home team
79
+ "blk_home" REAL, -- Blocks by the home team
80
+ "tov_home" REAL, -- Turnovers by the home team
81
+ "pf_home" REAL, -- Personal fouls by the home team
82
+ "pts_home" REAL, -- Total points scored by the home team
83
+ "plus_minus_home" INTEGER, -- Plus/minus rating for the home team
84
+ "video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
85
+ "team_id_away" TEXT, -- ID of the away team
86
+ "team_abbreviation_away" TEXT, -- Abbreviation of the away team
87
+ "team_name_away" TEXT, -- Full name of the away team
88
+ "matchup_away" TEXT, -- Matchup details from the away team’s perspective
89
+ "wl_away" TEXT, -- "W" if the away team won, "L" if they lost
90
+ "fgm_away" REAL, -- Field goals made by the away team
91
+ "fga_away" REAL, -- Field goals attempted by the away team
92
+ "fg_pct_away" REAL, -- Field goal percentage of the away team
93
+ "fg3m_away" REAL, -- Three-point field goals made by the away team
94
+ "fg3a_away" REAL, -- Three-point attempts by the away team
95
+ "fg3_pct_away" REAL, -- Three-point field goal percentage of the away team
96
+ "ftm_away" REAL, -- Free throws made by the away team
97
+ "fta_away" REAL, -- Free throws attempted by the away team
98
+ "ft_pct_away" REAL, -- Free throw percentage of the away team
99
+ "oreb_away" REAL, -- Offensive rebounds by the away team
100
+ "dreb_away" REAL, -- Defensive rebounds by the away team
101
+ "reb_away" REAL, -- Total rebounds by the away team
102
+ "ast_away" REAL, -- Assists by the away team
103
+ "stl_away" REAL, -- Steals by the away team
104
+ "blk_away" REAL, -- Blocks by the away team
105
+ "tov_away" REAL, -- Turnovers by the away team
106
+ "pf_away" REAL, -- Personal fouls by the away team
107
+ "pts_away" REAL, -- Total points scored by the away team
108
+ "plus_minus_away" INTEGER, -- Plus/minus rating for the away team
109
+ "video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
110
+ "season_type" TEXT -- Regular season or playoffs
111
+ );
112
+ ''',
113
+ '''other_stats Table
114
+ Stores additional statistics, linked to the game table via game_id.
115
+ CREATE TABLE IF NOT EXISTS "other_stats" (
116
+ "game_id" TEXT, -- Unique game identifier, matches id column from game table
117
+ "league_id" TEXT, -- League identifier
118
+ "team_id_home" TEXT, -- Home team identifier
119
+ "team_abbreviation_home" TEXT, -- Home team abbreviation
120
+ "team_city_home" TEXT, -- Home team city
121
+ "pts_paint_home" INTEGER, -- Points in the paint by the home team
122
+ "pts_2nd_chance_home" INTEGER, -- Second chance points by the home team
123
+ "pts_fb_home" INTEGER, -- Fast break points by the home team
124
+ "largest_lead_home" INTEGER,-- Largest lead by the home team
125
+ "lead_changes" INTEGER, -- Number of lead changes
126
+ "times_tied" INTEGER, -- Number of times the score was tied
127
+ "team_turnovers_home" INTEGER, -- Home team turnovers
128
+ "total_turnovers_home" INTEGER, -- Total turnovers by the home team
129
+ "team_rebounds_home" INTEGER, -- Home team rebounds
130
+ "pts_off_to_home" INTEGER, -- Points off turnovers by the home team
131
+ "team_id_away" TEXT, -- Away team identifier
132
+ "team_abbreviation_away" TEXT, -- Away team abbreviation
133
+ "pts_paint_away" INTEGER, -- Points in the paint by the away team
134
+ "pts_2nd_chance_away" INTEGER, -- Second chance points by the away team
135
+ "pts_fb_away" INTEGER, -- Fast break points by the away team
136
+ "largest_lead_away" INTEGER,-- Largest lead by the away team
137
+ "team_turnovers_away" INTEGER, -- Away team turnovers
138
+ "total_turnovers_away" INTEGER, -- Total turnovers by the away team
139
+ "team_rebounds_away" INTEGER, -- Away team rebounds
140
+ "pts_off_to_away" INTEGER -- Points off turnovers by the away team
141
+ );
142
+ ''',
143
+ '''Team Name Information
144
+ In plaintext user questions, only the full team names will be used, but in the queries you may use either full names or abbreviations.
145
+ Full names are used with the game table, while abbreviations should be used with the other_stats table.
146
+ Team names and abbreviations (separated by |):
147
+ Atlanta Hawks|ATL, Boston Celtics|BOS, Cleveland Cavaliers|CLE, New Orleans Pelicans|NOP,
148
+ Chicago Bulls|CHI, Dallas Mavericks|DAL, Denver Nuggets|DEN, Golden State Warriors|GSW,
149
+ Houston Rockets|HOU, Los Angeles Clippers|LAC, Los Angeles Lakers|LAL, Miami Heat|MIA,
150
+ Milwaukee Bucks|MIL, Minnesota Timberwolves|MIN, Brooklyn Nets|BKN, New York Knicks|NYK,
151
+ Orlando Magic|ORL, Indiana Pacers|IND, Philadelphia 76ers|PHI, Phoenix Suns|PHX,
152
+ Portland Trail Blazers|POR, Sacramento Kings|SAC, San Antonio Spurs|SAS,
153
+ Oklahoma City Thunder|OKC, Toronto Raptors|TOR, Utah Jazz|UTA, Memphis Grizzlies|MEM,
154
+ Washington Wizards|WAS, Detroit Pistons|DET, Charlotte Hornets|CHA
155
+ '''
156
+ ]
157
+
158
+ retriever.add_documents(metadata_docs)
159
+
160
+ # ------------------------------
161
+ # Define a function to compare model output to ground truth
162
+ # ------------------------------
163
+ def compare_result(sample_query, sample_result, generated_query, actual_result):
164
+ # Remove any prefixes from the generated query
165
+ if generated_query.startswith("SQLite: "):
166
+ query = generated_query[len("SQLite: "):]
167
+ elif generated_query.startswith("SQL: "):
168
+ query = generated_query[len("SQL: "):]
169
+ else:
170
+ query = generated_query
171
+
172
+ # Truncate query after the first semicolon (if present)
173
+ semicolon_index = query.find(";")
174
+ if semicolon_index != -1:
175
+ query = query[:semicolon_index+1]
176
+
177
+ # Simple function to clean strings: removes whitespace and lowercases.
178
+ clean_str = lambda s: "".join(s.split()).lower()
179
+
180
+ # Compare the generated query text with the sample query.
181
+ query_match = (clean_str(query) == clean_str(sample_query))
182
+
183
+ # Compare the expected result and the actual result numerically if possible.
184
+ try:
185
+ sample_val = float(sample_result)
186
+ actual_val = float(actual_result)
187
+ result_match = math.isclose(sample_val, actual_val, abs_tol=1e-6)
188
+ except Exception:
189
+ # Otherwise, do a cleaned string comparison.
190
+ result_match = (clean_str(str(sample_result)) == clean_str(str(actual_result)))
191
+
192
+ overall_valid = query_match and result_match
193
+
194
+ # Debug output.
195
+ print("DEBUG: Expected Result (from dataset):", sample_result)
196
+ print("DEBUG: Actual DB Result:", actual_result)
197
+ try:
198
+ sample_val = float(sample_result)
199
+ actual_val = float(actual_result)
200
+ print("DEBUG: Numeric Comparison result:", math.isclose(sample_val, actual_val, abs_tol=1e-6))
201
+ except Exception:
202
+ print("DEBUG: Numeric Comparison: N/A")
203
+
204
+ return overall_valid, query_match, result_match
205
+
206
+
207
+ # ------------------------------
208
+ # Function to evaluate the model on a given dataset
209
+ # ------------------------------
210
+ def run_evaluation(nba_df, title):
211
+ counter = 0
212
+ num_valid = 0
213
+ num_sql_matched = 0
214
+ num_result_matched = 0
215
+ for index, row in nba_df.iterrows():
216
+ # Retrieve relevant schema chunks via RAG
217
+ relevant_schemas = retriever.retrieve(row["natural_query"], top_k=2)
218
+ schema_block = "\n\n".join(relevant_schemas)
219
+
220
+ # Build the prompt with instructions, schema, examples, and current request.
221
+ input_text = f"""
222
+ You are an AI assistant that generates SQL queries for an NBA database based on user questions.
223
+
224
+ ### Relevant Schema:
225
+ {schema_block}
226
+
227
+ ### Instructions:
228
+ - Generate a valid SQL query to retrieve relevant data from the database.
229
+ - Use column names correctly based on the provided schema.
230
+ - Output only the SQL query as plain text.
231
+
232
+ ### Example Queries:
233
+ Use team_name_home and team_name_away to match teams to the game table.
234
+ Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.
235
+ To filter by season, use season_id = '2YYYY'.
236
+ Example: season_id = '22005' for 2005.
237
+ Ensure queries return relevant columns and avoid unnecessary joins.
238
+
239
+ Example User Requests and SQLite Queries
240
+ Request:
241
+ "What is the most points the Los Angeles Lakers have ever scored at home?"
242
+ SQLite:
243
+ SELECT MAX(pts_home)
244
+ FROM game
245
+ WHERE team_name_home = 'Los Angeles Lakers';
246
+
247
+ Request:
248
+ "Which teams are located in the state of California?"
249
+ SQLite:
250
+ SELECT full_name FROM team WHERE state = 'California';
251
+
252
+ Request:
253
+ "Which team had the highest number of team turnovers in an away game?"
254
+ SQLite:
255
+ SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;
256
+
257
+ Request:
258
+ "Which teams were founded before 1979?"
259
+ SQLite:
260
+ SELECT full_name FROM team WHERE year_founded < 1979;
261
+
262
+ Request:
263
+ "Find the Boston Celtics largest home victory margin in the 2008 season."
264
+ SQLite:
265
+ SELECT MAX(pts_home - pts_away) AS biggest_win
266
+ FROM game
267
+ WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';
268
+
269
+ Generate only the SQLite query prefaced by SQLite: and no other text. Now generate an SQLite query for the following user request.
270
+ Request: {row["natural_query"]}
271
+ """
272
+ messages = [{'role': 'user', 'content': input_text}]
273
+ prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
274
+ inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(model.device)
275
+
276
+ outputs = model.generate(
277
+ **inputs,
278
+ max_new_tokens=512,
279
+ do_sample=False,
280
+ top_k=50,
281
+ top_p=0.95,
282
+ num_return_sequences=1,
283
+ eos_token_id=tokenizer.eos_token_id,
284
+ pad_token_id=tokenizer.eos_token_id
285
+ )
286
+
287
+ # Decode the model output.
288
+ generated_query = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
289
+
290
+ # Clean generated query: remove any prefix and truncate after first semicolon.
291
+ if generated_query.startswith("SQLite:"):
292
+ clean_query = generated_query[len("SQLite:"):].strip()
293
+ elif generated_query.startswith("SQL:"):
294
+ clean_query = generated_query[len("SQL:"):].strip()
295
+ else:
296
+ clean_query = generated_query.strip()
297
+
298
+ semicolon_idx = clean_query.find(";")
299
+ if semicolon_idx != -1:
300
+ clean_query = clean_query[:semicolon_idx+1]
301
+
302
+ # Execute the cleaned query on the SQLite DB to obtain the actual result.
303
+ try:
304
+ cursor.execute(clean_query)
305
+ rows = cursor.fetchall()
306
+ if rows and isinstance(rows[0], (tuple, list)) and len(rows[0]) > 0:
307
+ actual_result = rows[0][0]
308
+ elif rows:
309
+ actual_result = rows[0]
310
+ else:
311
+ actual_result = ""
312
+ except Exception as e:
313
+ actual_result = "Error executing query: " + str(e)
314
+
315
+ # Compare the ground truth query and expected result to the generated query and actual result.
316
+ valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], generated_query, actual_result)
317
+ print("=============================================")
318
+ print(f"Overall Valid: {valid}")
319
+ print(f"SQL Query Matched: {sql_matched}")
320
+ print(f"Result Matched: {result_matched}")
321
+ print("=============================================\n")
322
+
323
+ # Print debug output.
324
+ print("----- Ground Truth SQL Query -----")
325
+ print(row["sql_query"])
326
+ print("------------------------------------\n")
327
+ print("----- Model Generated SQL Query -----")
328
+ print(generated_query)
329
+ print("---------------------------------------\n")
330
+
331
+ print("----- Expected Result -----")
332
+ print(row["result"])
333
+ print("----- Actual DB Result -----")
334
+ print(actual_result)
335
+ print("-------------------------------------------------\n")
336
+
337
+ if valid:
338
+ num_valid += 1
339
+ if sql_matched:
340
+ num_sql_matched += 1
341
+ if result_matched:
342
+ num_result_matched += 1
343
+
344
+ counter += 1
345
+
346
+ # CONTROL ITERS
347
+ # if counter == 2:
348
+ # break
349
+
350
+ if counter % 50 == 0:
351
+ print("Completed " + str(counter))
352
+
353
+ print("\n" + title + " results:")
354
+ print("Percent valid: " + str(num_valid / len(nba_df)))
355
+ print("Percent SQLite matched: " + str(num_sql_matched / len(nba_df)))
356
+ print("Percent result matched: " + str(num_result_matched / len(nba_df)))
357
+ print("Dataset length: " + str(len(nba_df)))
358
+
359
+
360
+ # ------------------------------
361
+ # Run evaluation on the full training dataset
362
+ # ------------------------------
363
+ run_evaluation(df, "All training data")
364
+ print("Dataset length: " + str(len(df)))