DeanGumas commited on
Commit
4be750a
·
1 Parent(s): c5c0647

Renamed demo script and added initial pre-trained test python notebook

Browse files
Files changed (2) hide show
  1. test.py → demo.py +0 -0
  2. test_pretrained.ipynb +303 -0
test.py → demo.py RENAMED
File without changes
test_pretrained.ipynb ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## First load dataset into pandas dataframe"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 83,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stdout",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "Total dataset examples: 1044\n",
27
+ "\n",
28
+ "\n",
29
+ "What is the highest number of assists recorded by the Indiana Pacers in a single home game?\n",
30
+ "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n",
31
+ "44.0\n"
32
+ ]
33
+ }
34
+ ],
35
+ "source": [
36
+ "import pandas as pd \n",
37
+ "\n",
38
+ "# Load dataset and check length\n",
39
+ "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
40
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
41
+ "print(\"\\n\")\n",
42
+ "\n",
43
+ "# Test sampling\n",
44
+ "sample = df.sample(n=1)\n",
45
+ "print(sample[\"natural_query\"].values[0])\n",
46
+ "print(sample[\"sql_query\"].values[0])\n",
47
+ "print(sample[\"result\"].values[0])"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": [
54
+ "## Load pre-trained DeepSeek model using transformers and pytorch packages"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 84,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
64
+ "import torch\n",
65
+ "\n",
66
+ "# Set device to cuda if available, otherwise CPU\n",
67
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
68
+ "\n",
69
+ "# Load model and tokenizer\n",
70
+ "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
71
+ "model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) "
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {},
77
+ "source": [
78
+ "## Create prompt to setup the model for better performance"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 85,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "input_text = \"\"\"You are an AI assistant that generates SQLite queries for an NBA database based on user questions. The database consists of two tables:\n",
88
+ "\n",
89
+ "1. `team` - Stores information about NBA teams.\n",
90
+ " - `id`: Unique team identifier.\n",
91
+ " - `full_name`: Full team name (e.g., \"Los Angeles Lakers\").\n",
92
+ " - `abbreviation`: 3-letter team code (e.g., \"LAL\").\n",
93
+ " - `city`, `state`: Location of the team.\n",
94
+ " - `year_founded`: The year the team was founded.\n",
95
+ "\n",
96
+ "2. `game` - Stores details of individual games.\n",
97
+ " - `game_date`: Date of the game.\n",
98
+ " - `team_id_home`, `team_id_away`: Unique IDs of home and away teams.\n",
99
+ " - `team_name_home`, `team_name_away`: Full names of the teams.\n",
100
+ " - `pts_home`, `pts_away`: Points scored by home and away teams.\n",
101
+ " - `wl_home`: \"W\" if the home team won, \"L\" if they lost.\n",
102
+ " - `reb_home`, `reb_away`: Total rebounds.\n",
103
+ " - `ast_home`, `ast_away`: Total assists.\n",
104
+ " - Other statistics include field goals (`fgm_home`, `fg_pct_home`), three-pointers (`fg3m_home`), free throws (`ftm_home`), and turnovers (`tov_home`).\n",
105
+ "\n",
106
+ "### Instructions:\n",
107
+ "- Generate a valid SQLite query to retrieve relevant data from the database.\n",
108
+ "- Use column names correctly based on the provided schema.\n",
109
+ "- Ensure the query is well-structured and avoids unnecessary joins.\n",
110
+ "- Format the query with proper indentation.\n",
111
+ "\n",
112
+ "### Example Queries:\n",
113
+ "User: \"What is the most points the Los Angeles Lakers have ever scored at home?\"\n",
114
+ "SQLite:\n",
115
+ "SELECT MAX(pts_home) \n",
116
+ "FROM game \n",
117
+ "WHERE team_name_home = 'Los Angeles Lakers';\n",
118
+ "\n",
119
+ "User: \"List all games where the Golden State Warriors scored more than 130 points.\" \n",
120
+ "SQLite:\n",
121
+ "SELECT game_date, team_name_home, pts_home, team_name_away, pts_away\n",
122
+ "FROM game\n",
123
+ "WHERE (team_name_home = 'Golden State Warriors' AND pts_home > 130)\n",
124
+ " OR (team_name_away = 'Golden State Warriors' AND pts_away > 130);\n",
125
+ " \n",
126
+ "Now, generate a SQL query based on the following user request: \"\"\""
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ "## Test model performance on a single example"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 86,
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stderr",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
146
+ " warnings.warn(\n",
147
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
148
+ "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n"
149
+ ]
150
+ },
151
+ {
152
+ "name": "stdout",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "SQLite:\n",
156
+ "SELECT MAX(ast_home) \n",
157
+ "FROM game \n",
158
+ "WHERE team_name_home = 'Indiana Pacers';\n",
159
+ "\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "# Create message with sample query and run model\n",
165
+ "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
166
+ "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
167
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
168
+ "\n",
169
+ "# Print output\n",
170
+ "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
171
+ "print(query_output)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {},
177
+ "source": [
178
+ "# Test sample output on sqlite3 database"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "outputs": [
186
+ {
187
+ "name": "stdout",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "cleaned\n",
191
+ "(44.0,)\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "import sqlite3 as sql\n",
197
+ "\n",
198
+ "# Create connection to sqlite3 database\n",
199
+ "connection = sql.connect('./nba-data/nba.sqlite')\n",
200
+ "cursor = connection.cursor()\n",
201
+ "\n",
202
+ "# Execute query from model output and print result\n",
203
+ "if query_output[0:7] == \"SQLite:\":\n",
204
+ " print(\"cleaned\")\n",
205
+ " query = query_output[7:]\n",
206
+ "elif query_output[0:4] == \"SQL:\":\n",
207
+ " query = query_output[4:]\n",
208
+ "else:\n",
209
+ " query = query_output\n",
210
+ "cursor.execute(query)\n",
211
+ "rows = cursor.fetchall()\n",
212
+ "for row in rows:\n",
213
+ " print(row)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {},
219
+ "source": [
220
+ "## Create function to compare output to ground truth result from examples"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "name": "stdout",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "cleaned\n",
233
+ "[(44.0,)]\n",
234
+ "\n",
235
+ "SELECT MAX(ast_home) \n",
236
+ "FROM game \n",
237
+ "WHERE team_name_home = 'Indiana Pacers';\n",
238
+ "\n",
239
+ "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n",
240
+ "44.0\n",
241
+ "44.0\n",
242
+ "SQL matched? True\n",
243
+ "Result matched? True\n"
244
+ ]
245
+ }
246
+ ],
247
+ "source": [
248
+ "def compare_result(sample_query, sample_result, query_output):\n",
249
+ " # Clean model output to only have the query output\n",
250
+ " if query_output[0:7] == \"SQLite:\":\n",
251
+ " query = query_output[7:]\n",
252
+ " elif query_output[0:4] == \"SQL:\":\n",
253
+ " query = query_output[4:]\n",
254
+ " else:\n",
255
+ " query = query_output\n",
256
+ " \n",
257
+ " # Try to execute query, if it fails, then this is a failure of the model\n",
258
+ " try:\n",
259
+ " # Execute query and obtain result\n",
260
+ " cursor.execute(query)\n",
261
+ " rows = cursor.fetchall()\n",
262
+ "\n",
263
+ " # Check if this is a multi-line query\n",
264
+ " if \"|\" in sample_result:\n",
265
+ " return True, True\n",
266
+ " else:\n",
267
+ " # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n",
268
+ " query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
269
+ " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
270
+ "\n",
271
+ " # Compare results and return\n",
272
+ " return (query == sample_query), (str(rows[0][0]) == str(sample_result))\n",
273
+ " except:\n",
274
+ " return False, False\n",
275
+ "\n",
276
+ "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
277
+ "print(\"SQL matched? \" + str(result[0]))\n",
278
+ "print(\"Result matched? \" + str(result[1]))"
279
+ ]
280
+ }
281
+ ],
282
+ "metadata": {
283
+ "kernelspec": {
284
+ "display_name": "Python 3",
285
+ "language": "python",
286
+ "name": "python3"
287
+ },
288
+ "language_info": {
289
+ "codemirror_mode": {
290
+ "name": "ipython",
291
+ "version": 3
292
+ },
293
+ "file_extension": ".py",
294
+ "mimetype": "text/x-python",
295
+ "name": "python",
296
+ "nbconvert_exporter": "python",
297
+ "pygments_lexer": "ipython3",
298
+ "version": "3.12.6"
299
+ }
300
+ },
301
+ "nbformat": 4,
302
+ "nbformat_minor": 2
303
+ }