licesma commited on
Commit
860fee5
·
1 Parent(s): 9a3c9fb

Notebooks for GPT evaluation

Browse files
__pycache__/rag_metadata.cpython-311.pyc ADDED
Binary file (3.73 kB). View file
 
chat_gpt_3.5.ipynb ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cf4403ec",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Notebook to evaluate ChatGPT Peformance"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "7f708eaa",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import pandas as pd\n",
19
+ "import warnings\n",
20
+ "import sqlite3 as sql\n",
21
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
22
+ "from huggingface_hub import snapshot_download\n",
23
+ "import sys\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "83a1bd00",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import os\n",
34
+ "os.environ[\"OPENAI_API_KEY\"] = \"<key>\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "b3a647bf",
40
+ "metadata": {},
41
+ "source": [
42
+ "## Set up path"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 17,
48
+ "id": "996e282d",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "is_google_colab=False"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 18,
58
+ "id": "5d96087b",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "current_path = \"./\"\n",
63
+ "\n",
64
+ "def get_path(rel_path):\n",
65
+ " return os.path.join(current_path, rel_path)\n",
66
+ "\n",
67
+ "if is_google_colab:\n",
68
+ " hugging_face_path = snapshot_download(\n",
69
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
70
+ " repo_type=\"model\", \n",
71
+ " allow_patterns=[\"src/*\", \"train-data/*\", \"deepseek-coder-1.3b-instruct/*\", \"nba-data/*\"], \n",
72
+ " )\n",
73
+ " sys.path.append(hugging_face_path)\n",
74
+ " current_path = hugging_face_path"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 19,
80
+ "id": "483da9f0",
81
+ "metadata": {},
82
+ "outputs": [
83
+ {
84
+ "data": {
85
+ "text/plain": [
86
+ "'./nba-data/nba.sqlite'"
87
+ ]
88
+ },
89
+ "execution_count": 19,
90
+ "metadata": {},
91
+ "output_type": "execute_result"
92
+ }
93
+ ],
94
+ "source": [
95
+ "get_path('nba-data/nba.sqlite')"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 20,
101
+ "id": "5cc9f19f",
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "name": "stdout",
106
+ "output_type": "stream",
107
+ "text": [
108
+ "Total dataset examples: 1044\n",
109
+ "\n",
110
+ "\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "\n",
116
+ "\n",
117
+ "warnings.filterwarnings(\"ignore\")\n",
118
+ "# Establish a database connection once (adjust the DB path as needed)\n",
119
+ "connection = sql.connect(get_path('nba-data/nba.sqlite'))\n",
120
+ "cursor = connection.cursor()\n",
121
+ "\n",
122
+ "# ------------------------------\n",
123
+ "# Load dataset and print summary\n",
124
+ "# ------------------------------\n",
125
+ "df = pd.read_csv(get_path(\"train-data/expanded_sql_train.tsv\"), sep='\\t')\n",
126
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
127
+ "print(\"\\n\")\n",
128
+ "\n",
129
+ "# ------------------------------\n",
130
+ "# Load tokenizer and model\n",
131
+ "# ------------------------------\n",
132
+ "\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "f2d859d8",
138
+ "metadata": {},
139
+ "source": [
140
+ "## Define compare result function for evaluation process"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 21,
146
+ "id": "a5295234",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "from src.evaluation.compare_result import compare_result\n",
151
+ "from src.rag.table_retriever import retrieve_doc"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "0a89a468",
157
+ "metadata": {},
158
+ "source": [
159
+ "## Create evaluation loop for ChatGPT"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 23,
165
+ "id": "e580dda8",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "from openai import OpenAI\n",
170
+ "client = OpenAI()"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 24,
176
+ "id": "69707ee7",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "# ------------------------------\n",
181
+ "# Function to evaluate the model on a given dataset\n",
182
+ "# ------------------------------\n",
183
+ "\n",
184
+ "from src.prompts.prompt import input_text\n",
185
+ "def run_evaluation(nba_df, title):\n",
186
+ " counter = 0\n",
187
+ " num_valid = 0\n",
188
+ " num_sql_matched = 0\n",
189
+ " num_result_matched = 0\n",
190
+ " for index, row in nba_df.iterrows():\n",
191
+ " # Retrieve relevant schema chunks via RAG\n",
192
+ "\n",
193
+ " response = client.chat.completions.create(\n",
194
+ " model=\"gpt-3.5-turbo\",\n",
195
+ " messages=[\n",
196
+ " {\"role\": \"user\", \"content\": input_text + row[\"natural_query\"]}\n",
197
+ " ]\n",
198
+ " )\n",
199
+ " \n",
200
+ " # Decode the model output.\n",
201
+ " generated_query = response.choices[0].message.content\n",
202
+ " \n",
203
+ " # Clean generated query: remove any prefix and truncate after first semicolon.\n",
204
+ " if generated_query.startswith(\"SQLite:\"):\n",
205
+ " clean_query = generated_query[len(\"SQLite:\"):].strip()\n",
206
+ " elif generated_query.startswith(\"SQL:\"):\n",
207
+ " clean_query = generated_query[len(\"SQL:\"):].strip()\n",
208
+ " else:\n",
209
+ " clean_query = generated_query.strip()\n",
210
+ " \n",
211
+ " semicolon_idx = clean_query.find(\";\")\n",
212
+ " if semicolon_idx != -1:\n",
213
+ " clean_query = clean_query[:semicolon_idx+1]\n",
214
+ " \n",
215
+ " # Execute the cleaned query on the SQLite DB to obtain the actual result.\n",
216
+ " \"\"\"\n",
217
+ " try:\n",
218
+ " cursor.execute(clean_query)\n",
219
+ " rows = cursor.fetchall()\n",
220
+ " if rows and isinstance(rows[0], (tuple, list)) and len(rows[0]) > 0:\n",
221
+ " actual_result = rows[0][0]\n",
222
+ " elif rows:\n",
223
+ " actual_result = rows[0]\n",
224
+ " else:\n",
225
+ " actual_result = \"\"\n",
226
+ " except Exception as e:\n",
227
+ " actual_result = \"Error executing query: \" + str(e)\n",
228
+ " \"\"\"\n",
229
+ " \n",
230
+ " # Compare the ground truth query and expected result to the generated query and actual result.\n",
231
+ " valid, sql_matched, result_matched = compare_result(cursor, row[\"sql_query\"], row[\"result\"], generated_query)\n",
232
+ " \"\"\"\n",
233
+ " print(\"=============================================\")\n",
234
+ " print(f\"Overall Valid: {valid}\")\n",
235
+ " print(f\"SQL Query Matched: {sql_matched}\")\n",
236
+ " print(f\"Result Matched: {result_matched}\")\n",
237
+ " print(\"=============================================\\n\")\n",
238
+ " \n",
239
+ " # Print debug output.\n",
240
+ " print(\"----- Ground Truth SQL Query -----\")\n",
241
+ " print(row[\"sql_query\"])\n",
242
+ " print(\"------------------------------------\\n\")\n",
243
+ " print(\"----- Model Generated SQL Query -----\")\n",
244
+ " print(generated_query)\n",
245
+ " print(\"---------------------------------------\\n\")\n",
246
+ " \n",
247
+ " print(\"----- Expected Result -----\")\n",
248
+ " print(row[\"result\"])\n",
249
+ " print(\"----- Actual DB Result -----\")\n",
250
+ " print(actual_result)\n",
251
+ " print(\"-------------------------------------------------\\n\")\n",
252
+ " \"\"\"\n",
253
+ " if valid:\n",
254
+ " num_valid += 1\n",
255
+ " if sql_matched:\n",
256
+ " num_sql_matched += 1\n",
257
+ " if result_matched:\n",
258
+ " num_result_matched += 1\n",
259
+ " \n",
260
+ " counter += 1\n",
261
+ "\n",
262
+ " # CONTROL ITERS\n",
263
+ " # if counter == 2:\n",
264
+ " # break\n",
265
+ " \n",
266
+ " if counter % 50 == 0:\n",
267
+ " print(\"Completed \" + str(counter))\n",
268
+ " \n",
269
+ " print(\"\\n\" + title + \" results:\")\n",
270
+ " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
271
+ " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
272
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))\n",
273
+ " print(\"Dataset length: \" + str(len(nba_df)))\n",
274
+ " print(\"-------------------\")\n",
275
+ " print(\"Num queries tested: \", counter)\n",
276
+ " print(\"Num correct queries: \", num_result_matched)\n",
277
+ " print(\"Acc: \", (num_result_matched / counter)*100)\n",
278
+ " print(\"-------------------\")\n",
279
+ " "
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": 17,
285
+ "id": "0c3fdc3f",
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "def run(nba_df, title):\n",
290
+ " counter = 0\n",
291
+ " num_valid = 0\n",
292
+ " num_sql_matched = 0\n",
293
+ " num_result_matched = 0\n",
294
+ " for index, row in nba_df.iterrows():\n",
295
+ " print(row['natural_query'])"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "markdown",
300
+ "id": "8bff68e0",
301
+ "metadata": {},
302
+ "source": [
303
+ "## Run ChatGPT evaluation"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 26,
309
+ "id": "ce291e30",
310
+ "metadata": {},
311
+ "outputs": [
312
+ {
313
+ "name": "stdout",
314
+ "output_type": "stream",
315
+ "text": [
316
+ "Completed 50\n",
317
+ "Completed 100\n",
318
+ "Completed 150\n",
319
+ "Completed 200\n",
320
+ "Completed 250\n",
321
+ "Completed 300\n",
322
+ "Completed 350\n",
323
+ "Completed 400\n",
324
+ "Completed 450\n",
325
+ "Completed 500\n",
326
+ "Completed 550\n",
327
+ "Completed 600\n",
328
+ "Completed 650\n",
329
+ "Completed 700\n",
330
+ "Completed 750\n",
331
+ "Completed 800\n",
332
+ "Completed 850\n",
333
+ "Completed 900\n",
334
+ "Completed 950\n",
335
+ "Completed 1000\n",
336
+ "\n",
337
+ "All training data results:\n",
338
+ "Percent valid: 0.8630268199233716\n",
339
+ "Percent SQLite matched: 0.20114942528735633\n",
340
+ "Percent result matched: 0.6293103448275862\n",
341
+ "Dataset length: 1044\n",
342
+ "-------------------\n",
343
+ "Num queries tested: 1044\n",
344
+ "Num correct queries: 657\n",
345
+ "Acc: 62.93103448275862\n",
346
+ "-------------------\n",
347
+ "Dataset length: 1044\n"
348
+ ]
349
+ }
350
+ ],
351
+ "source": [
352
+ "# ------------------------------\n",
353
+ "# Run evaluation on the full training dataset\n",
354
+ "# ------------------------------\n",
355
+ "run_evaluation(df, \"All training data\")\n",
356
+ "print(\"Dataset length: \" + str(len(df)))"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "markdown",
361
+ "id": "b21994fa",
362
+ "metadata": {},
363
+ "source": [
364
+ "## Run RAG evaluation on small query dataset"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "id": "c2d12248",
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "name": "stdout",
375
+ "output_type": "stream",
376
+ "text": [
377
+ "Completed 50\n",
378
+ "Completed 100\n",
379
+ "Completed 150\n",
380
+ "Completed 200\n",
381
+ "\n",
382
+ "Less than 90 results:\n",
383
+ "Percent valid: 0.8979591836734694\n",
384
+ "Percent SQLite matched: 0.37551020408163266\n",
385
+ "Percent result matched: 0.7061224489795919\n",
386
+ "Dataset length: 245\n",
387
+ "-------------------\n",
388
+ "Num queries tested: 245\n",
389
+ "Num correct queries: 173\n",
390
+ "Acc: 70.61224489795919\n",
391
+ "-------------------\n",
392
+ "Dataset length: 245\n"
393
+ ]
394
+ }
395
+ ],
396
+ "source": [
397
+ "less_than_90_df = pd.read_csv(get_path(\"train-data/less_than_90.tsv\"), sep='\\t')\n",
398
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
399
+ "print(\"Dataset length: \" + str(len(less_than_90_df)))"
400
+ ]
401
+ }
402
+ ],
403
+ "metadata": {
404
+ "kernelspec": {
405
+ "display_name": "CSCI544",
406
+ "language": "python",
407
+ "name": "python3"
408
+ },
409
+ "language_info": {
410
+ "codemirror_mode": {
411
+ "name": "ipython",
412
+ "version": 3
413
+ },
414
+ "file_extension": ".py",
415
+ "mimetype": "text/x-python",
416
+ "name": "python",
417
+ "nbconvert_exporter": "python",
418
+ "pygments_lexer": "ipython3",
419
+ "version": "3.11.11"
420
+ }
421
+ },
422
+ "nbformat": 4,
423
+ "nbformat_minor": 5
424
+ }
chat_gpt_4.ipynb ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cf4403ec",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Notebook to evaluate ChatGPT Peformance"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "7f708eaa",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "/opt/anaconda3/envs/CSCI544/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
22
+ " from .autonotebook import tqdm as notebook_tqdm\n"
23
+ ]
24
+ }
25
+ ],
26
+ "source": [
27
+ "import pandas as pd\n",
28
+ "import warnings\n",
29
+ "import sqlite3 as sql\n",
30
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
31
+ "from huggingface_hub import snapshot_download\n",
32
+ "import sys\n",
33
+ "import os\n",
34
+ "import openai\n"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "83a1bd00",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "import os\n",
45
+ "os.environ[\"OPENAI_API_KEY\"] = \"<key>\""
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "id": "b3a647bf",
51
+ "metadata": {},
52
+ "source": [
53
+ "## Set up path"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 2,
59
+ "id": "996e282d",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "is_google_colab=False"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 3,
69
+ "id": "5d96087b",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "current_path = \"./\"\n",
74
+ "\n",
75
+ "def get_path(rel_path):\n",
76
+ " return os.path.join(current_path, rel_path)\n",
77
+ "\n",
78
+ "if is_google_colab:\n",
79
+ " hugging_face_path = snapshot_download(\n",
80
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
81
+ " repo_type=\"model\", \n",
82
+ " allow_patterns=[\"src/*\", \"train-data/*\", \"deepseek-coder-1.3b-instruct/*\", \"nba-data/*\"], \n",
83
+ " )\n",
84
+ " sys.path.append(hugging_face_path)\n",
85
+ " current_path = hugging_face_path"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 4,
91
+ "id": "483da9f0",
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "data": {
96
+ "text/plain": [
97
+ "'./nba-data/nba.sqlite'"
98
+ ]
99
+ },
100
+ "execution_count": 4,
101
+ "metadata": {},
102
+ "output_type": "execute_result"
103
+ }
104
+ ],
105
+ "source": [
106
+ "get_path('nba-data/nba.sqlite')"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 5,
112
+ "id": "5cc9f19f",
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Total dataset examples: 1044\n",
120
+ "\n",
121
+ "\n"
122
+ ]
123
+ }
124
+ ],
125
+ "source": [
126
+ "\n",
127
+ "\n",
128
+ "warnings.filterwarnings(\"ignore\")\n",
129
+ "# Establish a database connection once (adjust the DB path as needed)\n",
130
+ "connection = sql.connect(get_path('nba-data/nba.sqlite'))\n",
131
+ "cursor = connection.cursor()\n",
132
+ "\n",
133
+ "# ------------------------------\n",
134
+ "# Load dataset and print summary\n",
135
+ "# ------------------------------\n",
136
+ "df = pd.read_csv(get_path(\"train-data/expanded_sql_train.tsv\"), sep='\\t')\n",
137
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
138
+ "print(\"\\n\")\n",
139
+ "\n",
140
+ "# ------------------------------\n",
141
+ "# Load tokenizer and model\n",
142
+ "# ------------------------------\n",
143
+ "\n"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "f2d859d8",
149
+ "metadata": {},
150
+ "source": [
151
+ "## Define compare result function for evaluation process"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 6,
157
+ "id": "a5295234",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "from src.evaluation.compare_result import compare_result\n",
162
+ "from src.rag.table_retriever import retrieve_doc"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "id": "0a89a468",
168
+ "metadata": {},
169
+ "source": [
170
+ "## Create evaluation loop for ChatGPT"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 8,
176
+ "id": "e580dda8",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "from openai import OpenAI\n",
181
+ "client = OpenAI()"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 9,
187
+ "id": "69707ee7",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# ------------------------------\n",
192
+ "# Function to evaluate the model on a given dataset\n",
193
+ "# ------------------------------\n",
194
+ "\n",
195
+ "from src.prompts.prompt import input_text\n",
196
+ "def run_evaluation(nba_df, title):\n",
197
+ " counter = 0\n",
198
+ " num_valid = 0\n",
199
+ " num_sql_matched = 0\n",
200
+ " num_result_matched = 0\n",
201
+ " for index, row in nba_df.iterrows():\n",
202
+ " # Retrieve relevant schema chunks via RAG\n",
203
+ "\n",
204
+ " response = client.chat.completions.create(\n",
205
+ " model=\"gpt-4-turbo\",\n",
206
+ " messages=[\n",
207
+ " {\"role\": \"user\", \"content\": input_text + row[\"natural_query\"]}\n",
208
+ " ]\n",
209
+ " )\n",
210
+ " \n",
211
+ " # Decode the model output.\n",
212
+ " generated_query = response.choices[0].message.content\n",
213
+ " \n",
214
+ " # Clean generated query: remove any prefix and truncate after first semicolon.\n",
215
+ " if generated_query.startswith(\"SQLite:\"):\n",
216
+ " clean_query = generated_query[len(\"SQLite:\"):].strip()\n",
217
+ " elif generated_query.startswith(\"SQL:\"):\n",
218
+ " clean_query = generated_query[len(\"SQL:\"):].strip()\n",
219
+ " else:\n",
220
+ " clean_query = generated_query.strip()\n",
221
+ " \n",
222
+ " semicolon_idx = clean_query.find(\";\")\n",
223
+ " if semicolon_idx != -1:\n",
224
+ " clean_query = clean_query[:semicolon_idx+1]\n",
225
+ " \n",
226
+ " # Execute the cleaned query on the SQLite DB to obtain the actual result.\n",
227
+ " \"\"\"\n",
228
+ " try:\n",
229
+ " cursor.execute(clean_query)\n",
230
+ " rows = cursor.fetchall()\n",
231
+ " if rows and isinstance(rows[0], (tuple, list)) and len(rows[0]) > 0:\n",
232
+ " actual_result = rows[0][0]\n",
233
+ " elif rows:\n",
234
+ " actual_result = rows[0]\n",
235
+ " else:\n",
236
+ " actual_result = \"\"\n",
237
+ " except Exception as e:\n",
238
+ " actual_result = \"Error executing query: \" + str(e)\n",
239
+ " \"\"\"\n",
240
+ " \n",
241
+ " # Compare the ground truth query and expected result to the generated query and actual result.\n",
242
+ " valid, sql_matched, result_matched = compare_result(cursor, row[\"sql_query\"], row[\"result\"], generated_query)\n",
243
+ " \"\"\"\n",
244
+ " print(\"=============================================\")\n",
245
+ " print(f\"Overall Valid: {valid}\")\n",
246
+ " print(f\"SQL Query Matched: {sql_matched}\")\n",
247
+ " print(f\"Result Matched: {result_matched}\")\n",
248
+ " print(\"=============================================\\n\")\n",
249
+ " \n",
250
+ " # Print debug output.\n",
251
+ " print(\"----- Ground Truth SQL Query -----\")\n",
252
+ " print(row[\"sql_query\"])\n",
253
+ " print(\"------------------------------------\\n\")\n",
254
+ " print(\"----- Model Generated SQL Query -----\")\n",
255
+ " print(generated_query)\n",
256
+ " print(\"---------------------------------------\\n\")\n",
257
+ " \n",
258
+ " print(\"----- Expected Result -----\")\n",
259
+ " print(row[\"result\"])\n",
260
+ " print(\"----- Actual DB Result -----\")\n",
261
+ " print(actual_result)\n",
262
+ " print(\"-------------------------------------------------\\n\")\n",
263
+ " \"\"\"\n",
264
+ " if valid:\n",
265
+ " num_valid += 1\n",
266
+ " if sql_matched:\n",
267
+ " num_sql_matched += 1\n",
268
+ " if result_matched:\n",
269
+ " num_result_matched += 1\n",
270
+ " \n",
271
+ " counter += 1\n",
272
+ "\n",
273
+ " # CONTROL ITERS\n",
274
+ " # if counter == 2:\n",
275
+ " # break\n",
276
+ " \n",
277
+ " if counter % 50 == 0:\n",
278
+ " print(\"Completed \" + str(counter))\n",
279
+ " \n",
280
+ " print(\"\\n\" + title + \" results:\")\n",
281
+ " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
282
+ " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
283
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))\n",
284
+ " print(\"Dataset length: \" + str(len(nba_df)))\n",
285
+ " print(\"-------------------\")\n",
286
+ " print(\"Num queries tested: \", counter)\n",
287
+ " print(\"Num correct queries: \", num_result_matched)\n",
288
+ " print(\"Acc: \", (num_result_matched / counter)*100)\n",
289
+ " print(\"-------------------\")\n",
290
+ " "
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 17,
296
+ "id": "0c3fdc3f",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "def run(nba_df, title):\n",
301
+ " counter = 0\n",
302
+ " num_valid = 0\n",
303
+ " num_sql_matched = 0\n",
304
+ " num_result_matched = 0\n",
305
+ " for index, row in nba_df.iterrows():\n",
306
+ " print(row['natural_query'])"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "markdown",
311
+ "id": "8bff68e0",
312
+ "metadata": {},
313
+ "source": [
314
+ "## Run ChatGPT evaluation"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": 10,
320
+ "id": "ce291e30",
321
+ "metadata": {},
322
+ "outputs": [
323
+ {
324
+ "name": "stdout",
325
+ "output_type": "stream",
326
+ "text": [
327
+ "Completed 50\n",
328
+ "Completed 100\n",
329
+ "Completed 150\n",
330
+ "Completed 200\n",
331
+ "Completed 250\n",
332
+ "Completed 300\n",
333
+ "Completed 350\n",
334
+ "Completed 400\n",
335
+ "Completed 450\n",
336
+ "Completed 500\n",
337
+ "Completed 550\n",
338
+ "Completed 600\n",
339
+ "Completed 650\n",
340
+ "Completed 700\n",
341
+ "Completed 750\n",
342
+ "Completed 800\n",
343
+ "Completed 850\n",
344
+ "Completed 900\n",
345
+ "Completed 950\n",
346
+ "Completed 1000\n",
347
+ "\n",
348
+ "All training data results:\n",
349
+ "Percent valid: 0.9521072796934866\n",
350
+ "Percent SQLite matched: 0.2260536398467433\n",
351
+ "Percent result matched: 0.7758620689655172\n",
352
+ "Dataset length: 1044\n",
353
+ "-------------------\n",
354
+ "Num queries tested: 1044\n",
355
+ "Num correct queries: 810\n",
356
+ "Acc: 77.58620689655173\n",
357
+ "-------------------\n",
358
+ "Dataset length: 1044\n"
359
+ ]
360
+ }
361
+ ],
362
+ "source": [
363
+ "# ------------------------------\n",
364
+ "# Run evaluation on the full training dataset\n",
365
+ "# ------------------------------\n",
366
+ "run_evaluation(df, \"All training data\")\n",
367
+ "print(\"Dataset length: \" + str(len(df)))"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "markdown",
372
+ "id": "b21994fa",
373
+ "metadata": {},
374
+ "source": [
375
+ "## Run RAG evaluation on small query dataset"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "id": "c2d12248",
382
+ "metadata": {},
383
+ "outputs": [
384
+ {
385
+ "name": "stdout",
386
+ "output_type": "stream",
387
+ "text": [
388
+ "Completed 50\n",
389
+ "Completed 100\n",
390
+ "Completed 150\n",
391
+ "Completed 200\n",
392
+ "\n",
393
+ "Less than 90 results:\n",
394
+ "Percent valid: 0.8979591836734694\n",
395
+ "Percent SQLite matched: 0.37551020408163266\n",
396
+ "Percent result matched: 0.7061224489795919\n",
397
+ "Dataset length: 245\n",
398
+ "-------------------\n",
399
+ "Num queries tested: 245\n",
400
+ "Num correct queries: 173\n",
401
+ "Acc: 70.61224489795919\n",
402
+ "-------------------\n",
403
+ "Dataset length: 245\n"
404
+ ]
405
+ }
406
+ ],
407
+ "source": [
408
+ "less_than_90_df = pd.read_csv(get_path(\"train-data/less_than_90.tsv\"), sep='\\t')\n",
409
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
410
+ "print(\"Dataset length: \" + str(len(less_than_90_df)))"
411
+ ]
412
+ }
413
+ ],
414
+ "metadata": {
415
+ "kernelspec": {
416
+ "display_name": "CSCI544",
417
+ "language": "python",
418
+ "name": "python3"
419
+ },
420
+ "language_info": {
421
+ "codemirror_mode": {
422
+ "name": "ipython",
423
+ "version": 3
424
+ },
425
+ "file_extension": ".py",
426
+ "mimetype": "text/x-python",
427
+ "name": "python",
428
+ "nbconvert_exporter": "python",
429
+ "pygments_lexer": "ipython3",
430
+ "version": "3.11.11"
431
+ }
432
+ },
433
+ "nbformat": 4,
434
+ "nbformat_minor": 5
435
+ }
src/evaluation/__pycache__/compare_result.cpython-311.pyc CHANGED
Binary files a/src/evaluation/__pycache__/compare_result.cpython-311.pyc and b/src/evaluation/__pycache__/compare_result.cpython-311.pyc differ
 
src/rag/__pycache__/table_retriever.cpython-311.pyc ADDED
Binary file (8.28 kB). View file