DeanGumas commited on
Commit
0405efb
·
1 Parent(s): ef32c53

Updating compare result function in src directory, fixing small issues in test_rag prompt, adding test_rag notebook

Browse files
Files changed (3) hide show
  1. src/evaluation/compare_result.py +18 -4
  2. test_rag.ipynb +262 -0
  3. test_rag.py +3 -3
src/evaluation/compare_result.py CHANGED
@@ -1,13 +1,27 @@
1
  import math
2
 
3
- def compare_result(cursor, sample_query, sample_result, query_output):
4
  # Clean model output to only have the query output
5
- if query_output[0:7] == "SQLite:":
 
 
 
 
6
  query = query_output[7:]
 
 
 
 
7
  elif query_output[0:4] == "SQL:":
8
  query = query_output[4:]
9
  else:
10
  query = query_output
 
 
 
 
 
 
11
 
12
  # Try to execute query, if it fails, then this is a failure of the model
13
  try:
@@ -47,7 +61,7 @@ def compare_result(cursor, sample_query, sample_result, query_output):
47
  if math.isclose(float(r), float(res), abs_tol=0.5):
48
  return True, query_match, True
49
  except:
50
- if r in res or res in r:
51
  return True, query_match, True
52
 
53
  # Check if the model returned a sum of examples as opposed to the whole thing
@@ -90,4 +104,4 @@ def compare_result(cursor, sample_query, sample_result, query_output):
90
  # Compare results and return
91
  return True, query_match, result
92
  except:
93
- return False, False, False
 
1
  import math
2
 
3
+ def compare_result(sample_query, sample_result, query_output):
4
  # Clean model output to only have the query output
5
+ if query_output[0:8] == "SQLite:\n":
6
+ query = query_output[8:]
7
+ elif query_output[0:8] == "SQLite: ":
8
+ query = query_output[8:]
9
+ elif query_output[0:7] == "SQLite:":
10
  query = query_output[7:]
11
+ elif query_output[0:5] == "SQL:\n":
12
+ query = query_output[5:]
13
+ elif query_output[0:5] == "SQL: ":
14
+ query = query_output[5:]
15
  elif query_output[0:4] == "SQL:":
16
  query = query_output[4:]
17
  else:
18
  query = query_output
19
+
20
+ # Clean any excess text after the query semicolon
21
+ for i in range(len(query)):
22
+ if query[i] == ";":
23
+ query = query[:i+1]
24
+ break
25
 
26
  # Try to execute query, if it fails, then this is a failure of the model
27
  try:
 
61
  if math.isclose(float(r), float(res), abs_tol=0.5):
62
  return True, query_match, True
63
  except:
64
+ if str(r) in res or res in str(r):
65
  return True, query_match, True
66
 
67
  # Check if the model returned a sum of examples as opposed to the whole thing
 
104
  # Compare results and return
105
  return True, query_match, result
106
  except:
107
+ return False, False, False
test_rag.ipynb ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9ba5b9ac",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Notebook to evaluate RAG performance"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "b7c75665",
14
+ "metadata": {},
15
+ "source": [
16
+ "## Create RAG document store"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "id": "d589714b",
23
+ "metadata": {},
24
+ "outputs": [
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
30
+ " from .autonotebook import tqdm as notebook_tqdm\n"
31
+ ]
32
+ },
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "WARNING:tensorflow:From c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tf_keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
38
+ "\n",
39
+ "Total dataset examples: 1044\n",
40
+ "\n",
41
+ "\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "import pandas as pd\n",
47
+ "import warnings\n",
48
+ "import torch\n",
49
+ "import time\n",
50
+ "import math\n",
51
+ "import sqlite3 as sql\n",
52
+ "\n",
53
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
54
+ "from rag_metadata import SQLMetadataRetriever\n",
55
+ "\n",
56
+ "warnings.filterwarnings(\"ignore\")\n",
57
+ "\n",
58
+ "# Establish a database connection once (adjust the DB path as needed)\n",
59
+ "connection = sql.connect('./nba-data/nba.sqlite')\n",
60
+ "cursor = connection.cursor()\n",
61
+ "\n",
62
+ "# ------------------------------\n",
63
+ "# Load dataset and print summary\n",
64
+ "# ------------------------------\n",
65
+ "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
66
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
67
+ "print(\"\\n\")\n",
68
+ "\n",
69
+ "# ------------------------------\n",
70
+ "# Load tokenizer and model\n",
71
+ "# ------------------------------\n",
72
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
73
+ "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
74
+ "model = AutoModelForCausalLM.from_pretrained(\n",
75
+ " \"./deepseek-coder-1.3b-instruct\",\n",
76
+ " torch_dtype=torch.bfloat16,\n",
77
+ " device_map=device\n",
78
+ ")\n",
79
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
80
+ "\n",
81
+ "# ------------------------------\n",
82
+ "# Initialize RAG retriever and load schema metadata\n",
83
+ "# ------------------------------\n",
84
+ "retriever = SQLMetadataRetriever()\n",
85
+ "\n",
86
+ "metadata_docs = [\n",
87
+ " '''team Table\n",
88
+ "Stores information about NBA teams.\n",
89
+ "CREATE TABLE IF NOT EXISTS \"team\" (\n",
90
+ " \"id\" TEXT PRIMARY KEY, -- Unique identifier for the team\n",
91
+ " \"full_name\" TEXT, -- Full official name of the team (e.g., \"Los Angeles Lakers\")\n",
92
+ " \"abbreviation\" TEXT, -- Shortened team name (e.g., \"LAL\")\n",
93
+ " \"nickname\" TEXT, -- Commonly used nickname for the team (e.g., \"Lakers\")\n",
94
+ " \"city\" TEXT, -- City where the team is based\n",
95
+ " \"state\" TEXT, -- State where the team is located\n",
96
+ " \"year_founded\" REAL -- Year the team was established\n",
97
+ ");''',\n",
98
+ " '''game Table\n",
99
+ "Contains detailed statistics for each NBA game, including home and away team performance.\n",
100
+ "CREATE TABLE IF NOT EXISTS \"game\" (\n",
101
+ " \"season_id\" TEXT, -- Season identifier, formatted as \"2YYYY\" (e.g., \"21970\" for the 1970 season)\n",
102
+ " \"team_id_home\" TEXT, -- ID of the home team (matches \"id\" in team table)\n",
103
+ " \"team_abbreviation_home\" TEXT, -- Abbreviation of the home team\n",
104
+ " \"team_name_home\" TEXT, -- Full name of the home team\n",
105
+ " \"game_id\" TEXT PRIMARY KEY, -- Unique identifier for the game\n",
106
+ " \"game_date\" TIMESTAMP, -- Date the game was played (YYYY-MM-DD format)\n",
107
+ " \"matchup_home\" TEXT, -- Matchup details including opponent (e.g., \"LAL vs. BOS\")\n",
108
+ " \"wl_home\" TEXT, -- \"W\" if the home team won, \"L\" if they lost\n",
109
+ " \"min\" INTEGER, -- Total minutes played in the game\n",
110
+ " \"fgm_home\" REAL, -- Field goals made by the home team\n",
111
+ " \"fga_home\" REAL, -- Field goals attempted by the home team\n",
112
+ " \"fg_pct_home\" REAL, -- Field goal percentage of the home team\n",
113
+ " \"fg3m_home\" REAL, -- Three-point field goals made by the home team\n",
114
+ " \"fg3a_home\" REAL, -- Three-point attempts by the home team\n",
115
+ " \"fg3_pct_home\" REAL, -- Three-point field goal percentage of the home team\n",
116
+ " \"ftm_home\" REAL, -- Free throws made by the home team\n",
117
+ " \"fta_home\" REAL, -- Free throws attempted by the home team\n",
118
+ " \"ft_pct_home\" REAL, -- Free throw percentage of the home team\n",
119
+ " \"oreb_home\" REAL, -- Offensive rebounds by the home team\n",
120
+ " \"dreb_home\" REAL, -- Defensive rebounds by the home team\n",
121
+ " \"reb_home\" REAL, -- Total rebounds by the home team\n",
122
+ " \"ast_home\" REAL, -- Assists by the home team\n",
123
+ " \"stl_home\" REAL, -- Steals by the home team\n",
124
+ " \"blk_home\" REAL, -- Blocks by the home team\n",
125
+ " \"tov_home\" REAL, -- Turnovers by the home team\n",
126
+ " \"pf_home\" REAL, -- Personal fouls by the home team\n",
127
+ " \"pts_home\" REAL, -- Total points scored by the home team\n",
128
+ " \"plus_minus_home\" INTEGER, -- Plus/minus rating for the home team\n",
129
+ " \"video_available_home\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
130
+ " \"team_id_away\" TEXT, -- ID of the away team\n",
131
+ " \"team_abbreviation_away\" TEXT, -- Abbreviation of the away team\n",
132
+ " \"team_name_away\" TEXT, -- Full name of the away team\n",
133
+ " \"matchup_away\" TEXT, -- Matchup details from the away team’s perspective\n",
134
+ " \"wl_away\" TEXT, -- \"W\" if the away team won, \"L\" if they lost\n",
135
+ " \"fgm_away\" REAL, -- Field goals made by the away team\n",
136
+ " \"fga_away\" REAL, -- Field goals attempted by the away team\n",
137
+ " \"fg_pct_away\" REAL, -- Field goal percentage of the away team\n",
138
+ " \"fg3m_away\" REAL, -- Three-point field goals made by the away team\n",
139
+ " \"fg3a_away\" REAL, -- Three-point attempts by the away team\n",
140
+ " \"fg3_pct_away\" REAL, -- Three-point field goal percentage of the away team\n",
141
+ " \"ftm_away\" REAL, -- Free throws made by the away team\n",
142
+ " \"fta_away\" REAL, -- Free throws attempted by the away team\n",
143
+ " \"ft_pct_away\" REAL, -- Free throw percentage of the away team\n",
144
+ " \"oreb_away\" REAL, -- Offensive rebounds by the away team\n",
145
+ " \"dreb_away\" REAL, -- Defensive rebounds by the away team\n",
146
+ " \"reb_away\" REAL, -- Total rebounds by the away team\n",
147
+ " \"ast_away\" REAL, -- Assists by the away team\n",
148
+ " \"stl_away\" REAL, -- Steals by the away team\n",
149
+ " \"blk_away\" REAL, -- Blocks by the away team\n",
150
+ " \"tov_away\" REAL, -- Turnovers by the away team\n",
151
+ " \"pf_away\" REAL, -- Personal fouls by the away team\n",
152
+ " \"pts_away\" REAL, -- Total points scored by the away team\n",
153
+ " \"plus_minus_away\" INTEGER, -- Plus/minus rating for the away team\n",
154
+ " \"video_available_away\" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)\n",
155
+ " \"season_type\" TEXT -- Regular season or playoffs\n",
156
+ ");\n",
157
+ "''',\n",
158
+ " '''other_stats Table\n",
159
+ "Stores additional statistics, linked to the game table via game_id.\n",
160
+ "CREATE TABLE IF NOT EXISTS \"other_stats\" (\n",
161
+ " \"game_id\" TEXT, -- Unique game identifier, matches id column from game table\n",
162
+ " \"league_id\" TEXT, -- League identifier\n",
163
+ " \"team_id_home\" TEXT, -- Home team identifier\n",
164
+ " \"team_abbreviation_home\" TEXT, -- Home team abbreviation\n",
165
+ " \"team_city_home\" TEXT, -- Home team city\n",
166
+ " \"pts_paint_home\" INTEGER, -- Points in the paint by the home team\n",
167
+ " \"pts_2nd_chance_home\" INTEGER, -- Second chance points by the home team\n",
168
+ " \"pts_fb_home\" INTEGER, -- Fast break points by the home team\n",
169
+ " \"largest_lead_home\" INTEGER,-- Largest lead by the home team\n",
170
+ " \"lead_changes\" INTEGER, -- Number of lead changes \n",
171
+ " \"times_tied\" INTEGER, -- Number of times the score was tied\n",
172
+ " \"team_turnovers_home\" INTEGER, -- Home team turnovers\n",
173
+ " \"total_turnovers_home\" INTEGER, -- Total turnovers by the home team\n",
174
+ " \"team_rebounds_home\" INTEGER, -- Home team rebounds\n",
175
+ " \"pts_off_to_home\" INTEGER, -- Points off turnovers by the home team\n",
176
+ " \"team_id_away\" TEXT, -- Away team identifier\n",
177
+ " \"team_abbreviation_away\" TEXT, -- Away team abbreviation\n",
178
+ " \"pts_paint_away\" INTEGER, -- Points in the paint by the away team\n",
179
+ " \"pts_2nd_chance_away\" INTEGER, -- Second chance points by the away team\n",
180
+ " \"pts_fb_away\" INTEGER, -- Fast break points by the away team\n",
181
+ " \"largest_lead_away\" INTEGER,-- Largest lead by the away team\n",
182
+ " \"team_turnovers_away\" INTEGER, -- Away team turnovers\n",
183
+ " \"total_turnovers_away\" INTEGER, -- Total turnovers by the away team\n",
184
+ " \"team_rebounds_away\" INTEGER, -- Away team rebounds\n",
185
+ " \"pts_off_to_away\" INTEGER -- Points off turnovers by the away team\n",
186
+ ");\n",
187
+ "''',\n",
188
+ " '''Team Name Information\n",
189
+ "In plaintext user questions, only the full team names will be used, but in the queries you may use either full names or abbreviations.\n",
190
+ "Full names are used with the game table, while abbreviations should be used with the other_stats table.\n",
191
+ "Team names and abbreviations (separated by |):\n",
192
+ "Atlanta Hawks|ATL, Boston Celtics|BOS, Cleveland Cavaliers|CLE, New Orleans Pelicans|NOP,\n",
193
+ "Chicago Bulls|CHI, Dallas Mavericks|DAL, Denver Nuggets|DEN, Golden State Warriors|GSW,\n",
194
+ "Houston Rockets|HOU, Los Angeles Clippers|LAC, Los Angeles Lakers|LAL, Miami Heat|MIA,\n",
195
+ "Milwaukee Bucks|MIL, Minnesota Timberwolves|MIN, Brooklyn Nets|BKN, New York Knicks|NYK,\n",
196
+ "Orlando Magic|ORL, Indiana Pacers|IND, Philadelphia 76ers|PHI, Phoenix Suns|PHX,\n",
197
+ "Portland Trail Blazers|POR, Sacramento Kings|SAC, San Antonio Spurs|SAS,\n",
198
+ "Oklahoma City Thunder|OKC, Toronto Raptors|TOR, Utah Jazz|UTA, Memphis Grizzlies|MEM,\n",
199
+ "Washington Wizards|WAS, Detroit Pistons|DET, Charlotte Hornets|CHA\n",
200
+ "'''\n",
201
+ "]\n",
202
+ "\n",
203
+ "retriever.add_documents(metadata_docs)"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "id": "499d2745",
209
+ "metadata": {},
210
+ "source": [
211
+ "## Define compare result function for evaluation process"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "id": "268561cd",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "\n"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "markdown",
226
+ "id": "e7393ccb",
227
+ "metadata": {},
228
+ "source": [
229
+ "## Evaluate RAG model on single training example"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "500f003b",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": []
239
+ }
240
+ ],
241
+ "metadata": {
242
+ "kernelspec": {
243
+ "display_name": "Python 3",
244
+ "language": "python",
245
+ "name": "python3"
246
+ },
247
+ "language_info": {
248
+ "codemirror_mode": {
249
+ "name": "ipython",
250
+ "version": 3
251
+ },
252
+ "file_extension": ".py",
253
+ "mimetype": "text/x-python",
254
+ "name": "python",
255
+ "nbconvert_exporter": "python",
256
+ "pygments_lexer": "ipython3",
257
+ "version": "3.12.6"
258
+ }
259
+ },
260
+ "nbformat": 4,
261
+ "nbformat_minor": 5
262
+ }
test_rag.py CHANGED
@@ -282,15 +282,15 @@ def run_evaluation(nba_df, title):
282
 
283
  # Build the prompt with instructions, schema, examples, and current request.
284
  input_text = f"""
285
- You are an AI assistant that generates SQL queries for an NBA database based on user questions.
286
 
287
  ### Relevant Schema:
288
  {schema_block}
289
 
290
  ### Instructions:
291
- - Generate a valid SQL query to retrieve relevant data from the database.
292
  - Use column names correctly based on the provided schema.
293
- - Output only the SQL query as plain text.
294
 
295
  ### Example Queries:
296
  Use team_name_home and team_name_away to match teams to the game table.
 
282
 
283
  # Build the prompt with instructions, schema, examples, and current request.
284
  input_text = f"""
285
+ You are an AI assistant that generates SQLite queries for an NBA database based on user questions.
286
 
287
  ### Relevant Schema:
288
  {schema_block}
289
 
290
  ### Instructions:
291
+ - Generate a valid SQLite query to retrieve relevant data from the database.
292
  - Use column names correctly based on the provided schema.
293
+ - Output only the SQLite query as plain text.
294
 
295
  ### Example Queries:
296
  Use team_name_home and team_name_away to match teams to the game table.