Delete models
Browse files- models/FullyConectedModels/Grid_SearchCV.ipynb +0 -1213
- models/FullyConectedModels/gridsearchcv/grid_model_1.csv +0 -49
- models/FullyConectedModels/model.py +0 -120
- models/FullyConectedModels/parseval.py +0 -83
- models/Parseval_Networks/README.md +0 -3
- models/Parseval_Networks/constraint.py +0 -81
- models/Parseval_Networks/convexity_constraint.py +0 -53
- models/Parseval_Networks/parsevalnet.py +0 -328
- models/README.md +0 -15
- models/_utility.py +0 -109
- models/wideresnet/wresnet.py +0 -329
models/FullyConectedModels/Grid_SearchCV.ipynb
DELETED
@@ -1,1213 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"accelerator": "GPU",
|
6 |
-
"colab": {
|
7 |
-
"name": "Grid_SearchCV.ipynb",
|
8 |
-
"provenance": [],
|
9 |
-
"collapsed_sections": [],
|
10 |
-
"toc_visible": true
|
11 |
-
},
|
12 |
-
"kernelspec": {
|
13 |
-
"display_name": "Python 3",
|
14 |
-
"language": "python",
|
15 |
-
"name": "python3"
|
16 |
-
},
|
17 |
-
"language_info": {
|
18 |
-
"codemirror_mode": {
|
19 |
-
"name": "ipython",
|
20 |
-
"version": 3
|
21 |
-
},
|
22 |
-
"file_extension": ".py",
|
23 |
-
"mimetype": "text/x-python",
|
24 |
-
"name": "python",
|
25 |
-
"nbconvert_exporter": "python",
|
26 |
-
"pygments_lexer": "ipython3",
|
27 |
-
"version": "3.7.4"
|
28 |
-
}
|
29 |
-
},
|
30 |
-
"cells": [
|
31 |
-
{
|
32 |
-
"cell_type": "markdown",
|
33 |
-
"metadata": {
|
34 |
-
"id": "o7mJMiThKvtT"
|
35 |
-
},
|
36 |
-
"source": [
|
37 |
-
"# <font color=\"purple\"><b>Grid Search CV Algorithm for Fully Connected Networks</b></font>"
|
38 |
-
]
|
39 |
-
},
|
40 |
-
{
|
41 |
-
"cell_type": "markdown",
|
42 |
-
"metadata": {
|
43 |
-
"id": "kUrakNvqKvtU"
|
44 |
-
},
|
45 |
-
"source": [
|
46 |
-
"Using the grid search CV algorithm, the hyperparameters of this model is sought.\n",
|
47 |
-
"<li><b> Learning Rate:</b> 0.1, 0.01</li>\n",
|
48 |
-
"<li><b> Regularization Penalty:</b>0.01, 0.001, 0.0001</li>\n",
|
49 |
-
"<li><b> Batch Size:</b> 64, 128</li>\n",
|
50 |
-
"<li><b> Epochs:</b> 50, 100, 150</li>"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "markdown",
|
55 |
-
"metadata": {
|
56 |
-
"id": "9rFQbEcDKvtV"
|
57 |
-
},
|
58 |
-
"source": [
|
59 |
-
"## <font color=\"blue\">Import Libraries</font>"
|
60 |
-
]
|
61 |
-
},
|
62 |
-
{
|
63 |
-
"cell_type": "code",
|
64 |
-
"metadata": {
|
65 |
-
"id": "6nhsKKJZ02AK"
|
66 |
-
},
|
67 |
-
"source": [
|
68 |
-
"import gzip\n",
|
69 |
-
"import pickle\n",
|
70 |
-
"import numpy as np\n",
|
71 |
-
"import pandas as pd\n",
|
72 |
-
"import numpy as np\n",
|
73 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
74 |
-
"from tensorflow.keras.utils import to_categorical\n",
|
75 |
-
"from tensorflow.keras import backend as K\n",
|
76 |
-
"from itertools import product\n",
|
77 |
-
"from sklearn.model_selection import train_test_split\n",
|
78 |
-
"from sklearn.model_selection import KFold\n",
|
79 |
-
"\n",
|
80 |
-
"from tensorflow.keras.regularizers import l2\n",
|
81 |
-
"from tensorflow.keras.optimizers import SGD\n",
|
82 |
-
"import tensorflow\n",
|
83 |
-
"import json\n",
|
84 |
-
"import cv2\n",
|
85 |
-
"import io\n",
|
86 |
-
"from sklearn.metrics import accuracy_score\n",
|
87 |
-
"from sklearn.metrics import precision_score\n",
|
88 |
-
"from sklearn.metrics import recall_score\n",
|
89 |
-
"try:\n",
|
90 |
-
" to_unicode = unicode\n",
|
91 |
-
"except NameError:\n",
|
92 |
-
" to_unicode = str\n",
|
93 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
94 |
-
"from tensorflow.keras.utils import to_categorical"
|
95 |
-
],
|
96 |
-
"execution_count": null,
|
97 |
-
"outputs": []
|
98 |
-
},
|
99 |
-
{
|
100 |
-
"cell_type": "code",
|
101 |
-
"metadata": {
|
102 |
-
"id": "TdoFKvR-m04D"
|
103 |
-
},
|
104 |
-
"source": [
|
105 |
-
"!pip install hickle\n",
|
106 |
-
"import hickle as hkl"
|
107 |
-
],
|
108 |
-
"execution_count": null,
|
109 |
-
"outputs": []
|
110 |
-
},
|
111 |
-
{
|
112 |
-
"cell_type": "code",
|
113 |
-
"metadata": {
|
114 |
-
"id": "XqmnAXulmtWm"
|
115 |
-
},
|
116 |
-
"source": [
|
117 |
-
"data = hkl.load(\"data.hkl\")\n",
|
118 |
-
"X_train, X_test, Y_train, y_test = data['xtrain'], data['xtest'], data['ytrain'], data['ytest']\n",
|
119 |
-
"x_train, x_val, y_train, y_val = train_test_split(X_train, Y_train, test_size=0.1)"
|
120 |
-
],
|
121 |
-
"execution_count": null,
|
122 |
-
"outputs": []
|
123 |
-
},
|
124 |
-
{
|
125 |
-
"cell_type": "code",
|
126 |
-
"metadata": {
|
127 |
-
"id": "u4zwcqNAnEoE"
|
128 |
-
},
|
129 |
-
"source": [
|
130 |
-
"from tensorflow.data import Dataset\n",
|
131 |
-
"import tensorflow.keras as keras\n",
|
132 |
-
"from tensorflow.keras.optimizers import Adam\n",
|
133 |
-
"from tensorflow.keras.layers import Conv2D,Input,MaxPooling2D, Dense, Dropout, MaxPool1D, Flatten, AveragePooling1D, BatchNormalization\n",
|
134 |
-
"from tensorflow.keras import Model\n",
|
135 |
-
"import numpy as np\n",
|
136 |
-
"import tensorflow as tf\n",
|
137 |
-
"from tensorflow.keras.models import Sequential\n"
|
138 |
-
],
|
139 |
-
"execution_count": null,
|
140 |
-
"outputs": []
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "code",
|
144 |
-
"metadata": {
|
145 |
-
"id": "pW_D-_Dqm5mg"
|
146 |
-
},
|
147 |
-
"source": [
|
148 |
-
"def model_1(weight_decay):\n",
|
149 |
-
" model = Sequential()\n",
|
150 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1), kernel_regularizer=l2(weight_decay)))\n",
|
151 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
152 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
153 |
-
" model.add(BatchNormalization())\n",
|
154 |
-
" model.add(Flatten())\n",
|
155 |
-
" model.add(Dense(4, activation='softmax', kernel_regularizer=l2(weight_decay)))\n",
|
156 |
-
" return model\n",
|
157 |
-
"\n",
|
158 |
-
"\n",
|
159 |
-
"def model_2(weight_decay):\n",
|
160 |
-
" model = Sequential()\n",
|
161 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1), kernel_regularizer=l2(weight_decay)))\n",
|
162 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
163 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
164 |
-
" model.add(BatchNormalization())\n",
|
165 |
-
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu', kernel_regularizer=l2(weight_decay)))\n",
|
166 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
167 |
-
" model.add(BatchNormalization())\n",
|
168 |
-
" model.add(Flatten())\n",
|
169 |
-
" model.add(Dense(4, activation='softmax', kernel_regularizer=l2(weight_decay)))\n",
|
170 |
-
" return model\n",
|
171 |
-
"\n",
|
172 |
-
"\n",
|
173 |
-
"def model_3(weight_decay):\n",
|
174 |
-
" model = Sequential()\n",
|
175 |
-
" model.add(Conv2D(32, kernel_size=(3, 3),activation='relu', input_shape=(32, 32, 1),kernel_regularizer=l2(weight_decay)))\n",
|
176 |
-
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
177 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
178 |
-
" model.add(BatchNormalization())\n",
|
179 |
-
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
180 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
181 |
-
" model.add(BatchNormalization())\n",
|
182 |
-
" model.add(Conv2D(256, kernel_size=(3, 3), activation='relu',kernel_regularizer=l2(weight_decay)))\n",
|
183 |
-
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
184 |
-
" model.add(BatchNormalization())\n",
|
185 |
-
" model.add(Flatten())\n",
|
186 |
-
" model.add(Dense(4, activation='softmax',kernel_regularizer=l2(weight_decay)))\n",
|
187 |
-
" return model\n"
|
188 |
-
],
|
189 |
-
"execution_count": null,
|
190 |
-
"outputs": []
|
191 |
-
},
|
192 |
-
{
|
193 |
-
"cell_type": "markdown",
|
194 |
-
"metadata": {
|
195 |
-
"id": "p9X1E2wybnZj"
|
196 |
-
},
|
197 |
-
"source": [
|
198 |
-
"<font color=\"blue\"> The algorithm below is that ... </font>"
|
199 |
-
]
|
200 |
-
},
|
201 |
-
{
|
202 |
-
"cell_type": "code",
|
203 |
-
"metadata": {
|
204 |
-
"id": "V8U0Vk9t0l3V"
|
205 |
-
},
|
206 |
-
"source": [
|
207 |
-
"from sklearn.metrics import confusion_matrix\n",
|
208 |
-
"def encoded_label(y_predict):\n",
|
209 |
-
" y_list = [] \n",
|
210 |
-
" for y_hat in y_predict:\n",
|
211 |
-
" y_hat = np.argmax(y_hat)\n",
|
212 |
-
" y_list.append(to_categorical(y_hat))\n",
|
213 |
-
" return y_list\n",
|
214 |
-
"\n",
|
215 |
-
"\n",
|
216 |
-
"def KFold_GridSearchCV(input_dim, X, Y, X_test, y_test, combinations, filename=\"log.csv\", acc_loss_json=\"hist.json\"):\n",
|
217 |
-
" \"\"\"Summary: Grid Search CV for 3 Folds Cross Validation\n",
|
218 |
-
" \"\"\"\n",
|
219 |
-
" res_df = pd.DataFrame(columns=['momentum','learning rate','batch size',\n",
|
220 |
-
" 'loss1', 'acc1','loss2', 'acc2','loss3', 'acc3', 'widing factor',\n",
|
221 |
-
" 'prec1', 'prec2', 'prec3', 'recall1', 'recall2', 'recall3'])\n",
|
222 |
-
" generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(rotation_range=10,\n",
|
223 |
-
" width_shift_range=5./32,\n",
|
224 |
-
" height_shift_range=5./32,)\n",
|
225 |
-
" hist_dict_global = {}\n",
|
226 |
-
"\n",
|
227 |
-
" for i, combination in enumerate(combinations):\n",
|
228 |
-
" kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
229 |
-
" metrics_dict = {}\n",
|
230 |
-
" \n",
|
231 |
-
" for j, (train_index, test_index) in enumerate(kf.split(X)):\n",
|
232 |
-
" X_train, X_val = X[train_index], X[test_index]\n",
|
233 |
-
" y_train, y_val = Y[train_index], Y[test_index]\n",
|
234 |
-
" model = model_1(combination[2])\n",
|
235 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=combination[0])\n",
|
236 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
237 |
-
" hist = model.fit(generator.flow(X_train, y_train, batch_size=combination[1]), steps_per_epoch=len(X_train) // combination[1], epochs=combination[3],\n",
|
238 |
-
" validation_data=(X_val, y_val),\n",
|
239 |
-
" validation_steps=len(X_val) // combination[1],)\n",
|
240 |
-
" loss, acc = model.evaluate(X_test, y_test)\n",
|
241 |
-
" #yhat_classes = encoded_label( model.predict(X_test))\n",
|
242 |
-
" predict = model.predict(X_test)\n",
|
243 |
-
" yhat_classes = np.argmax(predict, axis=1)\n",
|
244 |
-
" print(yhat_classes)\n",
|
245 |
-
" print(y_test)\n",
|
246 |
-
" cm = confusion_matrix(np.argmax(y_test, axis=1), yhat_classes)\n",
|
247 |
-
" recall = np.diag(cm) / np.sum(cm, axis = 1)\n",
|
248 |
-
" precision = np.diag(cm) / np.sum(cm, axis = 0)\n",
|
249 |
-
"\n",
|
250 |
-
" recall_avg = np.mean(recall)\n",
|
251 |
-
" precision_avg = np.mean(precision)\n",
|
252 |
-
" metrics_dict[j+1] = {\"loss\": loss, \"acc\": acc, \"epoch_stopped\": combination[3], \"precision_avg\":precision_avg,\n",
|
253 |
-
" \"avg_recall\":recall_avg}\n",
|
254 |
-
" graph_loss_acc = {\"id\": i, \"com\":j+1, \"val_acc\":hist.history[\"val_accuracy\"], \"train_acc\":hist.history[\"accuracy\"],\n",
|
255 |
-
" \"val_loss\":hist.history[\"val_loss\"], \"train_loss\":hist.history[\"loss\"], \"epoch_stopped\": combination[3], 'learning rate': combination[0],\n",
|
256 |
-
" 'batch size': combination[1], 'reg_penalty': combination[2]}\n",
|
257 |
-
"\n",
|
258 |
-
" # Write JSON file\n",
|
259 |
-
" with io.open(acc_loss_json, 'a+', encoding='utf8') as outfile:\n",
|
260 |
-
" str_ = json.dumps(graph_loss_acc)\n",
|
261 |
-
" outfile.write(to_unicode(str_))\n",
|
262 |
-
" row = {'momentum': combination[4],'learning rate': combination[0],\n",
|
263 |
-
" 'batch size': combination[1],\n",
|
264 |
-
" 'reg_penalty': combination[2],\n",
|
265 |
-
" 'epoch_stopped': metrics_dict[1][\"epoch_stopped\"],\n",
|
266 |
-
" 'widing factor' : 1,\n",
|
267 |
-
" 'loss1': metrics_dict[1][\"loss\"],\n",
|
268 |
-
" 'acc1': metrics_dict[1][\"acc\"],\n",
|
269 |
-
" 'loss2': metrics_dict[2][\"loss\"],\n",
|
270 |
-
" 'acc2': metrics_dict[2][\"acc\"],\n",
|
271 |
-
" 'loss3': metrics_dict[3][\"loss\"],\n",
|
272 |
-
" 'acc3': metrics_dict[3][\"acc\"],\n",
|
273 |
-
" 'prec1':metrics_dict[1][\"precision_avg\"],\n",
|
274 |
-
" 'prec2':metrics_dict[2][\"precision_avg\"],\n",
|
275 |
-
" 'prec3':metrics_dict[3][\"precision_avg\"],\n",
|
276 |
-
" 'recall1':metrics_dict[1][\"avg_recall\"],\n",
|
277 |
-
" 'recall2':metrics_dict[2][\"avg_recall\"],\n",
|
278 |
-
" 'recall3':metrics_dict[3][\"avg_recall\"]}\n",
|
279 |
-
" res_df = res_df.append(row , ignore_index=True)\n",
|
280 |
-
" res_df.to_csv(filename, sep=\";\")"
|
281 |
-
],
|
282 |
-
"execution_count": null,
|
283 |
-
"outputs": []
|
284 |
-
},
|
285 |
-
{
|
286 |
-
"cell_type": "code",
|
287 |
-
"metadata": {
|
288 |
-
"id": "tgO8xJ0Fe93I"
|
289 |
-
},
|
290 |
-
"source": [
|
291 |
-
"\n",
|
292 |
-
"if __name__ == \"__main__\":\n",
|
293 |
-
" learning_rate = [0.1,0.01]\n",
|
294 |
-
" batch_size = [64,128]\n",
|
295 |
-
" reg_penalty = [0,0.01, 0.001, 0.0001]\n",
|
296 |
-
" epochs = [50,100,150]\n",
|
297 |
-
" momentum = [0.9]\n",
|
298 |
-
" in_dim = (32,32,1)\n",
|
299 |
-
" grid_result = \"grid_model_1.csv\"\n",
|
300 |
-
" acc_loss_json = \"history.json\"\n",
|
301 |
-
" # create list of all different parameter combinations\n",
|
302 |
-
" param_grid = dict(learning_rate = learning_rate, batch_size = batch_size, \n",
|
303 |
-
" reg_penalty = reg_penalty, epochs = epochs, momentum=momentum)\n",
|
304 |
-
" combinations = list(product(*param_grid.values()))\n",
|
305 |
-
" KFold_GridSearchCV(in_dim,X_train,Y_train,X_test, y_test, combinations, grid_result, acc_loss_json)"
|
306 |
-
],
|
307 |
-
"execution_count": null,
|
308 |
-
"outputs": []
|
309 |
-
},
|
310 |
-
{
|
311 |
-
"cell_type": "code",
|
312 |
-
"metadata": {
|
313 |
-
"id": "iVp5lMuhpAMW"
|
314 |
-
},
|
315 |
-
"source": [
|
316 |
-
"data = pd.read_csv(\"grid_model_1.csv\", sep=\";\")"
|
317 |
-
],
|
318 |
-
"execution_count": null,
|
319 |
-
"outputs": []
|
320 |
-
},
|
321 |
-
{
|
322 |
-
"cell_type": "code",
|
323 |
-
"metadata": {
|
324 |
-
"colab": {
|
325 |
-
"base_uri": "https://localhost:8080/",
|
326 |
-
"height": 244
|
327 |
-
},
|
328 |
-
"id": "rru6HHyrsgbl",
|
329 |
-
"outputId": "01496643-5865-41a6-85c2-f69242b5912b"
|
330 |
-
},
|
331 |
-
"source": [
|
332 |
-
"data.head(5)"
|
333 |
-
],
|
334 |
-
"execution_count": null,
|
335 |
-
"outputs": [
|
336 |
-
{
|
337 |
-
"output_type": "execute_result",
|
338 |
-
"data": {
|
339 |
-
"text/html": [
|
340 |
-
"<div>\n",
|
341 |
-
"<style scoped>\n",
|
342 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
343 |
-
" vertical-align: middle;\n",
|
344 |
-
" }\n",
|
345 |
-
"\n",
|
346 |
-
" .dataframe tbody tr th {\n",
|
347 |
-
" vertical-align: top;\n",
|
348 |
-
" }\n",
|
349 |
-
"\n",
|
350 |
-
" .dataframe thead th {\n",
|
351 |
-
" text-align: right;\n",
|
352 |
-
" }\n",
|
353 |
-
"</style>\n",
|
354 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
355 |
-
" <thead>\n",
|
356 |
-
" <tr style=\"text-align: right;\">\n",
|
357 |
-
" <th></th>\n",
|
358 |
-
" <th>momentum</th>\n",
|
359 |
-
" <th>learning rate</th>\n",
|
360 |
-
" <th>batch size</th>\n",
|
361 |
-
" <th>loss1</th>\n",
|
362 |
-
" <th>acc1</th>\n",
|
363 |
-
" <th>loss2</th>\n",
|
364 |
-
" <th>acc2</th>\n",
|
365 |
-
" <th>loss3</th>\n",
|
366 |
-
" <th>acc3</th>\n",
|
367 |
-
" <th>widing factor</th>\n",
|
368 |
-
" <th>prec1</th>\n",
|
369 |
-
" <th>prec2</th>\n",
|
370 |
-
" <th>prec3</th>\n",
|
371 |
-
" <th>recall1</th>\n",
|
372 |
-
" <th>recall2</th>\n",
|
373 |
-
" <th>recall3</th>\n",
|
374 |
-
" <th>epoch_stopped</th>\n",
|
375 |
-
" <th>reg_penalty</th>\n",
|
376 |
-
" </tr>\n",
|
377 |
-
" </thead>\n",
|
378 |
-
" <tbody>\n",
|
379 |
-
" <tr>\n",
|
380 |
-
" <th>0</th>\n",
|
381 |
-
" <td>0.9</td>\n",
|
382 |
-
" <td>0.1</td>\n",
|
383 |
-
" <td>64.0</td>\n",
|
384 |
-
" <td>1.078664</td>\n",
|
385 |
-
" <td>0.541012</td>\n",
|
386 |
-
" <td>1.087765</td>\n",
|
387 |
-
" <td>0.560209</td>\n",
|
388 |
-
" <td>1.140540</td>\n",
|
389 |
-
" <td>0.495637</td>\n",
|
390 |
-
" <td>1.0</td>\n",
|
391 |
-
" <td>0.545625</td>\n",
|
392 |
-
" <td>0.570303</td>\n",
|
393 |
-
" <td>0.540447</td>\n",
|
394 |
-
" <td>0.549296</td>\n",
|
395 |
-
" <td>0.561227</td>\n",
|
396 |
-
" <td>0.506039</td>\n",
|
397 |
-
" <td>50.0</td>\n",
|
398 |
-
" <td>0.00</td>\n",
|
399 |
-
" </tr>\n",
|
400 |
-
" <tr>\n",
|
401 |
-
" <th>1</th>\n",
|
402 |
-
" <td>0.9</td>\n",
|
403 |
-
" <td>0.1</td>\n",
|
404 |
-
" <td>64.0</td>\n",
|
405 |
-
" <td>1.090459</td>\n",
|
406 |
-
" <td>0.544503</td>\n",
|
407 |
-
" <td>1.052103</td>\n",
|
408 |
-
" <td>0.568935</td>\n",
|
409 |
-
" <td>0.991915</td>\n",
|
410 |
-
" <td>0.586387</td>\n",
|
411 |
-
" <td>1.0</td>\n",
|
412 |
-
" <td>0.567745</td>\n",
|
413 |
-
" <td>0.581882</td>\n",
|
414 |
-
" <td>0.587054</td>\n",
|
415 |
-
" <td>0.549632</td>\n",
|
416 |
-
" <td>0.571559</td>\n",
|
417 |
-
" <td>0.584749</td>\n",
|
418 |
-
" <td>100.0</td>\n",
|
419 |
-
" <td>0.00</td>\n",
|
420 |
-
" </tr>\n",
|
421 |
-
" <tr>\n",
|
422 |
-
" <th>2</th>\n",
|
423 |
-
" <td>0.9</td>\n",
|
424 |
-
" <td>0.1</td>\n",
|
425 |
-
" <td>64.0</td>\n",
|
426 |
-
" <td>1.203270</td>\n",
|
427 |
-
" <td>0.549738</td>\n",
|
428 |
-
" <td>1.024475</td>\n",
|
429 |
-
" <td>0.607330</td>\n",
|
430 |
-
" <td>1.099551</td>\n",
|
431 |
-
" <td>0.542757</td>\n",
|
432 |
-
" <td>1.0</td>\n",
|
433 |
-
" <td>0.597314</td>\n",
|
434 |
-
" <td>0.629707</td>\n",
|
435 |
-
" <td>0.583353</td>\n",
|
436 |
-
" <td>0.558664</td>\n",
|
437 |
-
" <td>0.612029</td>\n",
|
438 |
-
" <td>0.532126</td>\n",
|
439 |
-
" <td>150.0</td>\n",
|
440 |
-
" <td>0.00</td>\n",
|
441 |
-
" </tr>\n",
|
442 |
-
" <tr>\n",
|
443 |
-
" <th>3</th>\n",
|
444 |
-
" <td>0.9</td>\n",
|
445 |
-
" <td>0.1</td>\n",
|
446 |
-
" <td>64.0</td>\n",
|
447 |
-
" <td>1.331875</td>\n",
|
448 |
-
" <td>0.450262</td>\n",
|
449 |
-
" <td>1.357389</td>\n",
|
450 |
-
" <td>0.425829</td>\n",
|
451 |
-
" <td>1.136434</td>\n",
|
452 |
-
" <td>0.542757</td>\n",
|
453 |
-
" <td>1.0</td>\n",
|
454 |
-
" <td>0.520104</td>\n",
|
455 |
-
" <td>0.556441</td>\n",
|
456 |
-
" <td>0.541896</td>\n",
|
457 |
-
" <td>0.467984</td>\n",
|
458 |
-
" <td>0.441258</td>\n",
|
459 |
-
" <td>0.549730</td>\n",
|
460 |
-
" <td>50.0</td>\n",
|
461 |
-
" <td>0.01</td>\n",
|
462 |
-
" </tr>\n",
|
463 |
-
" <tr>\n",
|
464 |
-
" <th>4</th>\n",
|
465 |
-
" <td>0.9</td>\n",
|
466 |
-
" <td>0.1</td>\n",
|
467 |
-
" <td>64.0</td>\n",
|
468 |
-
" <td>1.286225</td>\n",
|
469 |
-
" <td>0.490401</td>\n",
|
470 |
-
" <td>1.218342</td>\n",
|
471 |
-
" <td>0.539267</td>\n",
|
472 |
-
" <td>1.207052</td>\n",
|
473 |
-
" <td>0.537522</td>\n",
|
474 |
-
" <td>1.0</td>\n",
|
475 |
-
" <td>0.552838</td>\n",
|
476 |
-
" <td>0.569478</td>\n",
|
477 |
-
" <td>0.586226</td>\n",
|
478 |
-
" <td>0.494587</td>\n",
|
479 |
-
" <td>0.541650</td>\n",
|
480 |
-
" <td>0.547532</td>\n",
|
481 |
-
" <td>100.0</td>\n",
|
482 |
-
" <td>0.01</td>\n",
|
483 |
-
" </tr>\n",
|
484 |
-
" </tbody>\n",
|
485 |
-
"</table>\n",
|
486 |
-
"</div>"
|
487 |
-
],
|
488 |
-
"text/plain": [
|
489 |
-
" momentum learning rate batch size ... recall3 epoch_stopped reg_penalty\n",
|
490 |
-
"0 0.9 0.1 64.0 ... 0.506039 50.0 0.00\n",
|
491 |
-
"1 0.9 0.1 64.0 ... 0.584749 100.0 0.00\n",
|
492 |
-
"2 0.9 0.1 64.0 ... 0.532126 150.0 0.00\n",
|
493 |
-
"3 0.9 0.1 64.0 ... 0.549730 50.0 0.01\n",
|
494 |
-
"4 0.9 0.1 64.0 ... 0.547532 100.0 0.01\n",
|
495 |
-
"\n",
|
496 |
-
"[5 rows x 18 columns]"
|
497 |
-
]
|
498 |
-
},
|
499 |
-
"metadata": {
|
500 |
-
"tags": []
|
501 |
-
},
|
502 |
-
"execution_count": 17
|
503 |
-
}
|
504 |
-
]
|
505 |
-
},
|
506 |
-
{
|
507 |
-
"cell_type": "code",
|
508 |
-
"metadata": {
|
509 |
-
"id": "LeVPJJQQt36n"
|
510 |
-
},
|
511 |
-
"source": [
|
512 |
-
"data[\"loss_mean\"] = (data[\"loss1\"]+data[\"loss2\"]+data[\"loss3\"])/3\n",
|
513 |
-
"data[\"acc_mean\"] = (data[\"acc1\"]+data[\"acc2\"]+data[\"acc3\"])/3"
|
514 |
-
],
|
515 |
-
"execution_count": null,
|
516 |
-
"outputs": []
|
517 |
-
},
|
518 |
-
{
|
519 |
-
"cell_type": "code",
|
520 |
-
"metadata": {
|
521 |
-
"id": "XPIKRbPMuUOr"
|
522 |
-
},
|
523 |
-
"source": [
|
524 |
-
"data['epoch'] = data['epoch_stopped']\n",
|
525 |
-
"data['weight_decay'] = data['reg_penalty']"
|
526 |
-
],
|
527 |
-
"execution_count": null,
|
528 |
-
"outputs": []
|
529 |
-
},
|
530 |
-
{
|
531 |
-
"cell_type": "code",
|
532 |
-
"metadata": {
|
533 |
-
"id": "5_2tMeO6x9Uq"
|
534 |
-
},
|
535 |
-
"source": [
|
536 |
-
"data['recall_mean'] = (data['recall1']+data['recall2']+data['recall3'])/3\n",
|
537 |
-
"data['prec_mean'] = (data['prec1']+data['prec2']+data['prec3'])/3"
|
538 |
-
],
|
539 |
-
"execution_count": null,
|
540 |
-
"outputs": []
|
541 |
-
},
|
542 |
-
{
|
543 |
-
"cell_type": "code",
|
544 |
-
"metadata": {
|
545 |
-
"id": "JA0eRctYuWl-"
|
546 |
-
},
|
547 |
-
"source": [
|
548 |
-
"column_list = [\"momentum\", \"learning rate\", \"epoch\",\"batch size\",\"weight_decay\",\"loss_mean\", \"acc_mean\",\"recall_mean\", \"prec_mean\"]"
|
549 |
-
],
|
550 |
-
"execution_count": null,
|
551 |
-
"outputs": []
|
552 |
-
},
|
553 |
-
{
|
554 |
-
"cell_type": "code",
|
555 |
-
"metadata": {
|
556 |
-
"colab": {
|
557 |
-
"base_uri": "https://localhost:8080/",
|
558 |
-
"height": 143
|
559 |
-
},
|
560 |
-
"id": "McsMaiLZuZaa",
|
561 |
-
"outputId": "d04090a8-b295-45ad-d64f-e5965741cd80"
|
562 |
-
},
|
563 |
-
"source": [
|
564 |
-
"data.sort_values(axis=0, by=\"loss_mean\", ascending=True)[column_list].head(3)"
|
565 |
-
],
|
566 |
-
"execution_count": null,
|
567 |
-
"outputs": [
|
568 |
-
{
|
569 |
-
"output_type": "execute_result",
|
570 |
-
"data": {
|
571 |
-
"text/html": [
|
572 |
-
"<div>\n",
|
573 |
-
"<style scoped>\n",
|
574 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
575 |
-
" vertical-align: middle;\n",
|
576 |
-
" }\n",
|
577 |
-
"\n",
|
578 |
-
" .dataframe tbody tr th {\n",
|
579 |
-
" vertical-align: top;\n",
|
580 |
-
" }\n",
|
581 |
-
"\n",
|
582 |
-
" .dataframe thead th {\n",
|
583 |
-
" text-align: right;\n",
|
584 |
-
" }\n",
|
585 |
-
"</style>\n",
|
586 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
587 |
-
" <thead>\n",
|
588 |
-
" <tr style=\"text-align: right;\">\n",
|
589 |
-
" <th></th>\n",
|
590 |
-
" <th>momentum</th>\n",
|
591 |
-
" <th>learning rate</th>\n",
|
592 |
-
" <th>epoch</th>\n",
|
593 |
-
" <th>batch size</th>\n",
|
594 |
-
" <th>weight_decay</th>\n",
|
595 |
-
" <th>loss_mean</th>\n",
|
596 |
-
" <th>acc_mean</th>\n",
|
597 |
-
" <th>recall_mean</th>\n",
|
598 |
-
" <th>prec_mean</th>\n",
|
599 |
-
" </tr>\n",
|
600 |
-
" </thead>\n",
|
601 |
-
" <tbody>\n",
|
602 |
-
" <tr>\n",
|
603 |
-
" <th>24</th>\n",
|
604 |
-
" <td>0.9</td>\n",
|
605 |
-
" <td>0.01</td>\n",
|
606 |
-
" <td>50.0</td>\n",
|
607 |
-
" <td>64.0</td>\n",
|
608 |
-
" <td>0.0000</td>\n",
|
609 |
-
" <td>1.012071</td>\n",
|
610 |
-
" <td>0.612565</td>\n",
|
611 |
-
" <td>0.614360</td>\n",
|
612 |
-
" <td>0.615845</td>\n",
|
613 |
-
" </tr>\n",
|
614 |
-
" <tr>\n",
|
615 |
-
" <th>34</th>\n",
|
616 |
-
" <td>0.9</td>\n",
|
617 |
-
" <td>0.01</td>\n",
|
618 |
-
" <td>100.0</td>\n",
|
619 |
-
" <td>64.0</td>\n",
|
620 |
-
" <td>0.0001</td>\n",
|
621 |
-
" <td>1.026681</td>\n",
|
622 |
-
" <td>0.614311</td>\n",
|
623 |
-
" <td>0.612435</td>\n",
|
624 |
-
" <td>0.619908</td>\n",
|
625 |
-
" </tr>\n",
|
626 |
-
" <tr>\n",
|
627 |
-
" <th>8</th>\n",
|
628 |
-
" <td>0.9</td>\n",
|
629 |
-
" <td>0.10</td>\n",
|
630 |
-
" <td>150.0</td>\n",
|
631 |
-
" <td>64.0</td>\n",
|
632 |
-
" <td>0.0010</td>\n",
|
633 |
-
" <td>1.027309</td>\n",
|
634 |
-
" <td>0.603258</td>\n",
|
635 |
-
" <td>0.601262</td>\n",
|
636 |
-
" <td>0.606730</td>\n",
|
637 |
-
" </tr>\n",
|
638 |
-
" </tbody>\n",
|
639 |
-
"</table>\n",
|
640 |
-
"</div>"
|
641 |
-
],
|
642 |
-
"text/plain": [
|
643 |
-
" momentum learning rate epoch ... acc_mean recall_mean prec_mean\n",
|
644 |
-
"24 0.9 0.01 50.0 ... 0.612565 0.614360 0.615845\n",
|
645 |
-
"34 0.9 0.01 100.0 ... 0.614311 0.612435 0.619908\n",
|
646 |
-
"8 0.9 0.10 150.0 ... 0.603258 0.601262 0.606730\n",
|
647 |
-
"\n",
|
648 |
-
"[3 rows x 9 columns]"
|
649 |
-
]
|
650 |
-
},
|
651 |
-
"metadata": {
|
652 |
-
"tags": []
|
653 |
-
},
|
654 |
-
"execution_count": 26
|
655 |
-
}
|
656 |
-
]
|
657 |
-
},
|
658 |
-
{
|
659 |
-
"cell_type": "code",
|
660 |
-
"metadata": {
|
661 |
-
"id": "hCgRAh43udBu"
|
662 |
-
},
|
663 |
-
"source": [
|
664 |
-
"data[\"loss_na\"] = data.loc[:,[\"loss1\",\"loss2\", \"loss3\"]].isnull().sum(1)"
|
665 |
-
],
|
666 |
-
"execution_count": null,
|
667 |
-
"outputs": []
|
668 |
-
},
|
669 |
-
{
|
670 |
-
"cell_type": "code",
|
671 |
-
"metadata": {
|
672 |
-
"colab": {
|
673 |
-
"base_uri": "https://localhost:8080/",
|
674 |
-
"height": 181
|
675 |
-
},
|
676 |
-
"id": "-ptzEmSHudkM",
|
677 |
-
"outputId": "72a8ca57-603a-451f-8702-962b4be4f91a"
|
678 |
-
},
|
679 |
-
"source": [
|
680 |
-
"data.head(3)"
|
681 |
-
],
|
682 |
-
"execution_count": null,
|
683 |
-
"outputs": [
|
684 |
-
{
|
685 |
-
"output_type": "execute_result",
|
686 |
-
"data": {
|
687 |
-
"text/html": [
|
688 |
-
"<div>\n",
|
689 |
-
"<style scoped>\n",
|
690 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
691 |
-
" vertical-align: middle;\n",
|
692 |
-
" }\n",
|
693 |
-
"\n",
|
694 |
-
" .dataframe tbody tr th {\n",
|
695 |
-
" vertical-align: top;\n",
|
696 |
-
" }\n",
|
697 |
-
"\n",
|
698 |
-
" .dataframe thead th {\n",
|
699 |
-
" text-align: right;\n",
|
700 |
-
" }\n",
|
701 |
-
"</style>\n",
|
702 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
703 |
-
" <thead>\n",
|
704 |
-
" <tr style=\"text-align: right;\">\n",
|
705 |
-
" <th></th>\n",
|
706 |
-
" <th>momentum</th>\n",
|
707 |
-
" <th>learning rate</th>\n",
|
708 |
-
" <th>batch size</th>\n",
|
709 |
-
" <th>loss1</th>\n",
|
710 |
-
" <th>acc1</th>\n",
|
711 |
-
" <th>loss2</th>\n",
|
712 |
-
" <th>acc2</th>\n",
|
713 |
-
" <th>loss3</th>\n",
|
714 |
-
" <th>acc3</th>\n",
|
715 |
-
" <th>widing factor</th>\n",
|
716 |
-
" <th>prec1</th>\n",
|
717 |
-
" <th>prec2</th>\n",
|
718 |
-
" <th>prec3</th>\n",
|
719 |
-
" <th>recall1</th>\n",
|
720 |
-
" <th>recall2</th>\n",
|
721 |
-
" <th>recall3</th>\n",
|
722 |
-
" <th>epoch_stopped</th>\n",
|
723 |
-
" <th>reg_penalty</th>\n",
|
724 |
-
" <th>loss_mean</th>\n",
|
725 |
-
" <th>acc_mean</th>\n",
|
726 |
-
" <th>epoch</th>\n",
|
727 |
-
" <th>weight_decay</th>\n",
|
728 |
-
" <th>loss_na</th>\n",
|
729 |
-
" </tr>\n",
|
730 |
-
" </thead>\n",
|
731 |
-
" <tbody>\n",
|
732 |
-
" <tr>\n",
|
733 |
-
" <th>0</th>\n",
|
734 |
-
" <td>0.9</td>\n",
|
735 |
-
" <td>0.1</td>\n",
|
736 |
-
" <td>64.0</td>\n",
|
737 |
-
" <td>1.078664</td>\n",
|
738 |
-
" <td>0.541012</td>\n",
|
739 |
-
" <td>1.087765</td>\n",
|
740 |
-
" <td>0.560209</td>\n",
|
741 |
-
" <td>1.140540</td>\n",
|
742 |
-
" <td>0.495637</td>\n",
|
743 |
-
" <td>1.0</td>\n",
|
744 |
-
" <td>0.545625</td>\n",
|
745 |
-
" <td>0.570303</td>\n",
|
746 |
-
" <td>0.540447</td>\n",
|
747 |
-
" <td>0.549296</td>\n",
|
748 |
-
" <td>0.561227</td>\n",
|
749 |
-
" <td>0.506039</td>\n",
|
750 |
-
" <td>50.0</td>\n",
|
751 |
-
" <td>0.0</td>\n",
|
752 |
-
" <td>1.102323</td>\n",
|
753 |
-
" <td>0.532286</td>\n",
|
754 |
-
" <td>50.0</td>\n",
|
755 |
-
" <td>0.0</td>\n",
|
756 |
-
" <td>0</td>\n",
|
757 |
-
" </tr>\n",
|
758 |
-
" <tr>\n",
|
759 |
-
" <th>1</th>\n",
|
760 |
-
" <td>0.9</td>\n",
|
761 |
-
" <td>0.1</td>\n",
|
762 |
-
" <td>64.0</td>\n",
|
763 |
-
" <td>1.090459</td>\n",
|
764 |
-
" <td>0.544503</td>\n",
|
765 |
-
" <td>1.052103</td>\n",
|
766 |
-
" <td>0.568935</td>\n",
|
767 |
-
" <td>0.991915</td>\n",
|
768 |
-
" <td>0.586387</td>\n",
|
769 |
-
" <td>1.0</td>\n",
|
770 |
-
" <td>0.567745</td>\n",
|
771 |
-
" <td>0.581882</td>\n",
|
772 |
-
" <td>0.587054</td>\n",
|
773 |
-
" <td>0.549632</td>\n",
|
774 |
-
" <td>0.571559</td>\n",
|
775 |
-
" <td>0.584749</td>\n",
|
776 |
-
" <td>100.0</td>\n",
|
777 |
-
" <td>0.0</td>\n",
|
778 |
-
" <td>1.044826</td>\n",
|
779 |
-
" <td>0.566609</td>\n",
|
780 |
-
" <td>100.0</td>\n",
|
781 |
-
" <td>0.0</td>\n",
|
782 |
-
" <td>0</td>\n",
|
783 |
-
" </tr>\n",
|
784 |
-
" <tr>\n",
|
785 |
-
" <th>2</th>\n",
|
786 |
-
" <td>0.9</td>\n",
|
787 |
-
" <td>0.1</td>\n",
|
788 |
-
" <td>64.0</td>\n",
|
789 |
-
" <td>1.203270</td>\n",
|
790 |
-
" <td>0.549738</td>\n",
|
791 |
-
" <td>1.024475</td>\n",
|
792 |
-
" <td>0.607330</td>\n",
|
793 |
-
" <td>1.099551</td>\n",
|
794 |
-
" <td>0.542757</td>\n",
|
795 |
-
" <td>1.0</td>\n",
|
796 |
-
" <td>0.597314</td>\n",
|
797 |
-
" <td>0.629707</td>\n",
|
798 |
-
" <td>0.583353</td>\n",
|
799 |
-
" <td>0.558664</td>\n",
|
800 |
-
" <td>0.612029</td>\n",
|
801 |
-
" <td>0.532126</td>\n",
|
802 |
-
" <td>150.0</td>\n",
|
803 |
-
" <td>0.0</td>\n",
|
804 |
-
" <td>1.109098</td>\n",
|
805 |
-
" <td>0.566608</td>\n",
|
806 |
-
" <td>150.0</td>\n",
|
807 |
-
" <td>0.0</td>\n",
|
808 |
-
" <td>0</td>\n",
|
809 |
-
" </tr>\n",
|
810 |
-
" </tbody>\n",
|
811 |
-
"</table>\n",
|
812 |
-
"</div>"
|
813 |
-
],
|
814 |
-
"text/plain": [
|
815 |
-
" momentum learning rate batch size ... epoch weight_decay loss_na\n",
|
816 |
-
"0 0.9 0.1 64.0 ... 50.0 0.0 0\n",
|
817 |
-
"1 0.9 0.1 64.0 ... 100.0 0.0 0\n",
|
818 |
-
"2 0.9 0.1 64.0 ... 150.0 0.0 0\n",
|
819 |
-
"\n",
|
820 |
-
"[3 rows x 23 columns]"
|
821 |
-
]
|
822 |
-
},
|
823 |
-
"metadata": {
|
824 |
-
"tags": []
|
825 |
-
},
|
826 |
-
"execution_count": 23
|
827 |
-
}
|
828 |
-
]
|
829 |
-
},
|
830 |
-
{
|
831 |
-
"cell_type": "code",
|
832 |
-
"metadata": {
|
833 |
-
"id": "YBBXsz_rzQl-"
|
834 |
-
},
|
835 |
-
"source": [
|
836 |
-
"generator = tensorflow.keras.preprocessing.image.ImageDataGenerator(rotation_range=10,\n",
|
837 |
-
" width_shift_range=5./32,\n",
|
838 |
-
" height_shift_range=5./32,)"
|
839 |
-
],
|
840 |
-
"execution_count": null,
|
841 |
-
"outputs": []
|
842 |
-
},
|
843 |
-
{
|
844 |
-
"cell_type": "code",
|
845 |
-
"metadata": {
|
846 |
-
"id": "ShIJn_53mawD"
|
847 |
-
},
|
848 |
-
"source": [
|
849 |
-
"kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
850 |
-
"result = []\n",
|
851 |
-
"for j, (train_index, test_index) in enumerate(kf.split(X_train)):\n",
|
852 |
-
" x_train, x_val = X_train[train_index], X_train[test_index]\n",
|
853 |
-
" y_train, y_val = Y_train[train_index], Y_train[test_index]\n",
|
854 |
-
" model = model_2(0)\n",
|
855 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
856 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
857 |
-
" hist = model.fit(generator.flow(x_train, y_train, batch_size=64), steps_per_epoch=len(x_train) //64 , epochs=50,\n",
|
858 |
-
" validation_data=(x_val, y_val),\n",
|
859 |
-
" validation_steps=len(x_val) //64 ,)\n",
|
860 |
-
"\n",
|
861 |
-
" test = model.evaluate(X_test, y_test)\n",
|
862 |
-
" result.append(test)"
|
863 |
-
],
|
864 |
-
"execution_count": null,
|
865 |
-
"outputs": []
|
866 |
-
},
|
867 |
-
{
|
868 |
-
"cell_type": "code",
|
869 |
-
"metadata": {
|
870 |
-
"colab": {
|
871 |
-
"base_uri": "https://localhost:8080/"
|
872 |
-
},
|
873 |
-
"id": "Rj6hnUVRzjX-",
|
874 |
-
"outputId": "41b04e1e-d756-437a-afff-ca9f3fcff46c"
|
875 |
-
},
|
876 |
-
"source": [
|
877 |
-
"mean_acc = (result[0][1]+result[1][1]+result[2][1])/3;mean_acc"
|
878 |
-
],
|
879 |
-
"execution_count": null,
|
880 |
-
"outputs": [
|
881 |
-
{
|
882 |
-
"output_type": "execute_result",
|
883 |
-
"data": {
|
884 |
-
"text/plain": [
|
885 |
-
"0.6305991808573405"
|
886 |
-
]
|
887 |
-
},
|
888 |
-
"metadata": {
|
889 |
-
"tags": []
|
890 |
-
},
|
891 |
-
"execution_count": 47
|
892 |
-
}
|
893 |
-
]
|
894 |
-
},
|
895 |
-
{
|
896 |
-
"cell_type": "code",
|
897 |
-
"metadata": {
|
898 |
-
"colab": {
|
899 |
-
"base_uri": "https://localhost:8080/"
|
900 |
-
},
|
901 |
-
"id": "v_AccCit1De3",
|
902 |
-
"outputId": "a6299f58-4389-4f39-f318-e2371e94f4e5"
|
903 |
-
},
|
904 |
-
"source": [
|
905 |
-
"mean_loss = (result[0][0]+result[1][0]+result[2][0])/3;mean_loss"
|
906 |
-
],
|
907 |
-
"execution_count": null,
|
908 |
-
"outputs": [
|
909 |
-
{
|
910 |
-
"output_type": "execute_result",
|
911 |
-
"data": {
|
912 |
-
"text/plain": [
|
913 |
-
"0.9891296029090881"
|
914 |
-
]
|
915 |
-
},
|
916 |
-
"metadata": {
|
917 |
-
"tags": []
|
918 |
-
},
|
919 |
-
"execution_count": 48
|
920 |
-
}
|
921 |
-
]
|
922 |
-
},
|
923 |
-
{
|
924 |
-
"cell_type": "code",
|
925 |
-
"metadata": {
|
926 |
-
"id": "8gdDigrPzYM5"
|
927 |
-
},
|
928 |
-
"source": [
|
929 |
-
"kf = KFold(n_splits=3, random_state=42, shuffle=False)\n",
|
930 |
-
"result_2 = []\n",
|
931 |
-
"for j, (train_index, test_index) in enumerate(kf.split(X_train)):\n",
|
932 |
-
" x_train, x_val = X_train[train_index], X_train[test_index]\n",
|
933 |
-
" y_train, y_val = Y_train[train_index], Y_train[test_index]\n",
|
934 |
-
" model = model_3(0.0001)\n",
|
935 |
-
" opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
936 |
-
" model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
937 |
-
" hist = model.fit(generator.flow(x_train, y_train, batch_size=64), steps_per_epoch=len(x_train) //64 , epochs=100,\n",
|
938 |
-
" validation_data=(x_val, y_val),\n",
|
939 |
-
" validation_steps=len(x_val) //64 ,)\n",
|
940 |
-
"\n",
|
941 |
-
" test = model.evaluate(X_test, y_test)\n",
|
942 |
-
" result_2.append(test)"
|
943 |
-
],
|
944 |
-
"execution_count": null,
|
945 |
-
"outputs": []
|
946 |
-
},
|
947 |
-
{
|
948 |
-
"cell_type": "code",
|
949 |
-
"metadata": {
|
950 |
-
"colab": {
|
951 |
-
"base_uri": "https://localhost:8080/"
|
952 |
-
},
|
953 |
-
"id": "FYVQINoA2F4l",
|
954 |
-
"outputId": "ac29f686-9b48-4193-bf01-4a17ab589d91"
|
955 |
-
},
|
956 |
-
"source": [
|
957 |
-
"mean_acc = (result_2[0][1]+result_2[1][1]+result_2[2][1])/3;mean_acc"
|
958 |
-
],
|
959 |
-
"execution_count": null,
|
960 |
-
"outputs": [
|
961 |
-
{
|
962 |
-
"output_type": "execute_result",
|
963 |
-
"data": {
|
964 |
-
"text/plain": [
|
965 |
-
"0.7108784119288126"
|
966 |
-
]
|
967 |
-
},
|
968 |
-
"metadata": {
|
969 |
-
"tags": []
|
970 |
-
},
|
971 |
-
"execution_count": 56
|
972 |
-
}
|
973 |
-
]
|
974 |
-
},
|
975 |
-
{
|
976 |
-
"cell_type": "code",
|
977 |
-
"metadata": {
|
978 |
-
"colab": {
|
979 |
-
"base_uri": "https://localhost:8080/"
|
980 |
-
},
|
981 |
-
"id": "guueg5Wo2IE8",
|
982 |
-
"outputId": "986e139c-1224-43a7-b8fc-59d57e72d3a8"
|
983 |
-
},
|
984 |
-
"source": [
|
985 |
-
"mean_loss = (result_2[0][0]+result_2[1][0]+result_2[2][0])/3;mean_loss"
|
986 |
-
],
|
987 |
-
"execution_count": null,
|
988 |
-
"outputs": [
|
989 |
-
{
|
990 |
-
"output_type": "execute_result",
|
991 |
-
"data": {
|
992 |
-
"text/plain": [
|
993 |
-
"0.9208946625391642"
|
994 |
-
]
|
995 |
-
},
|
996 |
-
"metadata": {
|
997 |
-
"tags": []
|
998 |
-
},
|
999 |
-
"execution_count": 57
|
1000 |
-
}
|
1001 |
-
]
|
1002 |
-
},
|
1003 |
-
{
|
1004 |
-
"cell_type": "code",
|
1005 |
-
"metadata": {
|
1006 |
-
"colab": {
|
1007 |
-
"base_uri": "https://localhost:8080/"
|
1008 |
-
},
|
1009 |
-
"id": "luS42Hil0H5R",
|
1010 |
-
"outputId": "91c28154-8014-42a9-fbf2-8935cb4c521a"
|
1011 |
-
},
|
1012 |
-
"source": [
|
1013 |
-
"mean_acc = (result_2[0][1]+result_2[1][1]+result_2[2][1])/3;mean_acc"
|
1014 |
-
],
|
1015 |
-
"execution_count": null,
|
1016 |
-
"outputs": [
|
1017 |
-
{
|
1018 |
-
"output_type": "execute_result",
|
1019 |
-
"data": {
|
1020 |
-
"text/plain": [
|
1021 |
-
"0.6783013343811035"
|
1022 |
-
]
|
1023 |
-
},
|
1024 |
-
"metadata": {
|
1025 |
-
"tags": []
|
1026 |
-
},
|
1027 |
-
"execution_count": 50
|
1028 |
-
}
|
1029 |
-
]
|
1030 |
-
},
|
1031 |
-
{
|
1032 |
-
"cell_type": "code",
|
1033 |
-
"metadata": {
|
1034 |
-
"colab": {
|
1035 |
-
"base_uri": "https://localhost:8080/"
|
1036 |
-
},
|
1037 |
-
"id": "F9Zg3pI61W_J",
|
1038 |
-
"outputId": "fe5f7861-58e9-47f6-8134-8872e494e268"
|
1039 |
-
},
|
1040 |
-
"source": [
|
1041 |
-
"mean_loss = (result_2[0][0]+result_2[1][0]+result_2[2][0])/3;mean_loss"
|
1042 |
-
],
|
1043 |
-
"execution_count": null,
|
1044 |
-
"outputs": [
|
1045 |
-
{
|
1046 |
-
"output_type": "execute_result",
|
1047 |
-
"data": {
|
1048 |
-
"text/plain": [
|
1049 |
-
"0.8747362097104391"
|
1050 |
-
]
|
1051 |
-
},
|
1052 |
-
"metadata": {
|
1053 |
-
"tags": []
|
1054 |
-
},
|
1055 |
-
"execution_count": 51
|
1056 |
-
}
|
1057 |
-
]
|
1058 |
-
},
|
1059 |
-
{
|
1060 |
-
"cell_type": "code",
|
1061 |
-
"metadata": {
|
1062 |
-
"colab": {
|
1063 |
-
"base_uri": "https://localhost:8080/"
|
1064 |
-
},
|
1065 |
-
"id": "OWx9Jej11f_V",
|
1066 |
-
"outputId": "a5b72e6c-de64-4bee-da66-d48a5e6e1870"
|
1067 |
-
},
|
1068 |
-
"source": [
|
1069 |
-
"model = model_1(0)\n",
|
1070 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
1071 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
1072 |
-
"model.summary()"
|
1073 |
-
],
|
1074 |
-
"execution_count": null,
|
1075 |
-
"outputs": [
|
1076 |
-
{
|
1077 |
-
"output_type": "stream",
|
1078 |
-
"text": [
|
1079 |
-
"Model: \"sequential_162\"\n",
|
1080 |
-
"_________________________________________________________________\n",
|
1081 |
-
"Layer (type) Output Shape Param # \n",
|
1082 |
-
"=================================================================\n",
|
1083 |
-
"conv2d_342 (Conv2D) (None, 30, 30, 32) 320 \n",
|
1084 |
-
"_________________________________________________________________\n",
|
1085 |
-
"conv2d_343 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
1086 |
-
"_________________________________________________________________\n",
|
1087 |
-
"max_pooling2d_183 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
1088 |
-
"_________________________________________________________________\n",
|
1089 |
-
"batch_normalization_183 (Bat (None, 14, 14, 64) 256 \n",
|
1090 |
-
"_________________________________________________________________\n",
|
1091 |
-
"flatten_162 (Flatten) (None, 12544) 0 \n",
|
1092 |
-
"_________________________________________________________________\n",
|
1093 |
-
"dense_162 (Dense) (None, 4) 50180 \n",
|
1094 |
-
"=================================================================\n",
|
1095 |
-
"Total params: 69,252\n",
|
1096 |
-
"Trainable params: 69,124\n",
|
1097 |
-
"Non-trainable params: 128\n",
|
1098 |
-
"_________________________________________________________________\n"
|
1099 |
-
],
|
1100 |
-
"name": "stdout"
|
1101 |
-
}
|
1102 |
-
]
|
1103 |
-
},
|
1104 |
-
{
|
1105 |
-
"cell_type": "code",
|
1106 |
-
"metadata": {
|
1107 |
-
"colab": {
|
1108 |
-
"base_uri": "https://localhost:8080/"
|
1109 |
-
},
|
1110 |
-
"id": "l8mg2Bup1lPJ",
|
1111 |
-
"outputId": "6a2348f9-f6d8-47d3-a5d7-9169d3906d76"
|
1112 |
-
},
|
1113 |
-
"source": [
|
1114 |
-
"model = model_2(0)\n",
|
1115 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
1116 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
1117 |
-
"model.summary()"
|
1118 |
-
],
|
1119 |
-
"execution_count": null,
|
1120 |
-
"outputs": [
|
1121 |
-
{
|
1122 |
-
"output_type": "stream",
|
1123 |
-
"text": [
|
1124 |
-
"Model: \"sequential_154\"\n",
|
1125 |
-
"_________________________________________________________________\n",
|
1126 |
-
"Layer (type) Output Shape Param # \n",
|
1127 |
-
"=================================================================\n",
|
1128 |
-
"conv2d_320 (Conv2D) (None, 30, 30, 32) 320 \n",
|
1129 |
-
"_________________________________________________________________\n",
|
1130 |
-
"conv2d_321 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
1131 |
-
"_________________________________________________________________\n",
|
1132 |
-
"max_pooling2d_166 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
1133 |
-
"_________________________________________________________________\n",
|
1134 |
-
"batch_normalization_166 (Bat (None, 14, 14, 64) 256 \n",
|
1135 |
-
"_________________________________________________________________\n",
|
1136 |
-
"conv2d_322 (Conv2D) (None, 12, 12, 128) 73856 \n",
|
1137 |
-
"_________________________________________________________________\n",
|
1138 |
-
"max_pooling2d_167 (MaxPoolin (None, 6, 6, 128) 0 \n",
|
1139 |
-
"_________________________________________________________________\n",
|
1140 |
-
"batch_normalization_167 (Bat (None, 6, 6, 128) 512 \n",
|
1141 |
-
"_________________________________________________________________\n",
|
1142 |
-
"flatten_154 (Flatten) (None, 4608) 0 \n",
|
1143 |
-
"_________________________________________________________________\n",
|
1144 |
-
"dense_154 (Dense) (None, 4) 18436 \n",
|
1145 |
-
"=================================================================\n",
|
1146 |
-
"Total params: 111,876\n",
|
1147 |
-
"Trainable params: 111,492\n",
|
1148 |
-
"Non-trainable params: 384\n",
|
1149 |
-
"_________________________________________________________________\n"
|
1150 |
-
],
|
1151 |
-
"name": "stdout"
|
1152 |
-
}
|
1153 |
-
]
|
1154 |
-
},
|
1155 |
-
{
|
1156 |
-
"cell_type": "code",
|
1157 |
-
"metadata": {
|
1158 |
-
"colab": {
|
1159 |
-
"base_uri": "https://localhost:8080/"
|
1160 |
-
},
|
1161 |
-
"id": "vZImk4zB1mxk",
|
1162 |
-
"outputId": "805f844e-1312-423a-b309-35944f11a210"
|
1163 |
-
},
|
1164 |
-
"source": [
|
1165 |
-
"model = model_3(0)\n",
|
1166 |
-
"opt = tensorflow.keras.optimizers.SGD(learning_rate=0.01)\n",
|
1167 |
-
"model.compile(loss='categorical_crossentropy', optimizer=opt, metrics= ['accuracy'])\n",
|
1168 |
-
"model.summary()"
|
1169 |
-
],
|
1170 |
-
"execution_count": null,
|
1171 |
-
"outputs": [
|
1172 |
-
{
|
1173 |
-
"output_type": "stream",
|
1174 |
-
"text": [
|
1175 |
-
"Model: \"sequential_155\"\n",
|
1176 |
-
"_________________________________________________________________\n",
|
1177 |
-
"Layer (type) Output Shape Param # \n",
|
1178 |
-
"=================================================================\n",
|
1179 |
-
"conv2d_323 (Conv2D) (None, 30, 30, 32) 320 \n",
|
1180 |
-
"_________________________________________________________________\n",
|
1181 |
-
"conv2d_324 (Conv2D) (None, 28, 28, 64) 18496 \n",
|
1182 |
-
"_________________________________________________________________\n",
|
1183 |
-
"max_pooling2d_168 (MaxPoolin (None, 14, 14, 64) 0 \n",
|
1184 |
-
"_________________________________________________________________\n",
|
1185 |
-
"batch_normalization_168 (Bat (None, 14, 14, 64) 256 \n",
|
1186 |
-
"_________________________________________________________________\n",
|
1187 |
-
"conv2d_325 (Conv2D) (None, 12, 12, 128) 73856 \n",
|
1188 |
-
"_________________________________________________________________\n",
|
1189 |
-
"max_pooling2d_169 (MaxPoolin (None, 6, 6, 128) 0 \n",
|
1190 |
-
"_________________________________________________________________\n",
|
1191 |
-
"batch_normalization_169 (Bat (None, 6, 6, 128) 512 \n",
|
1192 |
-
"_________________________________________________________________\n",
|
1193 |
-
"conv2d_326 (Conv2D) (None, 4, 4, 256) 295168 \n",
|
1194 |
-
"_________________________________________________________________\n",
|
1195 |
-
"max_pooling2d_170 (MaxPoolin (None, 2, 2, 256) 0 \n",
|
1196 |
-
"_________________________________________________________________\n",
|
1197 |
-
"batch_normalization_170 (Bat (None, 2, 2, 256) 1024 \n",
|
1198 |
-
"_________________________________________________________________\n",
|
1199 |
-
"flatten_155 (Flatten) (None, 1024) 0 \n",
|
1200 |
-
"_________________________________________________________________\n",
|
1201 |
-
"dense_155 (Dense) (None, 4) 4100 \n",
|
1202 |
-
"=================================================================\n",
|
1203 |
-
"Total params: 393,732\n",
|
1204 |
-
"Trainable params: 392,836\n",
|
1205 |
-
"Non-trainable params: 896\n",
|
1206 |
-
"_________________________________________________________________\n"
|
1207 |
-
],
|
1208 |
-
"name": "stdout"
|
1209 |
-
}
|
1210 |
-
]
|
1211 |
-
}
|
1212 |
-
]
|
1213 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/gridsearchcv/grid_model_1.csv
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
;momentum;learning rate;batch size;loss1;acc1;loss2;acc2;loss3;acc3;widing factor;prec1;prec2;prec3;recall1;recall2;recall3;epoch_stopped;reg_penalty
|
2 |
-
0;0.9;0.1;64.0;1.0786635875701904;0.5410122275352478;1.087765097618103;0.5602094531059265;1.1405400037765503;0.4956369996070862;1.0;0.5456249725117649;0.5703025971923495;0.5404468631355052;0.549295511252033;0.561227419923072;0.5060386473429952;50.0;0.0
|
3 |
-
1;0.9;0.1;64.0;1.0904589891433716;0.5445026159286499;1.0521031618118286;0.5689354538917542;0.9919149279594421;0.5863874554634094;1.0;0.5677450027120325;0.5818822696361521;0.5870538777082881;0.5496315278923974;0.5715594791681748;0.5847485847485847;100.0;0.0
|
4 |
-
2;0.9;0.1;64.0;1.203269600868225;0.5497382283210754;1.0244745016098022;0.6073298454284668;1.0995512008666992;0.5427573919296265;1.0;0.5973140968206758;0.6297070150268066;0.5833525841963643;0.5586641619250315;0.6120292750727534;0.5321260864739126;150.0;0.0
|
5 |
-
3;0.9;0.1;64.0;1.3318754434585571;0.45026177167892456;1.3573890924453735;0.4258289635181427;1.1364340782165527;0.5427573919296265;1.0;0.520103685420754;0.5564409752732054;0.5418956991063357;0.4679843103756147;0.44125801734497383;0.5497298595124682;50.0;0.01
|
6 |
-
4;0.9;0.1;64.0;1.2862250804901123;0.49040138721466064;1.2183420658111572;0.5392670035362244;1.2070523500442505;0.5375218391418457;1.0;0.5528382317603805;0.5694776622208874;0.586225988700565;0.4945869348043261;0.5416503786069004;0.5475315747054877;100.0;0.01
|
7 |
-
5;0.9;0.1;64.0;1.1847248077392578;0.5445026159286499;1.4426966905593872;0.43106457591056824;1.2306617498397827;0.47993019223213196;1.0;0.5312498290025648;0.5339512821095386;0.5311977318393986;0.544515870602827;0.4474462735332301;0.49295993861211257;150.0;0.01
|
8 |
-
6;0.9;0.1;64.0;1.224439263343811;0.5095986127853394;1.1083790063858032;0.5584642291069031;1.042079210281372;0.584642231464386;1.0;0.5344176263987584;0.5580401772808269;0.5725199160342243;0.5072463768115942;0.564412640499597;0.5839504698200351;50.0;0.001
|
9 |
-
7;0.9;0.1;64.0;1.1233032941818237;0.5671902298927307;1.279549241065979;0.49389180541038513;1.0785572528839111;0.5811518430709839;1.0;0.5835013620894915;0.5600592746287815;0.5944705565842157;0.5647160810204288;0.4860936165283991;0.5825973543364849;100.0;0.001
|
10 |
-
8;0.9;0.1;64.0;1.0618174076080322;0.5898778438568115;1.0041369199752808;0.6178010702133179;1.0159738063812256;0.6020942330360413;1.0;0.5958448974789733;0.6197077143812308;0.6046378435659604;0.5848010684967206;0.6172076715554976;0.6017768463420637;150.0;0.001
|
11 |
-
9;0.9;0.1;64.0;1.1275826692581177;0.5095986127853394;1.037006139755249;0.547993004322052;1.1576757431030273;0.4712041914463043;1.0;0.5163513023182834;0.5527921101786539;0.4780335054814982;0.5184676434676434;0.5468667805624328;0.47072553050813926;50.0;0.0001
|
12 |
-
10;0.9;0.1;64.0;1.0660938024520874;0.5811518430709839;1.0170561075210571;0.5828970074653625;1.0196107625961304;0.5811518430709839;1.0;0.5900164662084766;0.5999640724915478;0.5950983436203666;0.5901320901320901;0.5829641373119634;0.5774551535421101;100.0;0.0001
|
13 |
-
11;0.9;0.1;64.0;1.0177899599075317;0.5881326198577881;1.0439162254333496;0.5654450058937073;1.1021149158477783;0.5759162306785583;1.0;0.6036141210870313;0.5902555674010898;0.6140637358081388;0.5845627802149542;0.561161664422534;0.5720825068651156;150.0;0.0001
|
14 |
-
12;0.9;0.1;128.0;1.0818946361541748;0.5253053903579712;1.0099669694900513;0.5968586206436157;1.0918453931808472;0.5410122275352478;1.0;0.521080315961989;0.6039058290508166;0.5263435769673295;0.5298210243862418;0.6008387747518182;0.5431760268716791;50.0;0.0
|
15 |
-
13;0.9;0.1;128.0;1.164839744567871;0.45724257826805115;0.9867958426475525;0.5933682322502136;1.0679926872253418;0.554973840713501;1.0;0.5123214368077382;0.5988966761697891;0.5643833075235515;0.4695654586958935;0.597261434217956;0.5518648018648019;100.0;0.0
|
16 |
-
14;0.9;0.1;128.0;1.113759160041809;0.5567190051078796;1.046524167060852;0.584642231464386;0.9935498833656311;0.5724258422851562;1.0;0.6122945895111208;0.5897342892147153;0.5820350700747354;0.553781966825445;0.58606007519051;0.5708771904424078;150.0;0.0
|
17 |
-
15;0.9;0.1;128.0;1.1370242834091187;0.5619546175003052;1.1299653053283691;0.5462478399276733;1.1480666399002075;0.5322862267494202;1.0;0.5756164609818162;0.5500235261699395;0.5222666160817157;0.5605360822752127;0.5444175389827564;0.53630849826502;50.0;0.01
|
18 |
-
16;0.9;0.1;128.0;1.1349637508392334;0.5636998414993286;1.1408828496932983;0.5410122275352478;1.3376970291137695;0.47993019223213196;1.0;0.5815890943000785;0.5523155608073429;0.5205896541198826;0.5641309173917869;0.5482174830000918;0.4919084538649756;100.0;0.01
|
19 |
-
17;0.9;0.1;128.0;1.1322417259216309;0.5881326198577881;1.1545387506484985;0.5602094531059265;1.2242604494094849;0.5654450058937073;1.0;0.599198600670369;0.5865846668946237;0.5603340273492765;0.5861801242236024;0.5607828162175988;0.5646859179467875;150.0;0.01
|
20 |
-
18;0.9;0.1;128.0;1.1220964193344116;0.5514833927154541;1.1653169393539429;0.5270506143569946;1.1562250852584839;0.5235602259635925;1.0;0.5715813891898507;0.5249991236833342;0.5223967978570353;0.5467980087545306;0.5330587287109027;0.5233751755490886;50.0;0.001
|
21 |
-
19;0.9;0.1;128.0;1.153067708015442;0.5881326198577881;1.1445388793945312;0.5602094531059265;1.1189345121383667;0.5357766151428223;1.0;0.6072767352919373;0.5604236873448876;0.5574203282247268;0.5862139068660808;0.5572374485417964;0.5373183579705318;100.0;0.001
|
22 |
-
20;0.9;0.1;128.0;1.2125698328018188;0.5410122275352478;1.1219922304153442;0.5602094531059265;1.0713446140289307;0.5828970074653625;1.0;0.5605981222253124;0.5931479015961776;0.5828110117010443;0.5311391507043681;0.5714448594883378;0.5890498390498391;150.0;0.001
|
23 |
-
21;0.9;0.1;128.0;1.1447523832321167;0.5567190051078796;1.1324081420898438;0.518324613571167;1.1090422868728638;0.49738219380378723;1.0;0.596030303030303;0.5490280922716101;0.5121340146227983;0.5638238573021181;0.5146260744086831;0.4972123287340679;50.0;0.0001
|
24 |
-
22;0.9;0.1;128.0;0.9825822710990906;0.584642231464386;1.108230471611023;0.5218150019645691;1.1356145143508911;0.5200698375701904;1.0;0.5948660954866036;0.5665316277437249;0.5656796449319814;0.5828724415680937;0.5286633656198874;0.5303995521386826;100.0;0.0001
|
25 |
-
23;0.9;0.1;128.0;1.0473191738128662;0.5951134562492371;1.0678406953811646;0.518324613571167;1.0169429779052734;0.5881326198577881;1.0;0.6219409183078248;0.5693268693732438;0.593293789725479;0.5926543263499786;0.5205706129619173;0.5907293189901885;150.0;0.0001
|
26 |
-
24;0.9;0.01;64.0;1.0598105192184448;0.5968586206436157;1.0257790088653564;0.6038394570350647;0.9506248235702515;0.6369982361793518;1.0;0.6051437947470779;0.602723696484381;0.6396683984470244;0.5915582002538524;0.609172831998919;0.6423491966970227;50.0;0.0
|
27 |
-
25;0.9;0.01;64.0;1.1066389083862305;0.6108202338218689;1.0695348978042603;0.6195462346076965;1.1887136697769165;0.5741710066795349;1.0;0.6467146374525526;0.6245149204071099;0.5971208662939392;0.6051062464105943;0.6201624462494028;0.5657778212126039;100.0;0.0
|
28 |
-
26;0.9;0.01;64.0;1.076154351234436;0.6404886841773987;1.1792618036270142;0.6073298454284668;1.4671285152435303;0.547993004322052;1.0;0.6529888091081401;0.6461774006977742;0.6066030792539256;0.6368685662163923;0.6003965840922363;0.537592841940668;150.0;0.0
|
29 |
-
27;0.9;0.01;64.0;1.2690242528915405;0.5968586206436157;1.2563925981521606;0.5968586206436157;1.2163962125778198;0.5968586206436157;1.0;0.5916521579476596;0.6248306369232546;0.5985189460592925;0.5966666184057489;0.6007983562331388;0.5971992982862548;50.0;0.01
|
30 |
-
28;0.9;0.01;64.0;1.116417646408081;0.6160558462142944;1.1469776630401611;0.6178010702133179;1.1717371940612793;0.584642231464386;1.0;0.6261048220941269;0.6335991775298133;0.610479317595543;0.6109440076831381;0.6120992534036012;0.5811808963982876;100.0;0.01
|
31 |
-
29;0.9;0.01;64.0;1.214686393737793;0.5811518430709839;1.2714550495147705;0.5584642291069031;1.2080868482589722;0.6073298454284668;1.0;0.6001486481081634;0.5946394895073237;0.6365187860446218;0.5804340586949283;0.5500978490108924;0.6034834730486904;150.0;0.01
|
32 |
-
30;0.9;0.01;64.0;1.1514617204666138;0.5881326198577881;1.0881938934326172;0.5881326198577881;1.053907871246338;0.6020942330360413;1.0;0.5834891320869433;0.6065228770652975;0.6088152818711143;0.5858362651840913;0.5805794447098794;0.6013455143889928;50.0;0.001
|
33 |
-
31;0.9;0.01;64.0;1.22362220287323;0.5968586206436157;1.0609272718429565;0.6300174593925476;1.1912355422973633;0.5968586206436157;1.0;0.621725687037072;0.6327141334647053;0.6186631707154095;0.5975027388070866;0.6298701298701299;0.5883030013464796;100.0;0.001
|
34 |
-
32;0.9;0.01;64.0;1.1824991703033447;0.5986038446426392;1.226974368095398;0.6090750694274902;1.2909026145935059;0.6038394570350647;1.0;0.6073992988417198;0.6445184510325356;0.619910049538771;0.5955710955710956;0.6017605582822975;0.5952151713021279;150.0;0.001
|
35 |
-
33;0.9;0.01;64.0;1.1622414588928223;0.5287958383560181;1.0493906736373901;0.5794066190719604;1.2138421535491943;0.5497382283210754;1.0;0.5809122116993566;0.6019885476001634;0.5910039233578294;0.5309117211291124;0.5840572471007254;0.5480895915678524;50.0;0.0001
|
36 |
-
34;0.9;0.01;64.0;1.0573968887329102;0.6073298454284668;1.0098820924758911;0.62129145860672;1.0127638578414917;0.614310622215271;1.0;0.6070338521099506;0.62878369303158;0.6239062815043207;0.603968495272843;0.6201449516666908;0.6131923631923631;100.0;0.0001
|
37 |
-
35;0.9;0.01;64.0;1.0984822511672974;0.657940685749054;1.0174990892410278;0.6561954617500305;1.1236447095870972;0.5986038446426392;1.0;0.6531590453460197;0.6671137560770346;0.5932537237439376;0.6610702099832534;0.6503381883816667;0.5962919930311235;150.0;0.0001
|
38 |
-
36;0.9;0.01;128.0;1.0621216297149658;0.584642231464386;1.0866228342056274;0.5497382283210754;1.073479413986206;0.5671902298927307;1.0;0.5977450132718186;0.5629235472659129;0.5687495323396486;0.5795900958944438;0.545400251921991;0.565356141443098;50.0;0.0
|
39 |
-
37;0.9;0.01;128.0;1.0547319650650024;0.6125654578208923;1.1377087831497192;0.5776614546775818;1.127797245979309;0.5636998414993286;1.0;0.6235207292967124;0.5720610223175946;0.5682047407187032;0.6082697495740974;0.5770853542592673;0.5593319723754506;100.0;0.0
|
40 |
-
38;0.9;0.01;128.0;0.9883608222007751;0.6265270709991455;1.316406488418579;0.5567190051078796;1.1691138744354248;0.5828970074653625;1.0;0.6431009815575375;0.5975566000144896;0.6062137814624874;0.624762918241179;0.5490433479563914;0.577417148069322;150.0;0.0
|
41 |
-
39;0.9;0.01;128.0;1.443315863609314;0.5253053903579712;1.3692501783370972;0.5828970074653625;1.342460036277771;0.5933682322502136;1.0;0.5716634678107281;0.597909057945944;0.5906145139282292;0.5164768806073153;0.57604231517275;0.5995459854155506;50.0;0.01
|
42 |
-
40;0.9;0.01;128.0;1.2021657228469849;0.6265270709991455;1.3414463996887207;0.5619546175003052;1.2604358196258545;0.5828970074653625;1.0;0.6285688058844148;0.6186485688650869;0.6176072056565729;0.6251604675517719;0.5610753980319197;0.5805716023107328;100.0;0.01
|
43 |
-
41;0.9;0.01;128.0;1.1669517755508423;0.6003490686416626;1.195756435394287;0.6020942330360413;1.1677254438400269;0.614310622215271;1.0;0.6203690104168971;0.6203570670428672;0.6432026808004327;0.5943404421665291;0.596520025867852;0.6156590993547515;150.0;0.01
|
44 |
-
42;0.9;0.01;128.0;1.1474894285202026;0.5497382283210754;1.1794898509979248;0.5357766151428223;1.0964933633804321;0.5759162306785583;1.0;0.5853768536604357;0.5617280741536463;0.5739665220002184;0.5477795151708195;0.5463829648612257;0.5779051866008388;50.0;0.001
|
45 |
-
43;0.9;0.01;128.0;1.253674864768982;0.5253053903579712;1.160112977027893;0.5776614546775818;1.1979146003723145;0.5602094531059265;1.0;0.5522893772893773;0.5946065751075158;0.609541492289036;0.5178848928848929;0.5693081073515855;0.5646515320428364;100.0;0.001
|
46 |
-
44;0.9;0.01;128.0;1.0433732271194458;0.62129145860672;1.1583247184753418;0.6055846214294434;1.2541570663452148;0.5584642291069031;1.0;0.6268407792270415;0.6419238559031413;0.5858269376442662;0.6190530484008745;0.5973609723609723;0.5483924288272114;150.0;0.001
|
47 |
-
45;0.9;0.01;128.0;1.1987037658691406;0.5375218391418457;1.1068445444107056;0.5445026159286499;1.1949487924575806;0.518324613571167;1.0;0.5710625395113904;0.5581101609004185;0.5326604175732083;0.5442896475505171;0.5378775813558422;0.5108786141394838;50.0;0.0001
|
48 |
-
46;0.9;0.01;128.0;1.1030073165893555;0.5863874554634094;1.0547511577606201;0.6195462346076965;1.066347599029541;0.5881326198577881;1.0;0.6079505660141936;0.6188059698281521;0.5895490996780354;0.5875187614318049;0.6249137336093857;0.5872448807231415;100.0;0.0001
|
49 |
-
47;0.9;0.01;128.0;1.132765293121338;0.5898778438568115;1.2393956184387207;0.5340313911437988;1.0497409105300903;0.5881326198577881;1.0;0.5911296865320181;0.566333885194723;0.6075930260346379;0.5854749115618681;0.5218344457474892;0.5897791821704865;150.0;0.0001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/model.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
from tensorflow.data import Dataset
|
2 |
-
import tensorflow.keras as keras
|
3 |
-
from tensorflow.keras.optimizers import Adam
|
4 |
-
from tensorflow.keras.layers import (
|
5 |
-
Conv2D,
|
6 |
-
Input,
|
7 |
-
MaxPooling2D,
|
8 |
-
Dense,
|
9 |
-
Dropout,
|
10 |
-
MaxPool1D,
|
11 |
-
Flatten,
|
12 |
-
AveragePooling1D,
|
13 |
-
BatchNormalization,
|
14 |
-
)
|
15 |
-
from tensorflow.keras import Model
|
16 |
-
import numpy as np
|
17 |
-
import tensorflow as tf
|
18 |
-
from tensorflow.keras.models import Sequential
|
19 |
-
from tensorflow.keras.models import Model
|
20 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
21 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
22 |
-
from tensorflow.keras.layers import BatchNormalization
|
23 |
-
from tensorflow.keras.regularizers import l2
|
24 |
-
from tensorflow.keras import backend as K
|
25 |
-
from tensorflow.keras.optimizers import SGD
|
26 |
-
import warnings
|
27 |
-
|
28 |
-
warnings.filterwarnings("ignore")
|
29 |
-
|
30 |
-
|
31 |
-
def basemodel(weight_decay):
|
32 |
-
# 2 hidden layers
|
33 |
-
model_input = Input(
|
34 |
-
shape=(
|
35 |
-
32,
|
36 |
-
32,
|
37 |
-
1,
|
38 |
-
)
|
39 |
-
)
|
40 |
-
model = Conv2D(
|
41 |
-
32,
|
42 |
-
kernel_size=(3, 3),
|
43 |
-
kernel_regularizer=l2(weight_decay),
|
44 |
-
activation="relu",
|
45 |
-
)(model_input)
|
46 |
-
model = Conv2D(
|
47 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
48 |
-
)(model)
|
49 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
50 |
-
model = BatchNormalization()(model)
|
51 |
-
model = Flatten()(model)
|
52 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
53 |
-
model = Model(inputs=model_input, outputs=model)
|
54 |
-
return model
|
55 |
-
|
56 |
-
|
57 |
-
def model_2(weight_decay):
|
58 |
-
model_input = Input(
|
59 |
-
shape=(
|
60 |
-
32,
|
61 |
-
32,
|
62 |
-
1,
|
63 |
-
)
|
64 |
-
)
|
65 |
-
model = Conv2D(
|
66 |
-
32,
|
67 |
-
kernel_size=(3, 3),
|
68 |
-
kernel_regularizer=l2(weight_decay),
|
69 |
-
activation="relu",
|
70 |
-
)(model_input)
|
71 |
-
model = Conv2D(
|
72 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
73 |
-
)(model)
|
74 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
75 |
-
model = BatchNormalization()(model)
|
76 |
-
model = Conv2D(
|
77 |
-
128, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
78 |
-
)(model)
|
79 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
80 |
-
model = BatchNormalization()(model)
|
81 |
-
model = Flatten()(model)
|
82 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
83 |
-
model = Model(inputs=model_input, outputs=model)
|
84 |
-
return model
|
85 |
-
|
86 |
-
|
87 |
-
def model_3(weight_decay):
|
88 |
-
# 4 hidden layers
|
89 |
-
model_input = Input(
|
90 |
-
shape=(
|
91 |
-
32,
|
92 |
-
32,
|
93 |
-
1,
|
94 |
-
)
|
95 |
-
)
|
96 |
-
model = Conv2D(
|
97 |
-
32,
|
98 |
-
kernel_size=(3, 3),
|
99 |
-
kernel_regularizer=l2(weight_decay),
|
100 |
-
activation="relu",
|
101 |
-
)(model_input)
|
102 |
-
model = Conv2D(
|
103 |
-
64, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
104 |
-
)(model)
|
105 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
106 |
-
model = BatchNormalization()(model)
|
107 |
-
model = Conv2D(
|
108 |
-
128, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
109 |
-
)(model)
|
110 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
111 |
-
model = BatchNormalization()(model)
|
112 |
-
model = Conv2D(
|
113 |
-
256, kernel_size=(3, 3), kernel_regularizer=l2(weight_decay), activation="relu"
|
114 |
-
)(model)
|
115 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
116 |
-
model = BatchNormalization()(model)
|
117 |
-
model = Flatten()(model)
|
118 |
-
model = Dense(4, kernel_regularizer=l2(weight_decay), activation="softmax")(model)
|
119 |
-
model = Model(inputs=model_input, outputs=model)
|
120 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/FullyConectedModels/parseval.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
from tensorflow.data import Dataset
|
2 |
-
import tensorflow.keras as keras
|
3 |
-
from tensorflow.keras.optimizers import Adam
|
4 |
-
from tensorflow.keras.layers import (
|
5 |
-
Conv2D,
|
6 |
-
Input,
|
7 |
-
MaxPooling2D,
|
8 |
-
Dense,
|
9 |
-
Dropout,
|
10 |
-
MaxPool1D,
|
11 |
-
Flatten,
|
12 |
-
AveragePooling1D,
|
13 |
-
BatchNormalization,
|
14 |
-
)
|
15 |
-
from tensorflow.keras import Model
|
16 |
-
import numpy as np
|
17 |
-
import tensorflow as tf
|
18 |
-
from tensorflow.keras.models import Sequential
|
19 |
-
from tensorflow.keras.models import Model
|
20 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
21 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
22 |
-
from tensorflow.keras.layers import BatchNormalization
|
23 |
-
from tensorflow.keras.regularizers import l2
|
24 |
-
from tensorflow.keras import backend as K
|
25 |
-
from tensorflow.keras.optimizers import SGD
|
26 |
-
import warnings
|
27 |
-
from constraint import tight_frame
|
28 |
-
|
29 |
-
warnings.filterwarnings("ignore")
|
30 |
-
|
31 |
-
|
32 |
-
def model_parseval(weight_decay):
|
33 |
-
|
34 |
-
model_input = Input(
|
35 |
-
shape=(
|
36 |
-
32,
|
37 |
-
32,
|
38 |
-
1,
|
39 |
-
)
|
40 |
-
)
|
41 |
-
model = Conv2D(
|
42 |
-
32,
|
43 |
-
kernel_size=(3, 3),
|
44 |
-
activation="relu",
|
45 |
-
input_shape=(32, 32, 1),
|
46 |
-
kernel_regularizer=l2(weight_decay),
|
47 |
-
kernel_constraint=tight_frame(0.001),
|
48 |
-
kernel_initializer="Orthogonal",
|
49 |
-
)(model_input)
|
50 |
-
model = Conv2D(
|
51 |
-
64,
|
52 |
-
kernel_size=(3, 3),
|
53 |
-
activation="relu",
|
54 |
-
kernel_regularizer=l2(weight_decay),
|
55 |
-
kernel_initializer="Orthogonal",
|
56 |
-
kernel_constraint=tight_frame(0.001),
|
57 |
-
)(model)
|
58 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
59 |
-
model = BatchNormalization()(model)
|
60 |
-
model = Conv2D(
|
61 |
-
128,
|
62 |
-
kernel_size=(3, 3),
|
63 |
-
activation="relu",
|
64 |
-
kernel_initializer="Orthogonal",
|
65 |
-
kernel_regularizer=l2(weight_decay),
|
66 |
-
kernel_constraint=tight_frame(0.001),
|
67 |
-
)(model)
|
68 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
69 |
-
model = BatchNormalization()(model)
|
70 |
-
model = Conv2D(
|
71 |
-
256,
|
72 |
-
kernel_size=(3, 3),
|
73 |
-
activation="relu",
|
74 |
-
kernel_initializer="Orthogonal",
|
75 |
-
kernel_regularizer=l2(weight_decay),
|
76 |
-
kernel_constraint=tight_frame(0.001),
|
77 |
-
)(model)
|
78 |
-
model = MaxPooling2D(pool_size=(2, 2))(model)
|
79 |
-
model = BatchNormalization()(model)
|
80 |
-
model = Flatten()(model)
|
81 |
-
model = Dense(4, activation="softmax", kernel_regularizer=l2(weight_decay))(model)
|
82 |
-
model = Model(inputs=model_input, outputs=model)
|
83 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/README.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
## ParsevalNetworks
|
2 |
-
* Orthogonality Constraint
|
3 |
-
* Convexity Constraint
|
|
|
|
|
|
|
|
models/Parseval_Networks/constraint.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
from tensorflow.python.keras.constraints import Constraint
|
2 |
-
from tensorflow.python.ops import math_ops, array_ops
|
3 |
-
|
4 |
-
|
5 |
-
class TightFrame(Constraint):
|
6 |
-
"""
|
7 |
-
Parseval (tight) frame contstraint, as introduced in https://arxiv.org/abs/1704.08847
|
8 |
-
|
9 |
-
Constraints the weight matrix to be a tight frame, so that the Lipschitz
|
10 |
-
constant of the layer is <= 1. This increases the robustness of the network
|
11 |
-
to adversarial noise.
|
12 |
-
|
13 |
-
Warning: This constraint simply performs the update step on the weight matrix
|
14 |
-
(or the unfolded weight matrix for convolutional layers). Thus, it does not
|
15 |
-
handle the necessary scalings for convolutional layers.
|
16 |
-
|
17 |
-
Args:
|
18 |
-
scale (float): Retraction parameter (length of retraction step).
|
19 |
-
num_passes (int): Number of retraction steps.
|
20 |
-
|
21 |
-
Returns:
|
22 |
-
Weight matrix after applying regularizer.
|
23 |
-
"""
|
24 |
-
|
25 |
-
def __init__(self, scale, num_passes=1):
|
26 |
-
"""[summary]
|
27 |
-
|
28 |
-
Args:
|
29 |
-
scale ([type]): [description]
|
30 |
-
num_passes (int, optional): [description]. Defaults to 1.
|
31 |
-
|
32 |
-
Raises:
|
33 |
-
ValueError: [description]
|
34 |
-
"""
|
35 |
-
self.scale = scale
|
36 |
-
|
37 |
-
if num_passes < 1:
|
38 |
-
raise ValueError(
|
39 |
-
"Number of passes cannot be non-positive! (got {})".format(num_passes)
|
40 |
-
)
|
41 |
-
self.num_passes = num_passes
|
42 |
-
|
43 |
-
def __call__(self, w):
|
44 |
-
"""[summary]
|
45 |
-
|
46 |
-
Args:
|
47 |
-
w ([type]): weight of conv or linear layers
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
[type]: returns new weights
|
51 |
-
"""
|
52 |
-
transpose_channels = len(w.shape) == 4
|
53 |
-
|
54 |
-
# Move channels_num to the front in order to make the dimensions correct for matmul
|
55 |
-
if transpose_channels:
|
56 |
-
w_reordered = array_ops.reshape(w, (-1, w.shape[3]))
|
57 |
-
|
58 |
-
else:
|
59 |
-
w_reordered = w
|
60 |
-
|
61 |
-
last = w_reordered
|
62 |
-
for i in range(self.num_passes):
|
63 |
-
temp1 = math_ops.matmul(last, last, transpose_a=True)
|
64 |
-
temp2 = (1 + self.scale) * w_reordered - self.scale * math_ops.matmul(
|
65 |
-
w_reordered, temp1
|
66 |
-
)
|
67 |
-
|
68 |
-
last = temp2
|
69 |
-
|
70 |
-
# Move channels_num to the back again
|
71 |
-
if transpose_channels:
|
72 |
-
return array_ops.reshape(last, w.shape)
|
73 |
-
else:
|
74 |
-
return last
|
75 |
-
|
76 |
-
def get_config(self):
|
77 |
-
return {"scale": self.scale, "num_passes": self.num_passes}
|
78 |
-
|
79 |
-
|
80 |
-
# Alias
|
81 |
-
tight_frame = TightFrame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/convexity_constraint.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
from tensorflow.python.ops import math_ops
|
2 |
-
from tensorflow.python.ops import variables
|
3 |
-
from tensorflow.python.framework import dtypes
|
4 |
-
import numpy as _np
|
5 |
-
|
6 |
-
|
7 |
-
def convex_add(input_layer, layer_3, initial_convex_par=0.5, trainable=False):
|
8 |
-
"""
|
9 |
-
Do a convex combination of input_layer and layer_3. That is, return the output of
|
10 |
-
|
11 |
-
lamda* input_layer + (1 - lamda) * layer_3
|
12 |
-
|
13 |
-
|
14 |
-
Args:
|
15 |
-
input_layer (tf.Tensor): Input to take convex combinatio of
|
16 |
-
layer_3 (tf.Tensor): Input to take convex combinatio of
|
17 |
-
initial_convex_par (float): Initial value for convex parameter. Must be
|
18 |
-
in [0, 1].
|
19 |
-
trainable (bool): Whether convex parameter should be trainable
|
20 |
-
or not.
|
21 |
-
|
22 |
-
Returns:
|
23 |
-
tf.Tensor: Result of convex combination
|
24 |
-
"""
|
25 |
-
# Will implement this as sigmoid(p)*input_layer + (1-sigmoid(p))*layer_3 to ensure
|
26 |
-
# convex parameter to be in the unit interval without constraints during
|
27 |
-
# optimization
|
28 |
-
|
29 |
-
# Find value for p, also check for legal initial_convex_par
|
30 |
-
if initial_convex_par < 0:
|
31 |
-
raise ValueError("Convex parameter must be >=0")
|
32 |
-
|
33 |
-
elif initial_convex_par == 0:
|
34 |
-
# sigmoid(-16) is approximately a 32bit roundoff error, practically 0
|
35 |
-
initial_p_value = -16
|
36 |
-
|
37 |
-
elif initial_convex_par < 1:
|
38 |
-
# Compute inverse of sigmoid to find initial p value
|
39 |
-
initial_p_value = -_np.log(1 / initial_convex_par - 1)
|
40 |
-
|
41 |
-
elif initial_convex_par == 1:
|
42 |
-
# Same argument as for 0
|
43 |
-
initial_p_value = 16
|
44 |
-
|
45 |
-
else:
|
46 |
-
raise ValueError("Convex parameter must be <=1")
|
47 |
-
|
48 |
-
p = variables.Variable(
|
49 |
-
initial_value=initial_p_value, dtype=dtypes.float32, trainable=trainable
|
50 |
-
)
|
51 |
-
|
52 |
-
lam = math_ops.sigmoid(p)
|
53 |
-
return input_layer * lam + (1 - lam) * layer_3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/Parseval_Networks/parsevalnet.py
DELETED
@@ -1,328 +0,0 @@
|
|
1 |
-
from tensorflow.keras.models import Model
|
2 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
3 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
4 |
-
from tensorflow.keras.layers import BatchNormalization
|
5 |
-
from tensorflow.keras.regularizers import l2
|
6 |
-
from tensorflow.keras import backend as K
|
7 |
-
from tensorflow.keras.optimizers import SGD
|
8 |
-
import warnings
|
9 |
-
from constraint import tight_frame
|
10 |
-
from convexity_constraint import convex_add
|
11 |
-
|
12 |
-
warnings.filterwarnings("ignore")
|
13 |
-
|
14 |
-
|
15 |
-
class ParsevalNetwork(Model):
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
input_dim,
|
19 |
-
weight_decay,
|
20 |
-
momentum,
|
21 |
-
nb_classes=4,
|
22 |
-
N=2,
|
23 |
-
k=1,
|
24 |
-
dropout=0.0,
|
25 |
-
verbose=1,
|
26 |
-
):
|
27 |
-
"""[Assign the initial parameters of the wide residual network]
|
28 |
-
|
29 |
-
Args:
|
30 |
-
weight_decay ([float]): [description]
|
31 |
-
input_dim ([tuple]): [input dimension]
|
32 |
-
nb_classes (int, optional): [output class]. Defaults to 4.
|
33 |
-
N (int, optional): [the number of blocks]. Defaults to 2.
|
34 |
-
k (int, optional): [network width]. Defaults to 1.
|
35 |
-
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0.
|
36 |
-
verbose (int, optional): [description]. Defaults to 1.
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
[Model]: [parsevalnetwork]
|
40 |
-
"""
|
41 |
-
self.weight_decay = weight_decay
|
42 |
-
self.input_dim = input_dim
|
43 |
-
self.nb_classes = nb_classes
|
44 |
-
self.N = N
|
45 |
-
self.k = k
|
46 |
-
self.dropout = dropout
|
47 |
-
self.verbose = verbose
|
48 |
-
|
49 |
-
def initial_conv(self, input):
|
50 |
-
"""[summary]
|
51 |
-
|
52 |
-
Args:
|
53 |
-
input ([type]): [description]
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
[type]: [description]
|
57 |
-
"""
|
58 |
-
x = Convolution2D(
|
59 |
-
16,
|
60 |
-
(3, 3),
|
61 |
-
padding="same",
|
62 |
-
kernel_initializer="orthogonal",
|
63 |
-
kernel_regularizer=l2(self.weight_decay),
|
64 |
-
kernel_constraint=tight_frame(0.001),
|
65 |
-
use_bias=False,
|
66 |
-
)(input)
|
67 |
-
|
68 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
69 |
-
|
70 |
-
x = BatchNormalization(
|
71 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
72 |
-
)(x)
|
73 |
-
x = Activation("relu")(x)
|
74 |
-
return x
|
75 |
-
|
76 |
-
def expand_conv(self, init, base, k, strides=(1, 1)):
|
77 |
-
"""[summary]
|
78 |
-
|
79 |
-
Args:
|
80 |
-
init ([type]): [description]
|
81 |
-
base ([type]): [description]
|
82 |
-
k ([type]): [description]
|
83 |
-
strides (tuple, optional): [description]. Defaults to (1, 1).
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
[type]: [description]
|
87 |
-
"""
|
88 |
-
x = Convolution2D(
|
89 |
-
base * k,
|
90 |
-
(3, 3),
|
91 |
-
padding="same",
|
92 |
-
strides=strides,
|
93 |
-
kernel_initializer="Orthogonal",
|
94 |
-
kernel_regularizer=l2(self.weight_decay),
|
95 |
-
kernel_constraint=tight_frame(0.001),
|
96 |
-
use_bias=False,
|
97 |
-
)(init)
|
98 |
-
|
99 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
100 |
-
|
101 |
-
x = BatchNormalization(
|
102 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
103 |
-
)(x)
|
104 |
-
x = Activation("relu")(x)
|
105 |
-
|
106 |
-
x = Convolution2D(
|
107 |
-
base * k,
|
108 |
-
(3, 3),
|
109 |
-
padding="same",
|
110 |
-
kernel_initializer="Orthogonal",
|
111 |
-
kernel_regularizer=l2(self.weight_decay),
|
112 |
-
kernel_constraint=tight_frame(0.001),
|
113 |
-
use_bias=False,
|
114 |
-
)(x)
|
115 |
-
|
116 |
-
skip = Convolution2D(
|
117 |
-
base * k,
|
118 |
-
(1, 1),
|
119 |
-
padding="same",
|
120 |
-
strides=strides,
|
121 |
-
kernel_initializer="Orthogonal",
|
122 |
-
kernel_regularizer=l2(self.weight_decay),
|
123 |
-
kernel_constraint=tight_frame(0.001),
|
124 |
-
use_bias=False,
|
125 |
-
)(init)
|
126 |
-
|
127 |
-
m = Add()([x, skip])
|
128 |
-
|
129 |
-
return m
|
130 |
-
|
131 |
-
def conv1_block(self, input, k=1, dropout=0.0):
|
132 |
-
"""[summary]
|
133 |
-
|
134 |
-
Args:
|
135 |
-
input ([type]): [description]
|
136 |
-
k (int, optional): [description]. Defaults to 1.
|
137 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
138 |
-
|
139 |
-
Returns:
|
140 |
-
[type]: [description]
|
141 |
-
"""
|
142 |
-
init = input
|
143 |
-
|
144 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
145 |
-
|
146 |
-
x = BatchNormalization(
|
147 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
148 |
-
)(input)
|
149 |
-
x = Activation("relu")(x)
|
150 |
-
x = Convolution2D(
|
151 |
-
16 * k,
|
152 |
-
(3, 3),
|
153 |
-
padding="same",
|
154 |
-
kernel_initializer="Orthogonal",
|
155 |
-
kernel_regularizer=l2(self.weight_decay),
|
156 |
-
kernel_constraint=tight_frame(0.001),
|
157 |
-
use_bias=False,
|
158 |
-
)(x)
|
159 |
-
|
160 |
-
if dropout > 0.0:
|
161 |
-
x = Dropout(dropout)(x)
|
162 |
-
|
163 |
-
x = BatchNormalization(
|
164 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
165 |
-
)(x)
|
166 |
-
x = Activation("relu")(x)
|
167 |
-
x = Convolution2D(
|
168 |
-
16 * k,
|
169 |
-
(3, 3),
|
170 |
-
padding="same",
|
171 |
-
kernel_initializer="Orthogonal",
|
172 |
-
kernel_regularizer=l2(self.weight_decay),
|
173 |
-
kernel_constraint=tight_frame(0.001),
|
174 |
-
use_bias=False,
|
175 |
-
)(x)
|
176 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
177 |
-
return m
|
178 |
-
|
179 |
-
def conv2_block(self, input, k=1, dropout=0.0):
|
180 |
-
"""[summary]
|
181 |
-
|
182 |
-
Args:
|
183 |
-
input ([type]): [description]
|
184 |
-
k (int, optional): [description]. Defaults to 1.
|
185 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
186 |
-
|
187 |
-
Returns:
|
188 |
-
[type]: [description]
|
189 |
-
"""
|
190 |
-
init = input
|
191 |
-
|
192 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
193 |
-
x = BatchNormalization(
|
194 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
195 |
-
)(input)
|
196 |
-
x = Activation("relu")(x)
|
197 |
-
x = Convolution2D(
|
198 |
-
32 * k,
|
199 |
-
(3, 3),
|
200 |
-
padding="same",
|
201 |
-
kernel_initializer="Orthogonal",
|
202 |
-
kernel_regularizer=l2(self.weight_decay),
|
203 |
-
kernel_constraint=tight_frame(0.001),
|
204 |
-
use_bias=False,
|
205 |
-
)(x)
|
206 |
-
|
207 |
-
if dropout > 0.0:
|
208 |
-
x = Dropout(dropout)(x)
|
209 |
-
|
210 |
-
x = BatchNormalization(
|
211 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
212 |
-
)(x)
|
213 |
-
x = Activation("relu")(x)
|
214 |
-
x = Convolution2D(
|
215 |
-
32 * k,
|
216 |
-
(3, 3),
|
217 |
-
padding="same",
|
218 |
-
kernel_initializer="Orthogonal",
|
219 |
-
kernel_regularizer=l2(self.weight_decay),
|
220 |
-
kernel_constraint=tight_frame(0.001),
|
221 |
-
use_bias=False,
|
222 |
-
)(x)
|
223 |
-
|
224 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
225 |
-
return m
|
226 |
-
|
227 |
-
def conv3_block(self, input, k=1, dropout=0.0):
|
228 |
-
init = input
|
229 |
-
|
230 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
231 |
-
x = BatchNormalization(
|
232 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
233 |
-
)(input)
|
234 |
-
x = Activation("relu")(x)
|
235 |
-
x = Convolution2D(
|
236 |
-
64 * k,
|
237 |
-
(3, 3),
|
238 |
-
padding="same",
|
239 |
-
kernel_initializer="Orthogonal",
|
240 |
-
kernel_constraint=tight_frame(0.001),
|
241 |
-
kernel_regularizer=l2(self.weight_decay),
|
242 |
-
use_bias=False,
|
243 |
-
)(x)
|
244 |
-
|
245 |
-
if dropout > 0.0:
|
246 |
-
x = Dropout(dropout)(x)
|
247 |
-
|
248 |
-
x = BatchNormalization(
|
249 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
250 |
-
)(x)
|
251 |
-
x = Activation("relu")(x)
|
252 |
-
x = Convolution2D(
|
253 |
-
64 * k,
|
254 |
-
(3, 3),
|
255 |
-
padding="same",
|
256 |
-
kernel_initializer="Orthogonal",
|
257 |
-
kernel_constraint=tight_frame(0.001),
|
258 |
-
kernel_regularizer=l2(self.weight_decay),
|
259 |
-
use_bias=False,
|
260 |
-
)(x)
|
261 |
-
|
262 |
-
m = convex_add(init, x, initial_convex_par=0.5, trainable=True)
|
263 |
-
return m
|
264 |
-
|
265 |
-
def create_wide_residual_network(self):
|
266 |
-
"""create a wide residual network model
|
267 |
-
|
268 |
-
|
269 |
-
Returns:
|
270 |
-
[Model]: [wide residual network]
|
271 |
-
"""
|
272 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
273 |
-
|
274 |
-
ip = Input(shape=self.input_dim)
|
275 |
-
|
276 |
-
x = self.initial_conv(ip)
|
277 |
-
nb_conv = 4
|
278 |
-
|
279 |
-
x = self.expand_conv(x, 16, self.k)
|
280 |
-
nb_conv += 2
|
281 |
-
|
282 |
-
for i in range(self.N - 1):
|
283 |
-
x = self.conv1_block(x, self.k, self.dropout)
|
284 |
-
nb_conv += 2
|
285 |
-
|
286 |
-
x = BatchNormalization(
|
287 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
288 |
-
)(x)
|
289 |
-
x = Activation("relu")(x)
|
290 |
-
|
291 |
-
x = self.expand_conv(x, 32, self.k, strides=(2, 2))
|
292 |
-
nb_conv += 2
|
293 |
-
|
294 |
-
for i in range(self.N - 1):
|
295 |
-
x = self.conv2_block(x, self.k, self.dropout)
|
296 |
-
nb_conv += 2
|
297 |
-
|
298 |
-
x = BatchNormalization(
|
299 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
300 |
-
)(x)
|
301 |
-
x = Activation("relu")(x)
|
302 |
-
|
303 |
-
x = self.expand_conv(x, 64, self.k, strides=(2, 2))
|
304 |
-
nb_conv += 2
|
305 |
-
|
306 |
-
for i in range(self.N - 1):
|
307 |
-
x = self.conv3_block(x, self.k, self.dropout)
|
308 |
-
nb_conv += 2
|
309 |
-
|
310 |
-
x = BatchNormalization(
|
311 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
312 |
-
)(x)
|
313 |
-
x = Activation("relu")(x)
|
314 |
-
|
315 |
-
x = AveragePooling2D((8, 8))(x)
|
316 |
-
x = Flatten()(x)
|
317 |
-
|
318 |
-
x = Dense(
|
319 |
-
self.nb_classes,
|
320 |
-
kernel_regularizer=l2(self.weight_decay),
|
321 |
-
activation="softmax",
|
322 |
-
)(x)
|
323 |
-
|
324 |
-
model = Model(ip, x)
|
325 |
-
|
326 |
-
if self.verbose:
|
327 |
-
print("Parseval Network-%d-%d created." % (nb_conv, self.k))
|
328 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/README.md
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
## Models
|
2 |
-
|
3 |
-
````
|
4 |
-
├── Parseval_network
|
5 |
-
│ ├── __init__.py
|
6 |
-
│ └── Parseval_resnet.py
|
7 |
-
├── Parseval_Networks_OC
|
8 |
-
│ ├── constraint.py
|
9 |
-
│ ├── parsnet_oc.py
|
10 |
-
│ └── README.md
|
11 |
-
├── README.md
|
12 |
-
├── _utility.py
|
13 |
-
└── wideresnet
|
14 |
-
└── wresnet.py
|
15 |
-
````
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/_utility.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
from tensorflow.keras.callbacks import LearningRateScheduler
|
2 |
-
|
3 |
-
# Define configuration parameters
|
4 |
-
import math
|
5 |
-
import cleverhans
|
6 |
-
from cleverhans.tf2.attacks.fast_gradient_method import fast_gradient_method
|
7 |
-
import tensorflow as tf
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
|
12 |
-
def step_decay(epoch):
|
13 |
-
"""[summary]
|
14 |
-
|
15 |
-
Args:
|
16 |
-
epoch (int): epoch number
|
17 |
-
|
18 |
-
Returns:
|
19 |
-
lrate(float): new learning rate
|
20 |
-
"""
|
21 |
-
initial_lrate = 0.1
|
22 |
-
factor = 0.1
|
23 |
-
if epoch < 10:
|
24 |
-
lrate = initial_lrate
|
25 |
-
elif epoch < 20:
|
26 |
-
lrate = initial_lrate * math.pow(factor, 1)
|
27 |
-
elif epoch < 30:
|
28 |
-
lrate = initial_lrate * math.pow(factor, 2)
|
29 |
-
elif epoch < 40:
|
30 |
-
lrate = initial_lrate * math.pow(factor, 3)
|
31 |
-
else:
|
32 |
-
lrate = initial_lrate * math.pow(factor, 4)
|
33 |
-
return lrate
|
34 |
-
|
35 |
-
|
36 |
-
def step_decay_conv(epoch):
|
37 |
-
"""step decay for learning rate in convolutional networks
|
38 |
-
|
39 |
-
Args:
|
40 |
-
epoch (int): epoch number
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
lrate(float): new learning rate
|
44 |
-
"""
|
45 |
-
initial_lrate = 0.01
|
46 |
-
factor = 0.1
|
47 |
-
if epoch < 10:
|
48 |
-
lrate = initial_lrate
|
49 |
-
elif epoch < 20:
|
50 |
-
lrate = initial_lrate * math.pow(factor, 1)
|
51 |
-
elif epoch < 30:
|
52 |
-
lrate = initial_lrate * math.pow(factor, 2)
|
53 |
-
elif epoch < 40:
|
54 |
-
lrate = initial_lrate * math.pow(factor, 3)
|
55 |
-
else:
|
56 |
-
lrate = initial_lrate * math.pow(factor, 4)
|
57 |
-
return lrate
|
58 |
-
|
59 |
-
|
60 |
-
def print_test(model, X_adv, X_test, y_test, epsilon):
|
61 |
-
"""
|
62 |
-
returns the test results and show the SNR and evaluation results
|
63 |
-
"""
|
64 |
-
loss, acc = model.evaluate(X_adv, y_test)
|
65 |
-
print("epsilon: {} and test evaluation : {}, {}".format(epsilon, loss, acc))
|
66 |
-
SNR = 20 * np.log10(np.linalg.norm(X_test) / np.linalg.norm(X_test - X_adv))
|
67 |
-
print("SNR: {}".format(SNR))
|
68 |
-
return loss, acc
|
69 |
-
|
70 |
-
|
71 |
-
def get_adversarial_examples(pretrained_model, X_true, y_true, epsilon):
|
72 |
-
"""
|
73 |
-
The attack requires the model to ouput the logits
|
74 |
-
returns the adversarial example/s of a given image/s for epsilon value using
|
75 |
-
fast gradient sign method
|
76 |
-
"""
|
77 |
-
logits_model = tf.keras.Model(
|
78 |
-
pretrained_model.input, pretrained_model.layers[-1].output
|
79 |
-
)
|
80 |
-
X_adv = []
|
81 |
-
|
82 |
-
for i in range(len(X_true)):
|
83 |
-
|
84 |
-
random_index = i
|
85 |
-
|
86 |
-
original_image = X_true[random_index]
|
87 |
-
original_image = tf.convert_to_tensor(
|
88 |
-
original_image.reshape((1, 32, 32))
|
89 |
-
) # The .reshape just gives it the proper form to input into the model, a batch of 1 a.k.a a tensor
|
90 |
-
original_label = y_true[random_index]
|
91 |
-
original_label = np.reshape(np.argmax(original_label), (1,)).astype("int64")
|
92 |
-
|
93 |
-
adv_example_targeted_label = fast_gradient_method(
|
94 |
-
logits_model,
|
95 |
-
original_image,
|
96 |
-
epsilon,
|
97 |
-
np.inf,
|
98 |
-
y=original_label,
|
99 |
-
targeted=False,
|
100 |
-
)
|
101 |
-
X_adv.append(np.array(adv_example_targeted_label).reshape(32, 32, 1))
|
102 |
-
|
103 |
-
X_adv = np.array(X_adv)
|
104 |
-
|
105 |
-
return X_adv
|
106 |
-
|
107 |
-
|
108 |
-
lrate_conv = LearningRateScheduler(step_decay_conv)
|
109 |
-
lrate = LearningRateScheduler(step_decay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/wideresnet/wresnet.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
from tensorflow.keras.models import Model
|
2 |
-
from tensorflow.keras.layers import Input, Add, Activation, Dropout, Flatten, Dense
|
3 |
-
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
|
4 |
-
from tensorflow.keras.layers import BatchNormalization
|
5 |
-
from tensorflow.keras.regularizers import l2
|
6 |
-
from tensorflow.keras import backend as K
|
7 |
-
from tensorflow.keras.optimizers import SGD
|
8 |
-
import warnings
|
9 |
-
|
10 |
-
warnings.filterwarnings("ignore")
|
11 |
-
|
12 |
-
|
13 |
-
class WideResidualNetwork(object):
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
input_dim,
|
17 |
-
weight_decay,
|
18 |
-
momentum,
|
19 |
-
nb_classes=100,
|
20 |
-
N=2,
|
21 |
-
k=1,
|
22 |
-
dropout=0.0,
|
23 |
-
verbose=1,
|
24 |
-
):
|
25 |
-
"""[Assign the initial parameters of the wide residual network]
|
26 |
-
|
27 |
-
Args:
|
28 |
-
weight_decay ([float]): [description]
|
29 |
-
input_dim ([tuple]): [input dimension]
|
30 |
-
nb_classes (int, optional): [output class]. Defaults to 100.
|
31 |
-
N (int, optional): [the number of blocks]. Defaults to 2.
|
32 |
-
k (int, optional): [network width]. Defaults to 1.
|
33 |
-
dropout (float, optional): [dropout value to prevent overfitting]. Defaults to 0.0.
|
34 |
-
verbose (int, optional): [description]. Defaults to 1.
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
[Model]: [wideresnet]
|
38 |
-
"""
|
39 |
-
self.weight_decay = weight_decay
|
40 |
-
self.input_dim = input_dim
|
41 |
-
self.nb_classes = nb_classes
|
42 |
-
self.N = N
|
43 |
-
self.k = k
|
44 |
-
self.dropout = dropout
|
45 |
-
self.verbose = verbose
|
46 |
-
|
47 |
-
def initial_conv(self, input):
|
48 |
-
"""[summary]
|
49 |
-
|
50 |
-
Args:
|
51 |
-
input ([type]): [description]
|
52 |
-
|
53 |
-
Returns:
|
54 |
-
[type]: [description]
|
55 |
-
"""
|
56 |
-
x = Convolution2D(
|
57 |
-
16,
|
58 |
-
(3, 3),
|
59 |
-
padding="same",
|
60 |
-
kernel_initializer="he_normal",
|
61 |
-
kernel_regularizer=l2(self.weight_decay),
|
62 |
-
use_bias=False,
|
63 |
-
)(input)
|
64 |
-
|
65 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
66 |
-
|
67 |
-
x = BatchNormalization(
|
68 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
69 |
-
)(x)
|
70 |
-
x = Activation("relu")(x)
|
71 |
-
return x
|
72 |
-
|
73 |
-
def expand_conv(self, init, base, k, strides=(1, 1)):
|
74 |
-
"""[summary]
|
75 |
-
|
76 |
-
Args:
|
77 |
-
init ([type]): [description]
|
78 |
-
base ([type]): [description]
|
79 |
-
k ([type]): [description]
|
80 |
-
strides (tuple, optional): [description]. Defaults to (1, 1).
|
81 |
-
|
82 |
-
Returns:
|
83 |
-
[type]: [description]
|
84 |
-
"""
|
85 |
-
x = Convolution2D(
|
86 |
-
base * k,
|
87 |
-
(3, 3),
|
88 |
-
padding="same",
|
89 |
-
strides=strides,
|
90 |
-
kernel_initializer="he_normal",
|
91 |
-
kernel_regularizer=l2(self.weight_decay),
|
92 |
-
use_bias=False,
|
93 |
-
)(init)
|
94 |
-
|
95 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
96 |
-
|
97 |
-
x = BatchNormalization(
|
98 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
99 |
-
)(x)
|
100 |
-
x = Activation("relu")(x)
|
101 |
-
|
102 |
-
x = Convolution2D(
|
103 |
-
base * k,
|
104 |
-
(3, 3),
|
105 |
-
padding="same",
|
106 |
-
kernel_initializer="he_normal",
|
107 |
-
kernel_regularizer=l2(self.weight_decay),
|
108 |
-
use_bias=False,
|
109 |
-
)(x)
|
110 |
-
|
111 |
-
skip = Convolution2D(
|
112 |
-
base * k,
|
113 |
-
(1, 1),
|
114 |
-
padding="same",
|
115 |
-
strides=strides,
|
116 |
-
kernel_initializer="he_normal",
|
117 |
-
kernel_regularizer=l2(self.weight_decay),
|
118 |
-
use_bias=False,
|
119 |
-
)(init)
|
120 |
-
|
121 |
-
m = Add()([x, skip])
|
122 |
-
|
123 |
-
return m
|
124 |
-
|
125 |
-
def conv1_block(self, input, k=1, dropout=0.0):
|
126 |
-
"""[summary]
|
127 |
-
|
128 |
-
Args:
|
129 |
-
input ([type]): [description]
|
130 |
-
k (int, optional): [description]. Defaults to 1.
|
131 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
132 |
-
|
133 |
-
Returns:
|
134 |
-
[type]: [description]
|
135 |
-
"""
|
136 |
-
init = input
|
137 |
-
|
138 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
139 |
-
|
140 |
-
x = BatchNormalization(
|
141 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
142 |
-
)(input)
|
143 |
-
x = Activation("relu")(x)
|
144 |
-
x = Convolution2D(
|
145 |
-
16 * k,
|
146 |
-
(3, 3),
|
147 |
-
padding="same",
|
148 |
-
kernel_initializer="he_normal",
|
149 |
-
kernel_regularizer=l2(self.weight_decay),
|
150 |
-
use_bias=False,
|
151 |
-
)(x)
|
152 |
-
|
153 |
-
if dropout > 0.0:
|
154 |
-
x = Dropout(dropout)(x)
|
155 |
-
|
156 |
-
x = BatchNormalization(
|
157 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
158 |
-
)(x)
|
159 |
-
x = Activation("relu")(x)
|
160 |
-
x = Convolution2D(
|
161 |
-
16 * k,
|
162 |
-
(3, 3),
|
163 |
-
padding="same",
|
164 |
-
kernel_initializer="he_normal",
|
165 |
-
kernel_regularizer=l2(self.weight_decay),
|
166 |
-
use_bias=False,
|
167 |
-
)(x)
|
168 |
-
|
169 |
-
m = Add()([init, x])
|
170 |
-
return m
|
171 |
-
|
172 |
-
def conv2_block(self, input, k=1, dropout=0.0):
|
173 |
-
"""[summary]
|
174 |
-
|
175 |
-
Args:
|
176 |
-
input ([type]): [description]
|
177 |
-
k (int, optional): [description]. Defaults to 1.
|
178 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
179 |
-
|
180 |
-
Returns:
|
181 |
-
[type]: [description]
|
182 |
-
"""
|
183 |
-
init = input
|
184 |
-
|
185 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
186 |
-
print("conv2:channel: {}".format(channel_axis))
|
187 |
-
x = BatchNormalization(
|
188 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
189 |
-
)(input)
|
190 |
-
x = Activation("relu")(x)
|
191 |
-
x = Convolution2D(
|
192 |
-
32 * k,
|
193 |
-
(3, 3),
|
194 |
-
padding="same",
|
195 |
-
kernel_initializer="he_normal",
|
196 |
-
kernel_regularizer=l2(self.weight_decay),
|
197 |
-
use_bias=False,
|
198 |
-
)(x)
|
199 |
-
|
200 |
-
if dropout > 0.0:
|
201 |
-
x = Dropout(dropout)(x)
|
202 |
-
|
203 |
-
x = BatchNormalization(
|
204 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
205 |
-
)(x)
|
206 |
-
x = Activation("relu")(x)
|
207 |
-
x = Convolution2D(
|
208 |
-
32 * k,
|
209 |
-
(3, 3),
|
210 |
-
padding="same",
|
211 |
-
kernel_initializer="he_normal",
|
212 |
-
kernel_regularizer=l2(self.weight_decay),
|
213 |
-
use_bias=False,
|
214 |
-
)(x)
|
215 |
-
|
216 |
-
m = Add()([init, x])
|
217 |
-
return m
|
218 |
-
|
219 |
-
def conv3_block(self, input, k=1, dropout=0.0):
|
220 |
-
"""[summary]
|
221 |
-
|
222 |
-
Args:
|
223 |
-
input ([type]): [description]
|
224 |
-
k (int, optional): [description]. Defaults to 1.
|
225 |
-
dropout (float, optional): [description]. Defaults to 0.0.
|
226 |
-
|
227 |
-
Returns:
|
228 |
-
[type]: [description]
|
229 |
-
"""
|
230 |
-
init = input
|
231 |
-
|
232 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
233 |
-
|
234 |
-
x = BatchNormalization(
|
235 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
236 |
-
)(input)
|
237 |
-
x = Activation("relu")(x)
|
238 |
-
x = Convolution2D(
|
239 |
-
64 * k,
|
240 |
-
(3, 3),
|
241 |
-
padding="same",
|
242 |
-
kernel_initializer="he_normal",
|
243 |
-
kernel_regularizer=l2(self.weight_decay),
|
244 |
-
use_bias=False,
|
245 |
-
)(x)
|
246 |
-
|
247 |
-
if dropout > 0.0:
|
248 |
-
x = Dropout(dropout)(x)
|
249 |
-
|
250 |
-
x = BatchNormalization(
|
251 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
252 |
-
)(x)
|
253 |
-
x = Activation("relu")(x)
|
254 |
-
x = Convolution2D(
|
255 |
-
64 * k,
|
256 |
-
(3, 3),
|
257 |
-
padding="same",
|
258 |
-
kernel_initializer="he_normal",
|
259 |
-
kernel_regularizer=l2(self.weight_decay),
|
260 |
-
use_bias=False,
|
261 |
-
)(x)
|
262 |
-
|
263 |
-
m = Add()([init, x])
|
264 |
-
return m
|
265 |
-
|
266 |
-
def create_wide_residual_network(self):
|
267 |
-
"""create a wide residual network model
|
268 |
-
|
269 |
-
|
270 |
-
Returns:
|
271 |
-
[Model]: [wide residual network]
|
272 |
-
"""
|
273 |
-
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
274 |
-
|
275 |
-
ip = Input(shape=self.input_dim)
|
276 |
-
|
277 |
-
x = self.initial_conv(ip)
|
278 |
-
nb_conv = 4
|
279 |
-
|
280 |
-
x = self.expand_conv(x, 16, self.k)
|
281 |
-
nb_conv += 2
|
282 |
-
|
283 |
-
for i in range(self.N - 1):
|
284 |
-
x = self.conv1_block(x, self.k, self.dropout)
|
285 |
-
nb_conv += 2
|
286 |
-
|
287 |
-
x = BatchNormalization(
|
288 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
289 |
-
)(x)
|
290 |
-
x = Activation("relu")(x)
|
291 |
-
|
292 |
-
x = self.expand_conv(x, 32, self.k, strides=(2, 2))
|
293 |
-
nb_conv += 2
|
294 |
-
|
295 |
-
for i in range(self.N - 1):
|
296 |
-
x = self.conv2_block(x, self.k, self.dropout)
|
297 |
-
nb_conv += 2
|
298 |
-
|
299 |
-
x = BatchNormalization(
|
300 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
301 |
-
)(x)
|
302 |
-
x = Activation("relu")(x)
|
303 |
-
|
304 |
-
x = self.expand_conv(x, 64, self.k, strides=(2, 2))
|
305 |
-
nb_conv += 2
|
306 |
-
|
307 |
-
for i in range(self.N - 1):
|
308 |
-
x = self.conv3_block(x, self.k, self.dropout)
|
309 |
-
nb_conv += 2
|
310 |
-
|
311 |
-
x = BatchNormalization(
|
312 |
-
axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer="uniform"
|
313 |
-
)(x)
|
314 |
-
x = Activation("relu")(x)
|
315 |
-
|
316 |
-
x = AveragePooling2D((8, 8))(x)
|
317 |
-
x = Flatten()(x)
|
318 |
-
|
319 |
-
x = Dense(
|
320 |
-
self.nb_classes,
|
321 |
-
kernel_regularizer=l2(self.weight_decay),
|
322 |
-
activation="softmax",
|
323 |
-
)(x)
|
324 |
-
|
325 |
-
model = Model(ip, x)
|
326 |
-
|
327 |
-
if self.verbose:
|
328 |
-
print("Wide Residual Network-%d-%d created." % (nb_conv, self.k))
|
329 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|