|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from rag_metadata import SQLMetadataRetriever |
|
import torch |
|
import time |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pretrain_path = "./deepseek-coder-1.3b-instruct" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrain_path) |
|
model = AutoModelForCausalLM.from_pretrained(pretrain_path, torch_dtype=torch.bfloat16, device_map=device) |
|
|
|
|
|
retriever = SQLMetadataRetriever() |
|
metadata_docs2 = [ |
|
"Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.", |
|
"Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics." |
|
] |
|
metadata_docs = [ |
|
'''team Table |
|
Stores information about NBA teams. |
|
CREATE TABLE IF NOT EXISTS "team" ( |
|
"id" TEXT PRIMARY KEY, -- Unique identifier for the team |
|
"full_name" TEXT, -- Full official name of the team (e.g., "Los Angeles Lakers") |
|
"abbreviation" TEXT, -- Shortened team name (e.g., "LAL") |
|
"nickname" TEXT, -- Commonly used nickname for the team (e.g., "Lakers") |
|
"city" TEXT, -- City where the team is based |
|
"state" TEXT, -- State where the team is located |
|
"year_founded" REAL -- Year the team was established |
|
);''', |
|
''' |
|
game Table |
|
Contains detailed statistics for each NBA game, including home and away team performance. |
|
CREATE TABLE IF NOT EXISTS "game" ( |
|
"season_id" TEXT, -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season) |
|
"team_id_home" TEXT, -- ID of the home team (matches "id" in team table) |
|
"team_abbreviation_home" TEXT, -- Abbreviation of the home team |
|
"team_name_home" TEXT, -- Full name of the home team |
|
"game_id" TEXT PRIMARY KEY, -- Unique identifier for the game |
|
"game_date" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format) |
|
"matchup_home" TEXT, -- Matchup details including opponent (e.g., "LAL vs. BOS") |
|
"wl_home" TEXT, -- "W" if the home team won, "L" if they lost |
|
"min" INTEGER, -- Total minutes played in the game |
|
"fgm_home" REAL, -- Field goals made by the home team |
|
"fga_home" REAL, -- Field goals attempted by the home team |
|
"fg_pct_home" REAL, -- Field goal percentage of the home team |
|
"fg3m_home" REAL, -- Three-point field goals made by the home team |
|
"fg3a_home" REAL, -- Three-point attempts by the home team |
|
"fg3_pct_home" REAL, -- Three-point field goal percentage of the home team |
|
"ftm_home" REAL, -- Free throws made by the home team |
|
"fta_home" REAL, -- Free throws attempted by the home team |
|
"ft_pct_home" REAL, -- Free throw percentage of the home team |
|
"oreb_home" REAL, -- Offensive rebounds by the home team |
|
"dreb_home" REAL, -- Defensive rebounds by the home team |
|
"reb_home" REAL, -- Total rebounds by the home team |
|
"ast_home" REAL, -- Assists by the home team |
|
"stl_home" REAL, -- Steals by the home team |
|
"blk_home" REAL, -- Blocks by the home team |
|
"tov_home" REAL, -- Turnovers by the home team |
|
"pf_home" REAL, -- Personal fouls by the home team |
|
"pts_home" REAL, -- Total points scored by the home team |
|
"plus_minus_home" INTEGER, -- Plus/minus rating for the home team |
|
"video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No) |
|
"team_id_away" TEXT, -- ID of the away team |
|
"team_abbreviation_away" TEXT, -- Abbreviation of the away team |
|
"team_name_away" TEXT, -- Full name of the away team |
|
"matchup_away" TEXT, -- Matchup details from the away team’s perspective |
|
"wl_away" TEXT, -- "W" if the away team won, "L" if they lost |
|
"fgm_away" REAL, -- Field goals made by the away team |
|
"fga_away" REAL, -- Field goals attempted by the away team |
|
"fg_pct_away" REAL, -- Field goal percentage of the away team |
|
"fg3m_away" REAL, -- Three-point field goals made by the away team |
|
"fg3a_away" REAL, -- Three-point attempts by the away team |
|
"fg3_pct_away" REAL, -- Three-point field goal percentage of the away team |
|
"ftm_away" REAL, -- Free throws made by the away team |
|
"fta_away" REAL, -- Free throws attempted by the away team |
|
"ft_pct_away" REAL, -- Free throw percentage of the away team |
|
"oreb_away" REAL, -- Offensive rebounds by the away team |
|
"dreb_away" REAL, -- Defensive rebounds by the away team |
|
"reb_away" REAL, -- Total rebounds by the away team |
|
"ast_away" REAL, -- Assists by the away team |
|
"stl_away" REAL, -- Steals by the away team |
|
"blk_away" REAL, -- Blocks by the away team |
|
"tov_away" REAL, -- Turnovers by the away team |
|
"pf_away" REAL, -- Personal fouls by the away team |
|
"pts_away" REAL, -- Total points scored by the away team |
|
"plus_minus_away" INTEGER, -- Plus/minus rating for the away team |
|
"video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No) |
|
"season_type" TEXT -- Regular season or playoffs |
|
); |
|
''', |
|
''' |
|
other_stats Table |
|
Stores additional statistics, linked to the game table via game_id. |
|
CREATE TABLE IF NOT EXISTS "other_stats" ( |
|
"game_id" TEXT, -- Unique game identifier, matches id column from game table |
|
"league_id" TEXT, -- League identifier |
|
"team_id_home" TEXT, -- Home team identifier |
|
"team_abbreviation_home" TEXT, -- Home team abbreviation |
|
"team_city_home" TEXT, -- Home team city |
|
"pts_paint_home" INTEGER, -- Points in the paint by the home team |
|
"pts_2nd_chance_home" INTEGER, -- Second chance points by the home team |
|
"pts_fb_home" INTEGER, -- Fast break points by the home team |
|
"largest_lead_home" INTEGER,-- Largest lead by the home team |
|
"lead_changes" INTEGER, -- Number of lead changes |
|
"times_tied" INTEGER, -- Number of times the score was tied |
|
"team_turnovers_home" INTEGER, -- Home team turnovers |
|
"total_turnovers_home" INTEGER, -- Total turnovers by the home team |
|
"team_rebounds_home" INTEGER, -- Home team rebounds |
|
"pts_off_to_home" INTEGER, -- Points off turnovers by the home team |
|
"team_id_away" TEXT, -- Away team identifier |
|
"team_abbreviation_away" TEXT, -- Away team abbreviation |
|
"pts_paint_away" INTEGER, -- Points in the paint by the away team |
|
"pts_2nd_chance_away" INTEGER, -- Second chance points by the away team |
|
"pts_fb_away" INTEGER, -- Fast break points by the away team |
|
"largest_lead_away" INTEGER,-- Largest lead by the away team |
|
"team_turnovers_away" INTEGER, -- Away team turnovers |
|
"total_turnovers_away" INTEGER, -- Total turnovers by the away team |
|
"team_rebounds_away" INTEGER, -- Away team rebounds |
|
"pts_off_to_away" INTEGER -- Points off turnovers by the away team |
|
); |
|
''', |
|
''' |
|
Team Name Information |
|
In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. |
|
The full team names can be used with the game table, while the abbreviations should be used with the other_stats table. |
|
Notice they are separated by the | character in the following list: |
|
|
|
Atlanta Hawks|ATL |
|
Boston Celtics|BOS |
|
Cleveland Cavaliers|CLE |
|
New Orleans Pelicans|NOP |
|
Chicago Bulls|CHI |
|
Dallas Mavericks|DAL |
|
Denver Nuggets|DEN |
|
Golden State Warriors|GSW |
|
Houston Rockets|HOU |
|
Los Angeles Clippers|LAC |
|
Los Angeles Lakers|LAL |
|
Miami Heat|MIA |
|
Milwaukee Bucks|MIL |
|
Minnesota Timberwolves|MIN |
|
Brooklyn Nets|BKN |
|
New York Knicks|NYK |
|
Orlando Magic|ORL |
|
Indiana Pacers|IND |
|
Philadelphia 76ers|PHI |
|
Phoenix Suns|PHX |
|
Portland Trail Blazers|POR |
|
Sacramento Kings|SAC |
|
San Antonio Spurs|SAS |
|
Oklahoma City Thunder|OKC |
|
Toronto Raptors|TOR |
|
Utah Jazz|UTA |
|
Memphis Grizzlies|MEM |
|
Washington Wizards|WAS |
|
Detroit Pistons|DET |
|
Charlotte Hornets|CHA |
|
''' |
|
] |
|
retriever.add_documents(metadata_docs) |
|
|
|
|
|
user_question = "What is the most points ever scored by the New York Knicks at home?" |
|
|
|
|
|
relevant_schemas = retriever.retrieve(user_question, top_k=2) |
|
|
|
print("---------------------------------------------") |
|
print("INFO: Retrieved relevant documents from RAG:") |
|
print("") |
|
for i, doc in enumerate(relevant_schemas): |
|
print("Relevant doc -> ", i + 1) |
|
print(doc) |
|
print("---------------------------------------------") |
|
|
|
|
|
schema_block = "\n\n".join(relevant_schemas) |
|
|
|
|
|
input_text = f""" |
|
You are an AI assistant that generates SQL queries for an NBA database based on user questions. |
|
|
|
### Relevant Schema: |
|
{schema_block} |
|
|
|
### Instructions: |
|
- Generate a valid SQL query to retrieve relevant data from the database. |
|
- Use column names correctly based on the provided schema. |
|
- Output only the SQL query as plain text. |
|
|
|
### Example Queries: |
|
Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table. |
|
|
|
To filter by season, use season_id = '2YYYY'. |
|
|
|
Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = "21972". To get statistics from 2015, use a statement like: season_id = "22015". |
|
|
|
Ensure queries return relevant columns and avoid unnecessary joins. |
|
|
|
Example User Requests and SQLite Queries |
|
Request: |
|
"What is the most points the Los Angeles Lakers have ever scored at home?" |
|
SQLite: |
|
SELECT MAX(pts_home) |
|
FROM game |
|
WHERE team_name_home = 'Los Angeles Lakers'; |
|
|
|
Request: |
|
"Which teams are located in the state of California?" |
|
SQLite: |
|
SELECT full_name FROM team WHERE state = 'California'; |
|
|
|
Request: |
|
"Which team had the highest number of team turnovers in an away game?" |
|
SQLite: |
|
SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1; |
|
|
|
Request: |
|
"Which teams were founded before 1979?" |
|
SQLite: |
|
SELECT full_name FROM team WHERE year_founded < 1979; |
|
|
|
Request: |
|
"Find the Boston Celtics largest home victory margin in the 2008 season." |
|
SQLite: |
|
SELECT MAX(pts_home - pts_away) AS biggest_win |
|
FROM game |
|
WHERE team_name_home = 'Boston Celtics' AND season_id = '22008'; |
|
|
|
Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request: |
|
{user_question} |
|
""" |
|
|
|
|
|
messages = [{ 'role': 'user', 'content': input_text }] |
|
prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
inputs = tokenizer(prompt_text, return_tensors="pt", padding=True).to(model.device) |
|
|
|
|
|
start_time = time.time() |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
num_return_sequences=1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
end_time = time.time() |
|
|
|
|
|
print("Natural Language Query: ", user_question) |
|
print("") |
|
|
|
generated = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) |
|
print("Generated SQL Query:\n") |
|
print(generated) |
|
print("\nExecution time:", round(end_time - start_time, 2), "seconds") |
|
|