Ioana-Gidiuta commited on
Commit
ddf92c5
·
verified ·
1 Parent(s): 6a58055

Upload 3 files

Browse files
Ensemble_method.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Preprocess_Data.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Run_Ensamble.ipynb ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "collapsed_sections": [
8
+ "j1mFIyblvH0-"
9
+ ],
10
+ "gpuType": "T4"
11
+ },
12
+ "kernelspec": {
13
+ "name": "python3",
14
+ "display_name": "Python 3"
15
+ },
16
+ "language_info": {
17
+ "name": "python"
18
+ },
19
+ "accelerator": "GPU",
20
+ "widgets": {
21
+ "application/vnd.jupyter.widget-state+json": {
22
+ "7d71deee94c64080badb10c59e3ac49b": {
23
+ "model_module": "@jupyter-widgets/controls",
24
+ "model_name": "HBoxModel",
25
+ "model_module_version": "1.5.0",
26
+ "state": {
27
+ "_dom_classes": [],
28
+ "_model_module": "@jupyter-widgets/controls",
29
+ "_model_module_version": "1.5.0",
30
+ "_model_name": "HBoxModel",
31
+ "_view_count": null,
32
+ "_view_module": "@jupyter-widgets/controls",
33
+ "_view_module_version": "1.5.0",
34
+ "_view_name": "HBoxView",
35
+ "box_style": "",
36
+ "children": [
37
+ "IPY_MODEL_013d6d22ebe04ebda019d9bf2e7e135d",
38
+ "IPY_MODEL_dd04535976844f3b82bd58020d30acb7",
39
+ "IPY_MODEL_a02104842c0a4b24adfc413f59d84fc0"
40
+ ],
41
+ "layout": "IPY_MODEL_94cb76a011a94602b4192f8e6e7142a4"
42
+ }
43
+ },
44
+ "013d6d22ebe04ebda019d9bf2e7e135d": {
45
+ "model_module": "@jupyter-widgets/controls",
46
+ "model_name": "HTMLModel",
47
+ "model_module_version": "1.5.0",
48
+ "state": {
49
+ "_dom_classes": [],
50
+ "_model_module": "@jupyter-widgets/controls",
51
+ "_model_module_version": "1.5.0",
52
+ "_model_name": "HTMLModel",
53
+ "_view_count": null,
54
+ "_view_module": "@jupyter-widgets/controls",
55
+ "_view_module_version": "1.5.0",
56
+ "_view_name": "HTMLView",
57
+ "description": "",
58
+ "description_tooltip": null,
59
+ "layout": "IPY_MODEL_80d6e5f3592f440db67765ee24e83ccc",
60
+ "placeholder": "​",
61
+ "style": "IPY_MODEL_9b8c527888bd4721a6a4387797fc7ba0",
62
+ "value": "ensemble_model.pkl: 100%"
63
+ }
64
+ },
65
+ "dd04535976844f3b82bd58020d30acb7": {
66
+ "model_module": "@jupyter-widgets/controls",
67
+ "model_name": "FloatProgressModel",
68
+ "model_module_version": "1.5.0",
69
+ "state": {
70
+ "_dom_classes": [],
71
+ "_model_module": "@jupyter-widgets/controls",
72
+ "_model_module_version": "1.5.0",
73
+ "_model_name": "FloatProgressModel",
74
+ "_view_count": null,
75
+ "_view_module": "@jupyter-widgets/controls",
76
+ "_view_module_version": "1.5.0",
77
+ "_view_name": "ProgressView",
78
+ "bar_style": "success",
79
+ "description": "",
80
+ "description_tooltip": null,
81
+ "layout": "IPY_MODEL_ed88c4a913d849a1a8222668aca3b1a0",
82
+ "max": 278892144,
83
+ "min": 0,
84
+ "orientation": "horizontal",
85
+ "style": "IPY_MODEL_52921704988a4cd49196eaab6a858268",
86
+ "value": 278892144
87
+ }
88
+ },
89
+ "a02104842c0a4b24adfc413f59d84fc0": {
90
+ "model_module": "@jupyter-widgets/controls",
91
+ "model_name": "HTMLModel",
92
+ "model_module_version": "1.5.0",
93
+ "state": {
94
+ "_dom_classes": [],
95
+ "_model_module": "@jupyter-widgets/controls",
96
+ "_model_module_version": "1.5.0",
97
+ "_model_name": "HTMLModel",
98
+ "_view_count": null,
99
+ "_view_module": "@jupyter-widgets/controls",
100
+ "_view_module_version": "1.5.0",
101
+ "_view_name": "HTMLView",
102
+ "description": "",
103
+ "description_tooltip": null,
104
+ "layout": "IPY_MODEL_256010fb44964c1fabc7d96d5ea4c644",
105
+ "placeholder": "​",
106
+ "style": "IPY_MODEL_91b2880b6eea4ba7b6a96a97c3fb8060",
107
+ "value": " 279M/279M [00:12<00:00, 23.6MB/s]"
108
+ }
109
+ },
110
+ "94cb76a011a94602b4192f8e6e7142a4": {
111
+ "model_module": "@jupyter-widgets/base",
112
+ "model_name": "LayoutModel",
113
+ "model_module_version": "1.2.0",
114
+ "state": {
115
+ "_model_module": "@jupyter-widgets/base",
116
+ "_model_module_version": "1.2.0",
117
+ "_model_name": "LayoutModel",
118
+ "_view_count": null,
119
+ "_view_module": "@jupyter-widgets/base",
120
+ "_view_module_version": "1.2.0",
121
+ "_view_name": "LayoutView",
122
+ "align_content": null,
123
+ "align_items": null,
124
+ "align_self": null,
125
+ "border": null,
126
+ "bottom": null,
127
+ "display": null,
128
+ "flex": null,
129
+ "flex_flow": null,
130
+ "grid_area": null,
131
+ "grid_auto_columns": null,
132
+ "grid_auto_flow": null,
133
+ "grid_auto_rows": null,
134
+ "grid_column": null,
135
+ "grid_gap": null,
136
+ "grid_row": null,
137
+ "grid_template_areas": null,
138
+ "grid_template_columns": null,
139
+ "grid_template_rows": null,
140
+ "height": null,
141
+ "justify_content": null,
142
+ "justify_items": null,
143
+ "left": null,
144
+ "margin": null,
145
+ "max_height": null,
146
+ "max_width": null,
147
+ "min_height": null,
148
+ "min_width": null,
149
+ "object_fit": null,
150
+ "object_position": null,
151
+ "order": null,
152
+ "overflow": null,
153
+ "overflow_x": null,
154
+ "overflow_y": null,
155
+ "padding": null,
156
+ "right": null,
157
+ "top": null,
158
+ "visibility": null,
159
+ "width": null
160
+ }
161
+ },
162
+ "80d6e5f3592f440db67765ee24e83ccc": {
163
+ "model_module": "@jupyter-widgets/base",
164
+ "model_name": "LayoutModel",
165
+ "model_module_version": "1.2.0",
166
+ "state": {
167
+ "_model_module": "@jupyter-widgets/base",
168
+ "_model_module_version": "1.2.0",
169
+ "_model_name": "LayoutModel",
170
+ "_view_count": null,
171
+ "_view_module": "@jupyter-widgets/base",
172
+ "_view_module_version": "1.2.0",
173
+ "_view_name": "LayoutView",
174
+ "align_content": null,
175
+ "align_items": null,
176
+ "align_self": null,
177
+ "border": null,
178
+ "bottom": null,
179
+ "display": null,
180
+ "flex": null,
181
+ "flex_flow": null,
182
+ "grid_area": null,
183
+ "grid_auto_columns": null,
184
+ "grid_auto_flow": null,
185
+ "grid_auto_rows": null,
186
+ "grid_column": null,
187
+ "grid_gap": null,
188
+ "grid_row": null,
189
+ "grid_template_areas": null,
190
+ "grid_template_columns": null,
191
+ "grid_template_rows": null,
192
+ "height": null,
193
+ "justify_content": null,
194
+ "justify_items": null,
195
+ "left": null,
196
+ "margin": null,
197
+ "max_height": null,
198
+ "max_width": null,
199
+ "min_height": null,
200
+ "min_width": null,
201
+ "object_fit": null,
202
+ "object_position": null,
203
+ "order": null,
204
+ "overflow": null,
205
+ "overflow_x": null,
206
+ "overflow_y": null,
207
+ "padding": null,
208
+ "right": null,
209
+ "top": null,
210
+ "visibility": null,
211
+ "width": null
212
+ }
213
+ },
214
+ "9b8c527888bd4721a6a4387797fc7ba0": {
215
+ "model_module": "@jupyter-widgets/controls",
216
+ "model_name": "DescriptionStyleModel",
217
+ "model_module_version": "1.5.0",
218
+ "state": {
219
+ "_model_module": "@jupyter-widgets/controls",
220
+ "_model_module_version": "1.5.0",
221
+ "_model_name": "DescriptionStyleModel",
222
+ "_view_count": null,
223
+ "_view_module": "@jupyter-widgets/base",
224
+ "_view_module_version": "1.2.0",
225
+ "_view_name": "StyleView",
226
+ "description_width": ""
227
+ }
228
+ },
229
+ "ed88c4a913d849a1a8222668aca3b1a0": {
230
+ "model_module": "@jupyter-widgets/base",
231
+ "model_name": "LayoutModel",
232
+ "model_module_version": "1.2.0",
233
+ "state": {
234
+ "_model_module": "@jupyter-widgets/base",
235
+ "_model_module_version": "1.2.0",
236
+ "_model_name": "LayoutModel",
237
+ "_view_count": null,
238
+ "_view_module": "@jupyter-widgets/base",
239
+ "_view_module_version": "1.2.0",
240
+ "_view_name": "LayoutView",
241
+ "align_content": null,
242
+ "align_items": null,
243
+ "align_self": null,
244
+ "border": null,
245
+ "bottom": null,
246
+ "display": null,
247
+ "flex": null,
248
+ "flex_flow": null,
249
+ "grid_area": null,
250
+ "grid_auto_columns": null,
251
+ "grid_auto_flow": null,
252
+ "grid_auto_rows": null,
253
+ "grid_column": null,
254
+ "grid_gap": null,
255
+ "grid_row": null,
256
+ "grid_template_areas": null,
257
+ "grid_template_columns": null,
258
+ "grid_template_rows": null,
259
+ "height": null,
260
+ "justify_content": null,
261
+ "justify_items": null,
262
+ "left": null,
263
+ "margin": null,
264
+ "max_height": null,
265
+ "max_width": null,
266
+ "min_height": null,
267
+ "min_width": null,
268
+ "object_fit": null,
269
+ "object_position": null,
270
+ "order": null,
271
+ "overflow": null,
272
+ "overflow_x": null,
273
+ "overflow_y": null,
274
+ "padding": null,
275
+ "right": null,
276
+ "top": null,
277
+ "visibility": null,
278
+ "width": null
279
+ }
280
+ },
281
+ "52921704988a4cd49196eaab6a858268": {
282
+ "model_module": "@jupyter-widgets/controls",
283
+ "model_name": "ProgressStyleModel",
284
+ "model_module_version": "1.5.0",
285
+ "state": {
286
+ "_model_module": "@jupyter-widgets/controls",
287
+ "_model_module_version": "1.5.0",
288
+ "_model_name": "ProgressStyleModel",
289
+ "_view_count": null,
290
+ "_view_module": "@jupyter-widgets/base",
291
+ "_view_module_version": "1.2.0",
292
+ "_view_name": "StyleView",
293
+ "bar_color": null,
294
+ "description_width": ""
295
+ }
296
+ },
297
+ "256010fb44964c1fabc7d96d5ea4c644": {
298
+ "model_module": "@jupyter-widgets/base",
299
+ "model_name": "LayoutModel",
300
+ "model_module_version": "1.2.0",
301
+ "state": {
302
+ "_model_module": "@jupyter-widgets/base",
303
+ "_model_module_version": "1.2.0",
304
+ "_model_name": "LayoutModel",
305
+ "_view_count": null,
306
+ "_view_module": "@jupyter-widgets/base",
307
+ "_view_module_version": "1.2.0",
308
+ "_view_name": "LayoutView",
309
+ "align_content": null,
310
+ "align_items": null,
311
+ "align_self": null,
312
+ "border": null,
313
+ "bottom": null,
314
+ "display": null,
315
+ "flex": null,
316
+ "flex_flow": null,
317
+ "grid_area": null,
318
+ "grid_auto_columns": null,
319
+ "grid_auto_flow": null,
320
+ "grid_auto_rows": null,
321
+ "grid_column": null,
322
+ "grid_gap": null,
323
+ "grid_row": null,
324
+ "grid_template_areas": null,
325
+ "grid_template_columns": null,
326
+ "grid_template_rows": null,
327
+ "height": null,
328
+ "justify_content": null,
329
+ "justify_items": null,
330
+ "left": null,
331
+ "margin": null,
332
+ "max_height": null,
333
+ "max_width": null,
334
+ "min_height": null,
335
+ "min_width": null,
336
+ "object_fit": null,
337
+ "object_position": null,
338
+ "order": null,
339
+ "overflow": null,
340
+ "overflow_x": null,
341
+ "overflow_y": null,
342
+ "padding": null,
343
+ "right": null,
344
+ "top": null,
345
+ "visibility": null,
346
+ "width": null
347
+ }
348
+ },
349
+ "91b2880b6eea4ba7b6a96a97c3fb8060": {
350
+ "model_module": "@jupyter-widgets/controls",
351
+ "model_name": "DescriptionStyleModel",
352
+ "model_module_version": "1.5.0",
353
+ "state": {
354
+ "_model_module": "@jupyter-widgets/controls",
355
+ "_model_module_version": "1.5.0",
356
+ "_model_name": "DescriptionStyleModel",
357
+ "_view_count": null,
358
+ "_view_module": "@jupyter-widgets/base",
359
+ "_view_module_version": "1.2.0",
360
+ "_view_name": "StyleView",
361
+ "description_width": ""
362
+ }
363
+ }
364
+ }
365
+ }
366
+ },
367
+ "cells": [
368
+ {
369
+ "cell_type": "markdown",
370
+ "source": [
371
+ "# Imports and Classes"
372
+ ],
373
+ "metadata": {
374
+ "id": "j1mFIyblvH0-"
375
+ }
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "source": [
380
+ "!pip install huggingface-hub\n",
381
+ "!pip install datasets > delete.txt"
382
+ ],
383
+ "metadata": {
384
+ "colab": {
385
+ "base_uri": "https://localhost:8080/"
386
+ },
387
+ "id": "N0MB1E-_udRx",
388
+ "outputId": "3b240262-b2e5-43f5-d037-c8459ab5ad1f"
389
+ },
390
+ "execution_count": 14,
391
+ "outputs": [
392
+ {
393
+ "output_type": "stream",
394
+ "name": "stdout",
395
+ "text": [
396
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (0.26.3)\n",
397
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (3.16.1)\n",
398
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2024.10.0)\n",
399
+ "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (24.2)\n",
400
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (6.0.2)\n",
401
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2.32.3)\n",
402
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.66.6)\n",
403
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.12.2)\n",
404
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.4.0)\n",
405
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.10)\n",
406
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2.2.3)\n",
407
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2024.8.30)\n",
408
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
409
+ "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n",
410
+ "\u001b[0m"
411
+ ]
412
+ }
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 24,
418
+ "metadata": {
419
+ "id": "0MwL3yauuB8m"
420
+ },
421
+ "outputs": [],
422
+ "source": [
423
+ "import torch\n",
424
+ "import pickle\n",
425
+ "from huggingface_hub import hf_hub_download\n",
426
+ "from datasets import load_dataset, Image\n",
427
+ "import torch\n",
428
+ "from torch import nn, optim\n",
429
+ "from torch.utils.data import DataLoader, Dataset\n",
430
+ "import numpy as np"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "source": [
436
+ "# change runtype to GPU\n",
437
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
438
+ "print(device)"
439
+ ],
440
+ "metadata": {
441
+ "colab": {
442
+ "base_uri": "https://localhost:8080/"
443
+ },
444
+ "id": "6saYtLslw95c",
445
+ "outputId": "e7d16fae-e5dd-452e-c80b-ed823258a322"
446
+ },
447
+ "execution_count": 3,
448
+ "outputs": [
449
+ {
450
+ "output_type": "stream",
451
+ "name": "stdout",
452
+ "text": [
453
+ "cuda\n"
454
+ ]
455
+ }
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "source": [
461
+ "class CNNModel1(nn.Module):\n",
462
+ " def __init__(self, num_outputs=2):\n",
463
+ " super(CNNModel1, self).__init__()\n",
464
+ " self.features = nn.Sequential(\n",
465
+ " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
466
+ " nn.ReLU(inplace=True),\n",
467
+ " nn.MaxPool2d(kernel_size=3, stride=2),\n",
468
+ " nn.BatchNorm2d(64),\n",
469
+ " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
470
+ " nn.ReLU(inplace=True),\n",
471
+ " nn.MaxPool2d(kernel_size=3, stride=2),\n",
472
+ " nn.BatchNorm2d(192),\n",
473
+ " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
474
+ " nn.ReLU(inplace=True),\n",
475
+ " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
476
+ " nn.ReLU(inplace=True),\n",
477
+ " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
478
+ " nn.ReLU(inplace=True),\n",
479
+ " nn.MaxPool2d(kernel_size=3, stride=2)\n",
480
+ " )\n",
481
+ " self.classifier = nn.Sequential(\n",
482
+ " nn.Dropout(),\n",
483
+ " nn.Linear(256 * 6 * 6, 4096),\n",
484
+ " nn.ReLU(inplace=True),\n",
485
+ " nn.Dropout(),\n",
486
+ " nn.Linear(4096, 4096),\n",
487
+ " nn.ReLU(inplace=True),\n",
488
+ " nn.Linear(4096, num_outputs)\n",
489
+ " )\n",
490
+ "\n",
491
+ " def forward(self, x):\n",
492
+ " x = self.features(x)\n",
493
+ " x = x.view(x.size(0), -1)\n",
494
+ " x = self.classifier(x)\n",
495
+ " return x"
496
+ ],
497
+ "metadata": {
498
+ "id": "CPaE895Quyrt"
499
+ },
500
+ "execution_count": 4,
501
+ "outputs": []
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "source": [
506
+ "class ResidualBlock(nn.Module):\n",
507
+ " def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
508
+ " super(ResidualBlock, self).__init__()\n",
509
+ " self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)\n",
510
+ " self.bn1 = nn.BatchNorm2d(out_channels)\n",
511
+ " self.relu = nn.ReLU(inplace=True)\n",
512
+ " self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n",
513
+ " self.bn2 = nn.BatchNorm2d(out_channels)\n",
514
+ " self.downsample = downsample\n",
515
+ "\n",
516
+ " def forward(self, x):\n",
517
+ " identity = x\n",
518
+ " if self.downsample:\n",
519
+ " identity = self.downsample(x)\n",
520
+ " out = self.conv1(x)\n",
521
+ " out = self.bn1(out)\n",
522
+ " out = self.relu(out)\n",
523
+ " out = self.conv2(out)\n",
524
+ " out = self.bn2(out)\n",
525
+ " out += identity\n",
526
+ " out = self.relu(out)\n",
527
+ " return out\n",
528
+ "\n",
529
+ "class CNNModel2(nn.Module):\n",
530
+ " def __init__(self, num_outputs=2):\n",
531
+ " super(CNNModel2, self).__init__()\n",
532
+ " self.in_channels = 64\n",
533
+ " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n",
534
+ " self.bn1 = nn.BatchNorm2d(64)\n",
535
+ " self.relu = nn.ReLU(inplace=True)\n",
536
+ " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
537
+ "\n",
538
+ " self.layer1 = self._make_layer(64, 2, stride=1)\n",
539
+ " self.layer2 = self._make_layer(128, 2, stride=2)\n",
540
+ " self.layer3 = self._make_layer(256, 2, stride=2)\n",
541
+ " self.layer4 = self._make_layer(512, 2, stride=2)\n",
542
+ "\n",
543
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
544
+ " self.fc = nn.Linear(512, num_outputs)\n",
545
+ "\n",
546
+ " def _make_layer(self, out_channels, blocks, stride):\n",
547
+ " downsample = None\n",
548
+ " if stride != 1 or self.in_channels != out_channels:\n",
549
+ " downsample = nn.Sequential(\n",
550
+ " nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride),\n",
551
+ " nn.BatchNorm2d(out_channels)\n",
552
+ " )\n",
553
+ " layers = []\n",
554
+ " layers.append(ResidualBlock(self.in_channels, out_channels, stride, downsample))\n",
555
+ " self.in_channels = out_channels\n",
556
+ " for _ in range(1, blocks):\n",
557
+ " layers.append(ResidualBlock(out_channels, out_channels))\n",
558
+ " return nn.Sequential(*layers)\n",
559
+ "\n",
560
+ " def forward(self, x):\n",
561
+ " x = self.conv1(x)\n",
562
+ " x = self.bn1(x)\n",
563
+ " x = self.relu(x)\n",
564
+ " x = self.maxpool(x)\n",
565
+ " x = self.layer1(x)\n",
566
+ " x = self.layer2(x)\n",
567
+ " x = self.layer3(x)\n",
568
+ " x = self.layer4(x)\n",
569
+ " x = self.avgpool(x)\n",
570
+ " x = x.view(x.size(0), -1)\n",
571
+ " x = self.fc(x)\n",
572
+ " return x"
573
+ ],
574
+ "metadata": {
575
+ "id": "BqqtP1WtuTF0"
576
+ },
577
+ "execution_count": 5,
578
+ "outputs": []
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "source": [
583
+ "class InceptionModule(nn.Module):\n",
584
+ " def __init__(self, in_channels, ch1x1, ch3x3_reduce, ch3x3, ch5x5_reduce, ch5x5, pool_proj):\n",
585
+ " super(InceptionModule, self).__init__()\n",
586
+ " self.branch1 = nn.Sequential(\n",
587
+ " nn.Conv2d(in_channels, ch1x1, kernel_size=1),\n",
588
+ " nn.ReLU(inplace=True)\n",
589
+ " )\n",
590
+ " self.branch2 = nn.Sequential(\n",
591
+ " nn.Conv2d(in_channels, ch3x3_reduce, kernel_size=1),\n",
592
+ " nn.ReLU(inplace=True),\n",
593
+ " nn.Conv2d(ch3x3_reduce, ch3x3, kernel_size=3, padding=1),\n",
594
+ " nn.ReLU(inplace=True)\n",
595
+ " )\n",
596
+ " self.branch3 = nn.Sequential(\n",
597
+ " nn.Conv2d(in_channels, ch5x5_reduce, kernel_size=1),\n",
598
+ " nn.ReLU(inplace=True),\n",
599
+ " nn.Conv2d(ch5x5_reduce, ch5x5, kernel_size=5, padding=2),\n",
600
+ " nn.ReLU(inplace=True)\n",
601
+ " )\n",
602
+ " self.branch4 = nn.Sequential(\n",
603
+ " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n",
604
+ " nn.Conv2d(in_channels, pool_proj, kernel_size=1),\n",
605
+ " nn.ReLU(inplace=True)\n",
606
+ " )\n",
607
+ "\n",
608
+ " def forward(self, x):\n",
609
+ " branch1 = self.branch1(x)\n",
610
+ " branch2 = self.branch2(x)\n",
611
+ " branch3 = self.branch3(x)\n",
612
+ " branch4 = self.branch4(x)\n",
613
+ " outputs = torch.cat([branch1, branch2, branch3, branch4], 1)\n",
614
+ " return outputs\n",
615
+ "\n",
616
+ "class CNNModel3(nn.Module):\n",
617
+ " def __init__(self, num_outputs=2):\n",
618
+ " super(CNNModel3, self).__init__()\n",
619
+ " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n",
620
+ " self.maxpool1 = nn.MaxPool2d(3, stride=2)\n",
621
+ " self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)\n",
622
+ " self.maxpool2 = nn.MaxPool2d(3, stride=2)\n",
623
+ "\n",
624
+ " self.inception3a = InceptionModule(192, 64, 96, 128, 16, 32, 32)\n",
625
+ " self.inception3b = InceptionModule(256, 128, 128, 192, 32, 96, 64)\n",
626
+ " self.maxpool3 = nn.MaxPool2d(3, stride=2)\n",
627
+ "\n",
628
+ " self.inception4a = InceptionModule(480, 192, 96, 208, 16, 48, 64)\n",
629
+ " self.inception4b = InceptionModule(512, 160, 112, 224, 24, 64, 64)\n",
630
+ " self.maxpool4 = nn.MaxPool2d(3, stride=2)\n",
631
+ "\n",
632
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
633
+ " self.dropout = nn.Dropout(0.4)\n",
634
+ " self.fc = nn.Linear(512, num_outputs)\n",
635
+ "\n",
636
+ " def forward(self, x):\n",
637
+ " x = self.conv1(x)\n",
638
+ " x = self.maxpool1(x)\n",
639
+ " x = self.conv2(x)\n",
640
+ " x = self.maxpool2(x)\n",
641
+ " x = self.inception3a(x)\n",
642
+ " x = self.inception3b(x)\n",
643
+ " x = self.maxpool3(x)\n",
644
+ " x = self.inception4a(x)\n",
645
+ " x = self.inception4b(x)\n",
646
+ " x = self.maxpool4(x)\n",
647
+ " x = self.avgpool(x)\n",
648
+ " x = x.view(x.size(0), -1)\n",
649
+ " x = self.dropout(x)\n",
650
+ " x = self.fc(x)\n",
651
+ " return x"
652
+ ],
653
+ "metadata": {
654
+ "id": "LSc0Yyzau5y1"
655
+ },
656
+ "execution_count": 6,
657
+ "outputs": []
658
+ },
659
+ {
660
+ "cell_type": "code",
661
+ "source": [
662
+ "from torch.utils.data import Dataset\n",
663
+ "class GPSImageDataset(Dataset):\n",
664
+ " def __init__(self, hf_dataset, transform, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):\n",
665
+ " self.hf_dataset = hf_dataset\n",
666
+ " self.transform = transform\n",
667
+ "\n",
668
+ " # Normalize the latitude and longitude\n",
669
+ " self.latitudes = np.array(hf_dataset['Latitude'])\n",
670
+ " self.longitudes = np.array(hf_dataset['Longitude'])\n",
671
+ " self.latitude_mean = lat_mean if lat_mean is not None else self.latitudes.mean()\n",
672
+ " self.latitude_std = lat_std if lat_std is not None else self.latitudes.std()\n",
673
+ " self.longitude_mean = lon_mean if lon_mean is not None else self.longitudes.mean()\n",
674
+ " self.longitude_std = lon_std if lon_std is not None else self.longitudes.std()\n",
675
+ "\n",
676
+ " self.normalized_latitudes = (self.latitudes - self.latitude_mean) / self.latitude_std\n",
677
+ " self.normalized_longitudes = (self.longitudes - self.longitude_mean) / self.longitude_std\n",
678
+ "\n",
679
+ " def __len__(self):\n",
680
+ " return len(self.hf_dataset)\n",
681
+ "\n",
682
+ " def __getitem__(self, idx):\n",
683
+ " image = self.hf_dataset[idx]['image']\n",
684
+ " latitude = self.normalized_latitudes[idx]\n",
685
+ " longitude = self.normalized_longitudes[idx]\n",
686
+ "\n",
687
+ " if self.transform:\n",
688
+ " image = self.transform(image)\n",
689
+ "\n",
690
+ " return image, torch.tensor([latitude, longitude], dtype=torch.float)"
691
+ ],
692
+ "metadata": {
693
+ "id": "QXwlWCWazwGB"
694
+ },
695
+ "execution_count": 20,
696
+ "outputs": []
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "source": [
701
+ "from torchvision import transforms, models\n",
702
+ "transform = transforms.Compose([\n",
703
+ " transforms.RandomResizedCrop(224),\n",
704
+ " transforms.RandomHorizontalFlip(),\n",
705
+ " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),\n",
706
+ " transforms.ToTensor(),\n",
707
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
708
+ "])\n",
709
+ "\n",
710
+ "inference_transform = transforms.Compose([\n",
711
+ " transforms.Resize((224, 224)),\n",
712
+ " transforms.ToTensor(),\n",
713
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
714
+ "])"
715
+ ],
716
+ "metadata": {
717
+ "id": "Jt8ZJtCM0MEl"
718
+ },
719
+ "execution_count": 22,
720
+ "outputs": []
721
+ },
722
+ {
723
+ "cell_type": "markdown",
724
+ "source": [
725
+ "# Loading the Pickle and Running on Publlic Dataset"
726
+ ],
727
+ "metadata": {
728
+ "id": "DMHQWN_qvOwU"
729
+ }
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "source": [
734
+ "!huggingface-cli login\n",
735
+ "# use appropiate token"
736
+ ],
737
+ "metadata": {
738
+ "id": "hMz1QFksv-Dt"
739
+ },
740
+ "execution_count": null,
741
+ "outputs": []
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "source": [
746
+ "pickle_file_path = hf_hub_download(repo_id= \"CIS-5190-CIA/Ensamble\", filename=\"ensemble_model.pkl\")"
747
+ ],
748
+ "metadata": {
749
+ "colab": {
750
+ "base_uri": "https://localhost:8080/",
751
+ "height": 153,
752
+ "referenced_widgets": [
753
+ "7d71deee94c64080badb10c59e3ac49b",
754
+ "013d6d22ebe04ebda019d9bf2e7e135d",
755
+ "dd04535976844f3b82bd58020d30acb7",
756
+ "a02104842c0a4b24adfc413f59d84fc0",
757
+ "94cb76a011a94602b4192f8e6e7142a4",
758
+ "80d6e5f3592f440db67765ee24e83ccc",
759
+ "9b8c527888bd4721a6a4387797fc7ba0",
760
+ "ed88c4a913d849a1a8222668aca3b1a0",
761
+ "52921704988a4cd49196eaab6a858268",
762
+ "256010fb44964c1fabc7d96d5ea4c644",
763
+ "91b2880b6eea4ba7b6a96a97c3fb8060"
764
+ ]
765
+ },
766
+ "id": "U--CT_wUwKzr",
767
+ "outputId": "8b0c7205-6de6-4794-8da7-0e917874532e"
768
+ },
769
+ "execution_count": 9,
770
+ "outputs": [
771
+ {
772
+ "output_type": "stream",
773
+ "name": "stderr",
774
+ "text": [
775
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
776
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
777
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
778
+ "You will be able to reuse this secret in all of your notebooks.\n",
779
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
780
+ " warnings.warn(\n"
781
+ ]
782
+ },
783
+ {
784
+ "output_type": "display_data",
785
+ "data": {
786
+ "text/plain": [
787
+ "ensemble_model.pkl: 0%| | 0.00/279M [00:00<?, ?B/s]"
788
+ ],
789
+ "application/vnd.jupyter.widget-view+json": {
790
+ "version_major": 2,
791
+ "version_minor": 0,
792
+ "model_id": "7d71deee94c64080badb10c59e3ac49b"
793
+ }
794
+ },
795
+ "metadata": {}
796
+ }
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "source": [
802
+ "with open(pickle_file_path, \"rb\") as f:\n",
803
+ " ensemble_model = pickle.load(f)\n",
804
+ "\n",
805
+ "model1 = CNNModel1(num_outputs=2) # Adapted AlexNet\n",
806
+ "model2 = CNNModel2(num_outputs=2) # Adapted ResNet\n",
807
+ "model3 = CNNModel3(num_outputs=2) # Adapted GoogLeNet\n",
808
+ "\n",
809
+ "model1.load_state_dict(ensemble_model[\"RNNModel1\"])\n",
810
+ "model2.load_state_dict(ensemble_model[\"RNNModel2\"])\n",
811
+ "model3.load_state_dict(ensemble_model[\"RNNModel3\"])\n",
812
+ "\n",
813
+ "model1.to(device)\n",
814
+ "model2.to(device)\n",
815
+ "model3.to(device)\n",
816
+ "\n",
817
+ "model1.eval()\n",
818
+ "model2.eval()\n",
819
+ "model3.eval()"
820
+ ],
821
+ "metadata": {
822
+ "colab": {
823
+ "base_uri": "https://localhost:8080/"
824
+ },
825
+ "id": "dUSraz_wweEw",
826
+ "outputId": "bca9cc95-2d6b-4172-a6e9-faeb85e9f2d4"
827
+ },
828
+ "execution_count": 30,
829
+ "outputs": [
830
+ {
831
+ "output_type": "execute_result",
832
+ "data": {
833
+ "text/plain": [
834
+ "CNNModel3(\n",
835
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
836
+ " (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
837
+ " (conv2): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
838
+ " (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
839
+ " (inception3a): InceptionModule(\n",
840
+ " (branch1): Sequential(\n",
841
+ " (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))\n",
842
+ " (1): ReLU(inplace=True)\n",
843
+ " )\n",
844
+ " (branch2): Sequential(\n",
845
+ " (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))\n",
846
+ " (1): ReLU(inplace=True)\n",
847
+ " (2): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
848
+ " (3): ReLU(inplace=True)\n",
849
+ " )\n",
850
+ " (branch3): Sequential(\n",
851
+ " (0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))\n",
852
+ " (1): ReLU(inplace=True)\n",
853
+ " (2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
854
+ " (3): ReLU(inplace=True)\n",
855
+ " )\n",
856
+ " (branch4): Sequential(\n",
857
+ " (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
858
+ " (1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))\n",
859
+ " (2): ReLU(inplace=True)\n",
860
+ " )\n",
861
+ " )\n",
862
+ " (inception3b): InceptionModule(\n",
863
+ " (branch1): Sequential(\n",
864
+ " (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
865
+ " (1): ReLU(inplace=True)\n",
866
+ " )\n",
867
+ " (branch2): Sequential(\n",
868
+ " (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))\n",
869
+ " (1): ReLU(inplace=True)\n",
870
+ " (2): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
871
+ " (3): ReLU(inplace=True)\n",
872
+ " )\n",
873
+ " (branch3): Sequential(\n",
874
+ " (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n",
875
+ " (1): ReLU(inplace=True)\n",
876
+ " (2): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
877
+ " (3): ReLU(inplace=True)\n",
878
+ " )\n",
879
+ " (branch4): Sequential(\n",
880
+ " (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
881
+ " (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))\n",
882
+ " (2): ReLU(inplace=True)\n",
883
+ " )\n",
884
+ " )\n",
885
+ " (maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
886
+ " (inception4a): InceptionModule(\n",
887
+ " (branch1): Sequential(\n",
888
+ " (0): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))\n",
889
+ " (1): ReLU(inplace=True)\n",
890
+ " )\n",
891
+ " (branch2): Sequential(\n",
892
+ " (0): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))\n",
893
+ " (1): ReLU(inplace=True)\n",
894
+ " (2): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
895
+ " (3): ReLU(inplace=True)\n",
896
+ " )\n",
897
+ " (branch3): Sequential(\n",
898
+ " (0): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))\n",
899
+ " (1): ReLU(inplace=True)\n",
900
+ " (2): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
901
+ " (3): ReLU(inplace=True)\n",
902
+ " )\n",
903
+ " (branch4): Sequential(\n",
904
+ " (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
905
+ " (1): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))\n",
906
+ " (2): ReLU(inplace=True)\n",
907
+ " )\n",
908
+ " )\n",
909
+ " (inception4b): InceptionModule(\n",
910
+ " (branch1): Sequential(\n",
911
+ " (0): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))\n",
912
+ " (1): ReLU(inplace=True)\n",
913
+ " )\n",
914
+ " (branch2): Sequential(\n",
915
+ " (0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))\n",
916
+ " (1): ReLU(inplace=True)\n",
917
+ " (2): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
918
+ " (3): ReLU(inplace=True)\n",
919
+ " )\n",
920
+ " (branch3): Sequential(\n",
921
+ " (0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))\n",
922
+ " (1): ReLU(inplace=True)\n",
923
+ " (2): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
924
+ " (3): ReLU(inplace=True)\n",
925
+ " )\n",
926
+ " (branch4): Sequential(\n",
927
+ " (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n",
928
+ " (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))\n",
929
+ " (2): ReLU(inplace=True)\n",
930
+ " )\n",
931
+ " )\n",
932
+ " (maxpool4): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
933
+ " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
934
+ " (dropout): Dropout(p=0.4, inplace=False)\n",
935
+ " (fc): Linear(in_features=512, out_features=2, bias=True)\n",
936
+ ")"
937
+ ]
938
+ },
939
+ "metadata": {},
940
+ "execution_count": 30
941
+ }
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "code",
946
+ "source": [
947
+ "def ensemble_predict(models, dataloader):\n",
948
+ " \"\"\"\n",
949
+ " Runs inference on the ensemble model.\n",
950
+ " Args:\n",
951
+ " models: List of models in the ensemble.\n",
952
+ " dataloader: DataLoader containing the input data.\n",
953
+ " Returns:\n",
954
+ " Averaged predictions from the ensemble.\n",
955
+ " \"\"\"\n",
956
+ " model_outputs = []\n",
957
+ " for model in models:\n",
958
+ " outputs = []\n",
959
+ " with torch.no_grad():\n",
960
+ " for images, _ in dataloader:\n",
961
+ " images = images.to(device)\n",
962
+ " outputs.append(model(images))\n",
963
+ " model_outputs.append(torch.cat(outputs, dim=0))\n",
964
+ "\n",
965
+ " # average the predictions across all models\n",
966
+ " ensemble_output = torch.stack(model_outputs, dim=0).mean(dim=0)\n",
967
+ " return ensemble_output"
968
+ ],
969
+ "metadata": {
970
+ "id": "kOOXX1c1xr2J"
971
+ },
972
+ "execution_count": 31,
973
+ "outputs": []
974
+ },
975
+ {
976
+ "cell_type": "code",
977
+ "source": [
978
+ "models = [model1, model2, model3]\n",
979
+ "\n",
980
+ "## UPDATE THIS WITH THE ACTUAL TESTING DATASET --> THIS IS THE ONLY VALUE YOU\n",
981
+ "## NEED TO UPDATE\n",
982
+ "\n",
983
+ "dataset_test = load_dataset(\"gydou/released_img\")"
984
+ ],
985
+ "metadata": {
986
+ "id": "x2gDbdjMx5Pl"
987
+ },
988
+ "execution_count": 32,
989
+ "outputs": []
990
+ },
991
+ {
992
+ "cell_type": "code",
993
+ "source": [
994
+ "latitudes = np.array([item['Latitude'] for item in dataset_test['train']])\n",
995
+ "longitudes = np.array([item['Longitude'] for item in dataset_test['train']])\n",
996
+ "\n",
997
+ "lat_mean = latitudes.mean()\n",
998
+ "lat_std = latitudes.std()\n",
999
+ "lon_mean = longitudes.mean()\n",
1000
+ "lon_std = longitudes.std()\n",
1001
+ "\n",
1002
+ "val_dataset = GPSImageDataset(\n",
1003
+ " hf_dataset=dataset_test['train'],\n",
1004
+ " transform=inference_transform,\n",
1005
+ " lat_mean=lat_mean,\n",
1006
+ " lat_std=lat_std,\n",
1007
+ " lon_mean=lon_mean,\n",
1008
+ " lon_std=lon_std\n",
1009
+ ")\n",
1010
+ "\n",
1011
+ "val_dataloader = DataLoader(\n",
1012
+ " val_dataset,\n",
1013
+ " batch_size=32,\n",
1014
+ " shuffle=False,\n",
1015
+ " num_workers=4\n",
1016
+ ")\n",
1017
+ "\n",
1018
+ "predictions = ensemble_predict(models, dataloader = val_dataloader)"
1019
+ ],
1020
+ "metadata": {
1021
+ "id": "Q8rQ-jCHzceg"
1022
+ },
1023
+ "execution_count": 33,
1024
+ "outputs": []
1025
+ },
1026
+ {
1027
+ "cell_type": "code",
1028
+ "source": [
1029
+ "from geopy.distance import geodesic\n",
1030
+ "def compute_rmse_in_meters(predictions, dataloader, lat_mean, lon_mean, lat_std, lon_std):\n",
1031
+ " total_loss = 0.0\n",
1032
+ " total_samples = 0\n",
1033
+ "\n",
1034
+ " predictions_denorm = predictions.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1035
+ " for idx, (_, gps_coords) in enumerate(dataloader):\n",
1036
+ " gps_coords = gps_coords.cpu().numpy()\n",
1037
+ "\n",
1038
+ "\n",
1039
+ " actuals_denorm = gps_coords * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1040
+ " batch_preds = predictions_denorm[idx * len(gps_coords):(idx + 1) * len(gps_coords)]\n",
1041
+ " for pred, actual in zip(batch_preds, actuals_denorm):\n",
1042
+ " distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters\n",
1043
+ " total_loss += distance ** 2\n",
1044
+ "\n",
1045
+ " total_samples += len(gps_coords)\n",
1046
+ "\n",
1047
+ " rmse = np.sqrt(total_loss / total_samples)\n",
1048
+ " return rmse\n",
1049
+ "\n",
1050
+ "rmse = compute_rmse_in_meters(predictions, val_dataloader, lat_mean, lon_mean, lat_std, lon_std)\n",
1051
+ "\n",
1052
+ "print(f\"Root Mean Squared Error (meters): {rmse:.2f}\")"
1053
+ ],
1054
+ "metadata": {
1055
+ "colab": {
1056
+ "base_uri": "https://localhost:8080/"
1057
+ },
1058
+ "id": "_jyPRsLi2mZ3",
1059
+ "outputId": "f71f0c8d-5de6-4bc5-82b6-53cc97ce6bee"
1060
+ },
1061
+ "execution_count": 36,
1062
+ "outputs": [
1063
+ {
1064
+ "output_type": "stream",
1065
+ "name": "stdout",
1066
+ "text": [
1067
+ "Root Mean Squared Error (meters): 102.03\n"
1068
+ ]
1069
+ }
1070
+ ]
1071
+ }
1072
+ ]
1073
+ }