licesma commited on
Commit
c397e97
·
1 Parent(s): bfad6ce

colab support part two

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +17 -13
test_pretrained.ipynb CHANGED
@@ -9,7 +9,7 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": 22,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
@@ -19,13 +19,14 @@
19
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
20
  "import torch\n",
21
  "import sys\n",
 
22
  "import sqlite3 as sql\n",
23
  "from huggingface_hub import snapshot_download"
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": 23,
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
@@ -34,22 +35,25 @@
34
  },
35
  {
36
  "cell_type": "code",
37
- "execution_count": 24,
38
  "metadata": {},
39
  "outputs": [],
40
  "source": [
 
 
41
  "if is_google_colab:\n",
42
  " hugging_face_path = snapshot_download(\n",
43
  " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
44
  " repo_type=\"model\", \n",
45
  " allow_patterns=[\"src/*\"], \n",
46
  " )\n",
47
- " sys.path.append(hugging_face_path)"
 
48
  ]
49
  },
50
  {
51
  "cell_type": "code",
52
- "execution_count": 25,
53
  "metadata": {},
54
  "outputs": [],
55
  "source": [
@@ -66,7 +70,7 @@
66
  },
67
  {
68
  "cell_type": "code",
69
- "execution_count": 2,
70
  "metadata": {},
71
  "outputs": [
72
  {
@@ -76,15 +80,15 @@
76
  "Total dataset examples: 1044\n",
77
  "\n",
78
  "\n",
79
- "Which team had the largest lead in a single game in the 2001 season?\n",
80
- "SELECT g.team_name_home AS team, os.largest_lead_home AS lead FROM other_stats os JOIN game g ON os.game_id = g.game_id WHERE g.season_id = '22001' ORDER BY os.largest_lead_home DESC LIMIT 1;\n",
81
- "Portland Trail Blazers|47\n"
82
  ]
83
  }
84
  ],
85
  "source": [
86
  "# Load dataset and check length\n",
87
- "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
88
  "print(\"Total dataset examples: \" + str(len(df)))\n",
89
  "print(\"\\n\")\n",
90
  "\n",
@@ -126,7 +130,7 @@
126
  },
127
  {
128
  "cell_type": "code",
129
- "execution_count": 5,
130
  "metadata": {},
131
  "outputs": [
132
  {
@@ -159,7 +163,7 @@
159
  },
160
  {
161
  "cell_type": "code",
162
- "execution_count": 17,
163
  "metadata": {},
164
  "outputs": [
165
  {
@@ -202,7 +206,7 @@
202
  },
203
  {
204
  "cell_type": "code",
205
- "execution_count": null,
206
  "metadata": {},
207
  "outputs": [
208
  {
 
9
  },
10
  {
11
  "cell_type": "code",
12
+ "execution_count": 31,
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
19
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
20
  "import torch\n",
21
  "import sys\n",
22
+ "import os\n",
23
  "import sqlite3 as sql\n",
24
  "from huggingface_hub import snapshot_download"
25
  ]
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 32,
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
 
35
  },
36
  {
37
  "cell_type": "code",
38
+ "execution_count": 33,
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
42
+ "current_path = \"./\"\n",
43
+ "\n",
44
  "if is_google_colab:\n",
45
  " hugging_face_path = snapshot_download(\n",
46
  " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
47
  " repo_type=\"model\", \n",
48
  " allow_patterns=[\"src/*\"], \n",
49
  " )\n",
50
+ " sys.path.append(hugging_face_path)\n",
51
+ " current_path = hugging_face_path"
52
  ]
53
  },
54
  {
55
  "cell_type": "code",
56
+ "execution_count": 34,
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
 
70
  },
71
  {
72
  "cell_type": "code",
73
+ "execution_count": 36,
74
  "metadata": {},
75
  "outputs": [
76
  {
 
80
  "Total dataset examples: 1044\n",
81
  "\n",
82
  "\n",
83
+ "How many points did the Phoenix Suns score in the highest scoring away game they played?\n",
84
+ "SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'PHX';\n",
85
+ "161.0\n"
86
  ]
87
  }
88
  ],
89
  "source": [
90
  "# Load dataset and check length\n",
91
+ "df = pd.read_csv(os.path.join(current_path, \"train-data/sql_train.tsv\"), sep=\"\\t\")\n",
92
  "print(\"Total dataset examples: \" + str(len(df)))\n",
93
  "print(\"\\n\")\n",
94
  "\n",
 
130
  },
131
  {
132
  "cell_type": "code",
133
+ "execution_count": 28,
134
  "metadata": {},
135
  "outputs": [
136
  {
 
163
  },
164
  {
165
  "cell_type": "code",
166
+ "execution_count": 29,
167
  "metadata": {},
168
  "outputs": [
169
  {
 
206
  },
207
  {
208
  "cell_type": "code",
209
+ "execution_count": 12,
210
  "metadata": {},
211
  "outputs": [
212
  {