colab support part two
Browse files- test_pretrained.ipynb +17 -13
test_pretrained.ipynb
CHANGED
@@ -9,7 +9,7 @@
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
-
"execution_count":
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
@@ -19,13 +19,14 @@
|
|
19 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
20 |
"import torch\n",
|
21 |
"import sys\n",
|
|
|
22 |
"import sqlite3 as sql\n",
|
23 |
"from huggingface_hub import snapshot_download"
|
24 |
]
|
25 |
},
|
26 |
{
|
27 |
"cell_type": "code",
|
28 |
-
"execution_count":
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
@@ -34,22 +35,25 @@
|
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
-
"execution_count":
|
38 |
"metadata": {},
|
39 |
"outputs": [],
|
40 |
"source": [
|
|
|
|
|
41 |
"if is_google_colab:\n",
|
42 |
" hugging_face_path = snapshot_download(\n",
|
43 |
" repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
|
44 |
" repo_type=\"model\", \n",
|
45 |
" allow_patterns=[\"src/*\"], \n",
|
46 |
" )\n",
|
47 |
-
" sys.path.append(hugging_face_path)"
|
|
|
48 |
]
|
49 |
},
|
50 |
{
|
51 |
"cell_type": "code",
|
52 |
-
"execution_count":
|
53 |
"metadata": {},
|
54 |
"outputs": [],
|
55 |
"source": [
|
@@ -66,7 +70,7 @@
|
|
66 |
},
|
67 |
{
|
68 |
"cell_type": "code",
|
69 |
-
"execution_count":
|
70 |
"metadata": {},
|
71 |
"outputs": [
|
72 |
{
|
@@ -76,15 +80,15 @@
|
|
76 |
"Total dataset examples: 1044\n",
|
77 |
"\n",
|
78 |
"\n",
|
79 |
-
"
|
80 |
-
"SELECT
|
81 |
-
"
|
82 |
]
|
83 |
}
|
84 |
],
|
85 |
"source": [
|
86 |
"# Load dataset and check length\n",
|
87 |
-
"df = pd.read_csv(\"
|
88 |
"print(\"Total dataset examples: \" + str(len(df)))\n",
|
89 |
"print(\"\\n\")\n",
|
90 |
"\n",
|
@@ -126,7 +130,7 @@
|
|
126 |
},
|
127 |
{
|
128 |
"cell_type": "code",
|
129 |
-
"execution_count":
|
130 |
"metadata": {},
|
131 |
"outputs": [
|
132 |
{
|
@@ -159,7 +163,7 @@
|
|
159 |
},
|
160 |
{
|
161 |
"cell_type": "code",
|
162 |
-
"execution_count":
|
163 |
"metadata": {},
|
164 |
"outputs": [
|
165 |
{
|
@@ -202,7 +206,7 @@
|
|
202 |
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
-
"execution_count":
|
206 |
"metadata": {},
|
207 |
"outputs": [
|
208 |
{
|
|
|
9 |
},
|
10 |
{
|
11 |
"cell_type": "code",
|
12 |
+
"execution_count": 31,
|
13 |
"metadata": {},
|
14 |
"outputs": [],
|
15 |
"source": [
|
|
|
19 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
20 |
"import torch\n",
|
21 |
"import sys\n",
|
22 |
+
"import os\n",
|
23 |
"import sqlite3 as sql\n",
|
24 |
"from huggingface_hub import snapshot_download"
|
25 |
]
|
26 |
},
|
27 |
{
|
28 |
"cell_type": "code",
|
29 |
+
"execution_count": 32,
|
30 |
"metadata": {},
|
31 |
"outputs": [],
|
32 |
"source": [
|
|
|
35 |
},
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
+
"execution_count": 33,
|
39 |
"metadata": {},
|
40 |
"outputs": [],
|
41 |
"source": [
|
42 |
+
"current_path = \"./\"\n",
|
43 |
+
"\n",
|
44 |
"if is_google_colab:\n",
|
45 |
" hugging_face_path = snapshot_download(\n",
|
46 |
" repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
|
47 |
" repo_type=\"model\", \n",
|
48 |
" allow_patterns=[\"src/*\"], \n",
|
49 |
" )\n",
|
50 |
+
" sys.path.append(hugging_face_path)\n",
|
51 |
+
" current_path = hugging_face_path"
|
52 |
]
|
53 |
},
|
54 |
{
|
55 |
"cell_type": "code",
|
56 |
+
"execution_count": 34,
|
57 |
"metadata": {},
|
58 |
"outputs": [],
|
59 |
"source": [
|
|
|
70 |
},
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
+
"execution_count": 36,
|
74 |
"metadata": {},
|
75 |
"outputs": [
|
76 |
{
|
|
|
80 |
"Total dataset examples: 1044\n",
|
81 |
"\n",
|
82 |
"\n",
|
83 |
+
"How many points did the Phoenix Suns score in the highest scoring away game they played?\n",
|
84 |
+
"SELECT MAX(pts_away) FROM game WHERE team_abbreviation_away = 'PHX';\n",
|
85 |
+
"161.0\n"
|
86 |
]
|
87 |
}
|
88 |
],
|
89 |
"source": [
|
90 |
"# Load dataset and check length\n",
|
91 |
+
"df = pd.read_csv(os.path.join(current_path, \"train-data/sql_train.tsv\"), sep=\"\\t\")\n",
|
92 |
"print(\"Total dataset examples: \" + str(len(df)))\n",
|
93 |
"print(\"\\n\")\n",
|
94 |
"\n",
|
|
|
130 |
},
|
131 |
{
|
132 |
"cell_type": "code",
|
133 |
+
"execution_count": 28,
|
134 |
"metadata": {},
|
135 |
"outputs": [
|
136 |
{
|
|
|
163 |
},
|
164 |
{
|
165 |
"cell_type": "code",
|
166 |
+
"execution_count": 29,
|
167 |
"metadata": {},
|
168 |
"outputs": [
|
169 |
{
|
|
|
206 |
},
|
207 |
{
|
208 |
"cell_type": "code",
|
209 |
+
"execution_count": 12,
|
210 |
"metadata": {},
|
211 |
"outputs": [
|
212 |
{
|