amanmibra commited on
Commit
b1f510e
·
1 Parent(s): c63e93b

Fix linear layer

Browse files
cnn.py CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
- self.linear = nn.Linear(128 * 5 * 11, 10)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
 
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
+ self.linear = nn.Linear(128 * 5 * 11, 3)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
models/void.pth CHANGED
Binary files a/models/void.pth and b/models/void.pth differ
 
models/void_20230512_225714.pth ADDED
Binary file (477 kB). View file
 
models/void_20230513_021517.pth ADDED
Binary file (477 kB). View file
 
notebooks/playground.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
- "id": "aa1911f5",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -14,7 +14,7 @@
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
- "id": "67260d6a",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
@@ -25,7 +25,7 @@
25
  {
26
  "cell_type": "code",
27
  "execution_count": 86,
28
- "id": "45a193b9",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
@@ -38,7 +38,7 @@
38
  {
39
  "cell_type": "code",
40
  "execution_count": 85,
41
- "id": "4e2c5825",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
@@ -49,7 +49,7 @@
49
  {
50
  "cell_type": "code",
51
  "execution_count": 78,
52
- "id": "1b14521b",
53
  "metadata": {},
54
  "outputs": [
55
  {
@@ -71,7 +71,7 @@
71
  {
72
  "cell_type": "code",
73
  "execution_count": 109,
74
- "id": "d0d6fcd7",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
@@ -87,7 +87,7 @@
87
  {
88
  "cell_type": "code",
89
  "execution_count": 110,
90
- "id": "e39172ab",
91
  "metadata": {},
92
  "outputs": [
93
  {
@@ -108,7 +108,7 @@
108
  {
109
  "cell_type": "code",
110
  "execution_count": 111,
111
- "id": "ef5746b3",
112
  "metadata": {},
113
  "outputs": [
114
  {
@@ -136,7 +136,7 @@
136
  {
137
  "cell_type": "code",
138
  "execution_count": 112,
139
- "id": "a4f08bd4",
140
  "metadata": {},
141
  "outputs": [
142
  {
@@ -156,36 +156,19 @@
156
  },
157
  {
158
  "cell_type": "code",
159
- "execution_count": 113,
160
- "id": "2b58063c",
161
  "metadata": {},
162
- "outputs": [
163
- {
164
- "ename": "RuntimeError",
165
- "evalue": "mat1 and mat2 shapes cannot be multiplied (2x2560 and 7040x10)",
166
- "output_type": "error",
167
- "traceback": [
168
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
169
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
170
- "Cell \u001b[0;32mIn[113], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m cnn \u001b[38;5;241m=\u001b[39m CNNetwork()\n\u001b[0;32m----> 2\u001b[0m \u001b[43msummary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcnn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m44\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
171
- "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torchsummary/torchsummary.py:72\u001b[0m, in \u001b[0;36msummary\u001b[0;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[1;32m 68\u001b[0m model\u001b[38;5;241m.\u001b[39mapply(register_hook)\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# make a forward pass\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;66;03m# print(x.shape)\u001b[39;00m\n\u001b[0;32m---> 72\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;66;03m# remove these hooks\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m h \u001b[38;5;129;01min\u001b[39;00m hooks:\n",
172
- "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
173
- "File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../cnn.py:64\u001b[0m, in \u001b[0;36mCNNetwork.forward\u001b[0;34m(self, input_data)\u001b[0m\n\u001b[1;32m 62\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv4(x)\n\u001b[1;32m 63\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten(x)\n\u001b[0;32m---> 64\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m predictions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msoftmax(logits)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m predictions\n",
174
- "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1538\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1535\u001b[0m bw_hook \u001b[38;5;241m=\u001b[39m hooks\u001b[38;5;241m.\u001b[39mBackwardHook(\u001b[38;5;28mself\u001b[39m, full_backward_hooks, backward_pre_hooks)\n\u001b[1;32m 1536\u001b[0m args \u001b[38;5;241m=\u001b[39m bw_hook\u001b[38;5;241m.\u001b[39msetup_input_hook(args)\n\u001b[0;32m-> 1538\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks:\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook_id, hook \u001b[38;5;129;01min\u001b[39;00m (\n\u001b[1;32m 1541\u001b[0m \u001b[38;5;241m*\u001b[39m_global_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m 1543\u001b[0m ):\n",
175
- "File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
176
- "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x2560 and 7040x10)"
177
- ]
178
- }
179
- ],
180
  "source": [
181
  "cnn = CNNetwork()\n",
182
- "summary(cnn, (1, 64, 44))"
183
  ]
184
  },
185
  {
186
  "cell_type": "code",
187
  "execution_count": 114,
188
- "id": "70264920",
189
  "metadata": {},
190
  "outputs": [
191
  {
@@ -206,7 +189,7 @@
206
  {
207
  "cell_type": "code",
208
  "execution_count": 115,
209
- "id": "9383e1bb",
210
  "metadata": {},
211
  "outputs": [
212
  {
@@ -227,7 +210,7 @@
227
  {
228
  "cell_type": "code",
229
  "execution_count": 116,
230
- "id": "6d0cb06a",
231
  "metadata": {},
232
  "outputs": [
233
  {
@@ -255,7 +238,7 @@
255
  {
256
  "cell_type": "code",
257
  "execution_count": 117,
258
- "id": "e07e35f7",
259
  "metadata": {},
260
  "outputs": [],
261
  "source": [
@@ -266,7 +249,7 @@
266
  {
267
  "cell_type": "code",
268
  "execution_count": 107,
269
- "id": "ef2eddad",
270
  "metadata": {},
271
  "outputs": [
272
  {
@@ -284,10 +267,176 @@
284
  "now.strftime(\"%Y%m%d_%H%M%S\")"
285
  ]
286
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  {
288
  "cell_type": "code",
289
  "execution_count": null,
290
- "id": "c665d0cf",
291
  "metadata": {},
292
  "outputs": [],
293
  "source": []
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
+ "id": "9db7bd27",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
+ "id": "72b076a5",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
25
  {
26
  "cell_type": "code",
27
  "execution_count": 86,
28
+ "id": "391c8ebe",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
 
38
  {
39
  "cell_type": "code",
40
  "execution_count": 85,
41
+ "id": "0f0b166a",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
 
49
  {
50
  "cell_type": "code",
51
  "execution_count": 78,
52
+ "id": "b690f559",
53
  "metadata": {},
54
  "outputs": [
55
  {
 
71
  {
72
  "cell_type": "code",
73
  "execution_count": 109,
74
+ "id": "5b4cac66",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
 
87
  {
88
  "cell_type": "code",
89
  "execution_count": 110,
90
+ "id": "55928782",
91
  "metadata": {},
92
  "outputs": [
93
  {
 
108
  {
109
  "cell_type": "code",
110
  "execution_count": 111,
111
+ "id": "296fc1d0",
112
  "metadata": {},
113
  "outputs": [
114
  {
 
136
  {
137
  "cell_type": "code",
138
  "execution_count": 112,
139
+ "id": "b921ef42",
140
  "metadata": {},
141
  "outputs": [
142
  {
 
156
  },
157
  {
158
  "cell_type": "code",
159
+ "execution_count": 144,
160
+ "id": "83671781",
161
  "metadata": {},
162
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  "source": [
164
  "cnn = CNNetwork()\n",
165
+ "# summary(cnn, (1, 64, 44))"
166
  ]
167
  },
168
  {
169
  "cell_type": "code",
170
  "execution_count": 114,
171
+ "id": "5a12b59f",
172
  "metadata": {},
173
  "outputs": [
174
  {
 
189
  {
190
  "cell_type": "code",
191
  "execution_count": 115,
192
+ "id": "4845de38",
193
  "metadata": {},
194
  "outputs": [
195
  {
 
210
  {
211
  "cell_type": "code",
212
  "execution_count": 116,
213
+ "id": "51c03aaf",
214
  "metadata": {},
215
  "outputs": [
216
  {
 
238
  {
239
  "cell_type": "code",
240
  "execution_count": 117,
241
+ "id": "ba6b88ee",
242
  "metadata": {},
243
  "outputs": [],
244
  "source": [
 
249
  {
250
  "cell_type": "code",
251
  "execution_count": 107,
252
+ "id": "a6046ccf",
253
  "metadata": {},
254
  "outputs": [
255
  {
 
267
  "now.strftime(\"%Y%m%d_%H%M%S\")"
268
  ]
269
  },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 145,
273
+ "id": "d7789a04",
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "data": {
278
+ "text/plain": [
279
+ "<All keys matched successfully>"
280
+ ]
281
+ },
282
+ "execution_count": 145,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "cnn.load_state_dict(torch.load(\"../models/void_20230512_225714.pth\"))"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 151,
294
+ "id": "a6030b42",
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "x, y = dataset[10]"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 152,
304
+ "id": "78352b6b",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "labels = dataset._labels"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": 153,
314
+ "id": "b8cc2162",
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "input = x.unsqueeze_(0) "
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 154,
324
+ "id": "845ecea4",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "def predict(model, input, target, class_mapping):\n",
329
+ " model.eval()\n",
330
+ " with torch.no_grad():\n",
331
+ " predictions = model(input)\n",
332
+ " print(predictions)\n",
333
+ " predicted_index = predictions[0].argmax(0)\n",
334
+ " predicted = class_mapping[predicted_index]\n",
335
+ " expected = class_mapping[target]\n",
336
+ " return predicted, expected"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": 155,
342
+ "id": "eb8d1e55",
343
+ "metadata": {},
344
+ "outputs": [
345
+ {
346
+ "name": "stdout",
347
+ "output_type": "stream",
348
+ "text": [
349
+ "tensor([[1.0000e+00, 1.3728e-20, 2.8026e-44]])\n"
350
+ ]
351
+ },
352
+ {
353
+ "data": {
354
+ "text/plain": [
355
+ "('aman', 'aman')"
356
+ ]
357
+ },
358
+ "execution_count": 155,
359
+ "metadata": {},
360
+ "output_type": "execute_result"
361
+ }
362
+ ],
363
+ "source": [
364
+ "predict(cnn, input, y, labels)"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": 156,
370
+ "id": "5d58683e",
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/plain": [
376
+ "tensor([[[[0.0259, 0.1384, 0.0784, ..., 0.0000, 0.0000, 0.0000],\n",
377
+ " [0.0334, 0.1320, 0.0701, ..., 0.0000, 0.0000, 0.0000],\n",
378
+ " [0.0481, 0.0324, 0.0545, ..., 0.0000, 0.0000, 0.0000],\n",
379
+ " ...,\n",
380
+ " [0.2665, 0.3647, 0.3147, ..., 0.0000, 0.0000, 0.0000],\n",
381
+ " [0.2710, 0.3796, 0.2160, ..., 0.0000, 0.0000, 0.0000],\n",
382
+ " [0.1950, 0.2607, 0.1905, ..., 0.0000, 0.0000, 0.0000]]]])"
383
+ ]
384
+ },
385
+ "execution_count": 156,
386
+ "metadata": {},
387
+ "output_type": "execute_result"
388
+ }
389
+ ],
390
+ "source": [
391
+ "input"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 157,
397
+ "id": "b0af5b69",
398
+ "metadata": {},
399
+ "outputs": [
400
+ {
401
+ "data": {
402
+ "text/plain": [
403
+ "torch.Size([1, 1, 64, 157])"
404
+ ]
405
+ },
406
+ "execution_count": 157,
407
+ "metadata": {},
408
+ "output_type": "execute_result"
409
+ }
410
+ ],
411
+ "source": [
412
+ "input.shape"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 158,
418
+ "id": "28c0768a",
419
+ "metadata": {},
420
+ "outputs": [
421
+ {
422
+ "data": {
423
+ "text/plain": [
424
+ "torch.Size([1, 1, 64, 157])"
425
+ ]
426
+ },
427
+ "execution_count": 158,
428
+ "metadata": {},
429
+ "output_type": "execute_result"
430
+ }
431
+ ],
432
+ "source": [
433
+ "x.shape"
434
+ ]
435
+ },
436
  {
437
  "cell_type": "code",
438
  "execution_count": null,
439
+ "id": "c5817d01",
440
  "metadata": {},
441
  "outputs": [],
442
  "source": []