File size: 10,016 Bytes
4be750a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First load dataset into pandas dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total dataset examples: 1044\n",
      "\n",
      "\n",
      "What is the highest number of assists recorded by the Indiana Pacers in a single home game?\n",
      "SELECT MAX(ast_home)  FROM game  WHERE team_name_home = 'Indiana Pacers';\n",
      "44.0\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd \n",
    "\n",
    "# Load dataset and check length\n",
    "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
    "print(\"Total dataset examples: \" + str(len(df)))\n",
    "print(\"\\n\")\n",
    "\n",
    "# Test sampling\n",
    "sample = df.sample(n=1)\n",
    "print(sample[\"natural_query\"].values[0])\n",
    "print(sample[\"sql_query\"].values[0])\n",
    "print(sample[\"result\"].values[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pre-trained DeepSeek model using transformers and pytorch packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "# Set device to cuda if available, otherwise CPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load model and tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create prompt to setup the model for better performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = \"\"\"You are an AI assistant that generates SQLite queries for an NBA database based on user questions. The database consists of two tables:\n",
    "\n",
    "1. `team` - Stores information about NBA teams.\n",
    "   - `id`: Unique team identifier.\n",
    "   - `full_name`: Full team name (e.g., \"Los Angeles Lakers\").\n",
    "   - `abbreviation`: 3-letter team code (e.g., \"LAL\").\n",
    "   - `city`, `state`: Location of the team.\n",
    "   - `year_founded`: The year the team was founded.\n",
    "\n",
    "2. `game` - Stores details of individual games.\n",
    "   - `game_date`: Date of the game.\n",
    "   - `team_id_home`, `team_id_away`: Unique IDs of home and away teams.\n",
    "   - `team_name_home`, `team_name_away`: Full names of the teams.\n",
    "   - `pts_home`, `pts_away`: Points scored by home and away teams.\n",
    "   - `wl_home`: \"W\" if the home team won, \"L\" if they lost.\n",
    "   - `reb_home`, `reb_away`: Total rebounds.\n",
    "   - `ast_home`, `ast_away`: Total assists.\n",
    "   - Other statistics include field goals (`fgm_home`, `fg_pct_home`), three-pointers (`fg3m_home`), free throws (`ftm_home`), and turnovers (`tov_home`).\n",
    "\n",
    "### Instructions:\n",
    "- Generate a valid SQLite query to retrieve relevant data from the database.\n",
    "- Use column names correctly based on the provided schema.\n",
    "- Ensure the query is well-structured and avoids unnecessary joins.\n",
    "- Format the query with proper indentation.\n",
    "\n",
    "### Example Queries:\n",
    "User: \"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
    "SQLite:\n",
    "SELECT MAX(pts_home) \n",
    "FROM game \n",
    "WHERE team_name_home = 'Los Angeles Lakers';\n",
    "\n",
    "User: \"List all games where the Golden State Warriors scored more than 130 points.\" \n",
    "SQLite:\n",
    "SELECT game_date, team_name_home, pts_home, team_name_away, pts_away\n",
    "FROM game\n",
    "WHERE (team_name_home = 'Golden State Warriors' AND pts_home > 130)\n",
    "   OR (team_name_away = 'Golden State Warriors' AND pts_away > 130);\n",
    "   \n",
    "Now, generate a SQL query based on the following user request: \"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test model performance on a single example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n",
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SQLite:\n",
      "SELECT MAX(ast_home) \n",
      "FROM game \n",
      "WHERE team_name_home = 'Indiana Pacers';\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Create message with sample query and run model\n",
    "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
    "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
    "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "# Print output\n",
    "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
    "print(query_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test sample output on sqlite3 database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cleaned\n",
      "(44.0,)\n"
     ]
    }
   ],
   "source": [
    "import sqlite3 as sql\n",
    "\n",
    "# Create connection to sqlite3 database\n",
    "connection = sql.connect('./nba-data/nba.sqlite')\n",
    "cursor = connection.cursor()\n",
    "\n",
    "# Execute query from model output and print result\n",
    "if query_output[0:7] == \"SQLite:\":\n",
    "    print(\"cleaned\")\n",
    "    query = query_output[7:]\n",
    "elif query_output[0:4] == \"SQL:\":\n",
    "    query = query_output[4:]\n",
    "else:\n",
    "    query = query_output\n",
    "cursor.execute(query)\n",
    "rows = cursor.fetchall()\n",
    "for row in rows:\n",
    "    print(row)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create function to compare output to ground truth result from examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cleaned\n",
      "[(44.0,)]\n",
      "\n",
      "SELECT MAX(ast_home) \n",
      "FROM game \n",
      "WHERE team_name_home = 'Indiana Pacers';\n",
      "\n",
      "SELECT MAX(ast_home)  FROM game  WHERE team_name_home = 'Indiana Pacers';\n",
      "44.0\n",
      "44.0\n",
      "SQL matched? True\n",
      "Result matched? True\n"
     ]
    }
   ],
   "source": [
    "def compare_result(sample_query, sample_result, query_output):\n",
    "    # Clean model output to only have the query output\n",
    "    if query_output[0:7] == \"SQLite:\":\n",
    "        query = query_output[7:]\n",
    "    elif query_output[0:4] == \"SQL:\":\n",
    "        query = query_output[4:]\n",
    "    else:\n",
    "        query = query_output\n",
    "    \n",
    "    # Try to execute query, if it fails, then this is a failure of the model\n",
    "    try:\n",
    "        # Execute query and obtain result\n",
    "        cursor.execute(query)\n",
    "        rows = cursor.fetchall()\n",
    "\n",
    "        # Check if this is a multi-line query\n",
    "        if \"|\" in sample_result:\n",
    "            return True, True\n",
    "        else:\n",
    "            # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n",
    "            query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
    "            sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
    "\n",
    "            # Compare results and return\n",
    "            return (query == sample_query), (str(rows[0][0]) == str(sample_result))\n",
    "    except:\n",
    "        return False, False\n",
    "\n",
    "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
    "print(\"SQL matched? \" + str(result[0]))\n",
    "print(\"Result matched? \" + str(result[1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}