Spaces:
Runtime error
Runtime error
Fix linear layer
Browse files- cnn.py +1 -1
- models/void.pth +0 -0
- models/void_20230512_225714.pth +0 -0
- models/void_20230513_021517.pth +0 -0
- notebooks/playground.ipynb +185 -36
cnn.py
CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
-
self.linear = nn.Linear(128 * 5 * 11,
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
+
self.linear = nn.Linear(128 * 5 * 11, 3)
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
models/void.pth
CHANGED
Binary files a/models/void.pth and b/models/void.pth differ
|
|
models/void_20230512_225714.pth
ADDED
Binary file (477 kB). View file
|
|
models/void_20230513_021517.pth
ADDED
Binary file (477 kB). View file
|
|
notebooks/playground.ipynb
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
-
"id": "
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
-
"id": "
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -25,7 +25,7 @@
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 86,
|
28 |
-
"id": "
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
@@ -38,7 +38,7 @@
|
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
"execution_count": 85,
|
41 |
-
"id": "
|
42 |
"metadata": {},
|
43 |
"outputs": [],
|
44 |
"source": [
|
@@ -49,7 +49,7 @@
|
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
"execution_count": 78,
|
52 |
-
"id": "
|
53 |
"metadata": {},
|
54 |
"outputs": [
|
55 |
{
|
@@ -71,7 +71,7 @@
|
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
"execution_count": 109,
|
74 |
-
"id": "
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
@@ -87,7 +87,7 @@
|
|
87 |
{
|
88 |
"cell_type": "code",
|
89 |
"execution_count": 110,
|
90 |
-
"id": "
|
91 |
"metadata": {},
|
92 |
"outputs": [
|
93 |
{
|
@@ -108,7 +108,7 @@
|
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
"execution_count": 111,
|
111 |
-
"id": "
|
112 |
"metadata": {},
|
113 |
"outputs": [
|
114 |
{
|
@@ -136,7 +136,7 @@
|
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
"execution_count": 112,
|
139 |
-
"id": "
|
140 |
"metadata": {},
|
141 |
"outputs": [
|
142 |
{
|
@@ -156,36 +156,19 @@
|
|
156 |
},
|
157 |
{
|
158 |
"cell_type": "code",
|
159 |
-
"execution_count":
|
160 |
-
"id": "
|
161 |
"metadata": {},
|
162 |
-
"outputs": [
|
163 |
-
{
|
164 |
-
"ename": "RuntimeError",
|
165 |
-
"evalue": "mat1 and mat2 shapes cannot be multiplied (2x2560 and 7040x10)",
|
166 |
-
"output_type": "error",
|
167 |
-
"traceback": [
|
168 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
169 |
-
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
170 |
-
"Cell \u001b[0;32mIn[113], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m cnn \u001b[38;5;241m=\u001b[39m CNNetwork()\n\u001b[0;32m----> 2\u001b[0m \u001b[43msummary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcnn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m44\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
|
171 |
-
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torchsummary/torchsummary.py:72\u001b[0m, in \u001b[0;36msummary\u001b[0;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[1;32m 68\u001b[0m model\u001b[38;5;241m.\u001b[39mapply(register_hook)\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# make a forward pass\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;66;03m# print(x.shape)\u001b[39;00m\n\u001b[0;32m---> 72\u001b[0m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;66;03m# remove these hooks\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m h \u001b[38;5;129;01min\u001b[39;00m hooks:\n",
|
172 |
-
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
173 |
-
"File \u001b[0;32m~/ml-sandbox/VoID/notebooks/../cnn.py:64\u001b[0m, in \u001b[0;36mCNNetwork.forward\u001b[0;34m(self, input_data)\u001b[0m\n\u001b[1;32m 62\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv4(x)\n\u001b[1;32m 63\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mflatten(x)\n\u001b[0;32m---> 64\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m predictions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msoftmax(logits)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m predictions\n",
|
174 |
-
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/module.py:1538\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1535\u001b[0m bw_hook \u001b[38;5;241m=\u001b[39m hooks\u001b[38;5;241m.\u001b[39mBackwardHook(\u001b[38;5;28mself\u001b[39m, full_backward_hooks, backward_pre_hooks)\n\u001b[1;32m 1536\u001b[0m args \u001b[38;5;241m=\u001b[39m bw_hook\u001b[38;5;241m.\u001b[39msetup_input_hook(args)\n\u001b[0;32m-> 1538\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks:\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m hook_id, hook \u001b[38;5;129;01min\u001b[39;00m (\n\u001b[1;32m 1541\u001b[0m \u001b[38;5;241m*\u001b[39m_global_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m 1542\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks\u001b[38;5;241m.\u001b[39mitems(),\n\u001b[1;32m 1543\u001b[0m ):\n",
|
175 |
-
"File \u001b[0;32m~/anaconda3/envs/void/lib/python3.9/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
176 |
-
"\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x2560 and 7040x10)"
|
177 |
-
]
|
178 |
-
}
|
179 |
-
],
|
180 |
"source": [
|
181 |
"cnn = CNNetwork()\n",
|
182 |
-
"summary(cnn, (1, 64, 44))"
|
183 |
]
|
184 |
},
|
185 |
{
|
186 |
"cell_type": "code",
|
187 |
"execution_count": 114,
|
188 |
-
"id": "
|
189 |
"metadata": {},
|
190 |
"outputs": [
|
191 |
{
|
@@ -206,7 +189,7 @@
|
|
206 |
{
|
207 |
"cell_type": "code",
|
208 |
"execution_count": 115,
|
209 |
-
"id": "
|
210 |
"metadata": {},
|
211 |
"outputs": [
|
212 |
{
|
@@ -227,7 +210,7 @@
|
|
227 |
{
|
228 |
"cell_type": "code",
|
229 |
"execution_count": 116,
|
230 |
-
"id": "
|
231 |
"metadata": {},
|
232 |
"outputs": [
|
233 |
{
|
@@ -255,7 +238,7 @@
|
|
255 |
{
|
256 |
"cell_type": "code",
|
257 |
"execution_count": 117,
|
258 |
-
"id": "
|
259 |
"metadata": {},
|
260 |
"outputs": [],
|
261 |
"source": [
|
@@ -266,7 +249,7 @@
|
|
266 |
{
|
267 |
"cell_type": "code",
|
268 |
"execution_count": 107,
|
269 |
-
"id": "
|
270 |
"metadata": {},
|
271 |
"outputs": [
|
272 |
{
|
@@ -284,10 +267,176 @@
|
|
284 |
"now.strftime(\"%Y%m%d_%H%M%S\")"
|
285 |
]
|
286 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
{
|
288 |
"cell_type": "code",
|
289 |
"execution_count": null,
|
290 |
-
"id": "
|
291 |
"metadata": {},
|
292 |
"outputs": [],
|
293 |
"source": []
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
+
"id": "9db7bd27",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
+
"id": "72b076a5",
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 86,
|
28 |
+
"id": "391c8ebe",
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
|
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
"execution_count": 85,
|
41 |
+
"id": "0f0b166a",
|
42 |
"metadata": {},
|
43 |
"outputs": [],
|
44 |
"source": [
|
|
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
"execution_count": 78,
|
52 |
+
"id": "b690f559",
|
53 |
"metadata": {},
|
54 |
"outputs": [
|
55 |
{
|
|
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
"execution_count": 109,
|
74 |
+
"id": "5b4cac66",
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
|
|
87 |
{
|
88 |
"cell_type": "code",
|
89 |
"execution_count": 110,
|
90 |
+
"id": "55928782",
|
91 |
"metadata": {},
|
92 |
"outputs": [
|
93 |
{
|
|
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
"execution_count": 111,
|
111 |
+
"id": "296fc1d0",
|
112 |
"metadata": {},
|
113 |
"outputs": [
|
114 |
{
|
|
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
"execution_count": 112,
|
139 |
+
"id": "b921ef42",
|
140 |
"metadata": {},
|
141 |
"outputs": [
|
142 |
{
|
|
|
156 |
},
|
157 |
{
|
158 |
"cell_type": "code",
|
159 |
+
"execution_count": 144,
|
160 |
+
"id": "83671781",
|
161 |
"metadata": {},
|
162 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
"source": [
|
164 |
"cnn = CNNetwork()\n",
|
165 |
+
"# summary(cnn, (1, 64, 44))"
|
166 |
]
|
167 |
},
|
168 |
{
|
169 |
"cell_type": "code",
|
170 |
"execution_count": 114,
|
171 |
+
"id": "5a12b59f",
|
172 |
"metadata": {},
|
173 |
"outputs": [
|
174 |
{
|
|
|
189 |
{
|
190 |
"cell_type": "code",
|
191 |
"execution_count": 115,
|
192 |
+
"id": "4845de38",
|
193 |
"metadata": {},
|
194 |
"outputs": [
|
195 |
{
|
|
|
210 |
{
|
211 |
"cell_type": "code",
|
212 |
"execution_count": 116,
|
213 |
+
"id": "51c03aaf",
|
214 |
"metadata": {},
|
215 |
"outputs": [
|
216 |
{
|
|
|
238 |
{
|
239 |
"cell_type": "code",
|
240 |
"execution_count": 117,
|
241 |
+
"id": "ba6b88ee",
|
242 |
"metadata": {},
|
243 |
"outputs": [],
|
244 |
"source": [
|
|
|
249 |
{
|
250 |
"cell_type": "code",
|
251 |
"execution_count": 107,
|
252 |
+
"id": "a6046ccf",
|
253 |
"metadata": {},
|
254 |
"outputs": [
|
255 |
{
|
|
|
267 |
"now.strftime(\"%Y%m%d_%H%M%S\")"
|
268 |
]
|
269 |
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 145,
|
273 |
+
"id": "d7789a04",
|
274 |
+
"metadata": {},
|
275 |
+
"outputs": [
|
276 |
+
{
|
277 |
+
"data": {
|
278 |
+
"text/plain": [
|
279 |
+
"<All keys matched successfully>"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 145,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"cnn.load_state_dict(torch.load(\"../models/void_20230512_225714.pth\"))"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 151,
|
294 |
+
"id": "a6030b42",
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"x, y = dataset[10]"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": 152,
|
304 |
+
"id": "78352b6b",
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"labels = dataset._labels"
|
309 |
+
]
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "code",
|
313 |
+
"execution_count": 153,
|
314 |
+
"id": "b8cc2162",
|
315 |
+
"metadata": {},
|
316 |
+
"outputs": [],
|
317 |
+
"source": [
|
318 |
+
"input = x.unsqueeze_(0) "
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": 154,
|
324 |
+
"id": "845ecea4",
|
325 |
+
"metadata": {},
|
326 |
+
"outputs": [],
|
327 |
+
"source": [
|
328 |
+
"def predict(model, input, target, class_mapping):\n",
|
329 |
+
" model.eval()\n",
|
330 |
+
" with torch.no_grad():\n",
|
331 |
+
" predictions = model(input)\n",
|
332 |
+
" print(predictions)\n",
|
333 |
+
" predicted_index = predictions[0].argmax(0)\n",
|
334 |
+
" predicted = class_mapping[predicted_index]\n",
|
335 |
+
" expected = class_mapping[target]\n",
|
336 |
+
" return predicted, expected"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": 155,
|
342 |
+
"id": "eb8d1e55",
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [
|
345 |
+
{
|
346 |
+
"name": "stdout",
|
347 |
+
"output_type": "stream",
|
348 |
+
"text": [
|
349 |
+
"tensor([[1.0000e+00, 1.3728e-20, 2.8026e-44]])\n"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"data": {
|
354 |
+
"text/plain": [
|
355 |
+
"('aman', 'aman')"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
"execution_count": 155,
|
359 |
+
"metadata": {},
|
360 |
+
"output_type": "execute_result"
|
361 |
+
}
|
362 |
+
],
|
363 |
+
"source": [
|
364 |
+
"predict(cnn, input, y, labels)"
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "code",
|
369 |
+
"execution_count": 156,
|
370 |
+
"id": "5d58683e",
|
371 |
+
"metadata": {},
|
372 |
+
"outputs": [
|
373 |
+
{
|
374 |
+
"data": {
|
375 |
+
"text/plain": [
|
376 |
+
"tensor([[[[0.0259, 0.1384, 0.0784, ..., 0.0000, 0.0000, 0.0000],\n",
|
377 |
+
" [0.0334, 0.1320, 0.0701, ..., 0.0000, 0.0000, 0.0000],\n",
|
378 |
+
" [0.0481, 0.0324, 0.0545, ..., 0.0000, 0.0000, 0.0000],\n",
|
379 |
+
" ...,\n",
|
380 |
+
" [0.2665, 0.3647, 0.3147, ..., 0.0000, 0.0000, 0.0000],\n",
|
381 |
+
" [0.2710, 0.3796, 0.2160, ..., 0.0000, 0.0000, 0.0000],\n",
|
382 |
+
" [0.1950, 0.2607, 0.1905, ..., 0.0000, 0.0000, 0.0000]]]])"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
"execution_count": 156,
|
386 |
+
"metadata": {},
|
387 |
+
"output_type": "execute_result"
|
388 |
+
}
|
389 |
+
],
|
390 |
+
"source": [
|
391 |
+
"input"
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": 157,
|
397 |
+
"id": "b0af5b69",
|
398 |
+
"metadata": {},
|
399 |
+
"outputs": [
|
400 |
+
{
|
401 |
+
"data": {
|
402 |
+
"text/plain": [
|
403 |
+
"torch.Size([1, 1, 64, 157])"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
"execution_count": 157,
|
407 |
+
"metadata": {},
|
408 |
+
"output_type": "execute_result"
|
409 |
+
}
|
410 |
+
],
|
411 |
+
"source": [
|
412 |
+
"input.shape"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": 158,
|
418 |
+
"id": "28c0768a",
|
419 |
+
"metadata": {},
|
420 |
+
"outputs": [
|
421 |
+
{
|
422 |
+
"data": {
|
423 |
+
"text/plain": [
|
424 |
+
"torch.Size([1, 1, 64, 157])"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
"execution_count": 158,
|
428 |
+
"metadata": {},
|
429 |
+
"output_type": "execute_result"
|
430 |
+
}
|
431 |
+
],
|
432 |
+
"source": [
|
433 |
+
"x.shape"
|
434 |
+
]
|
435 |
+
},
|
436 |
{
|
437 |
"cell_type": "code",
|
438 |
"execution_count": null,
|
439 |
+
"id": "c5817d01",
|
440 |
"metadata": {},
|
441 |
"outputs": [],
|
442 |
"source": []
|