Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +65 -0
- models/.ipynb_checkpoints/Untitled-checkpoint.ipynb +766 -0
- models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb +305 -0
- models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb +485 -0
- models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb +372 -0
- models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb +483 -0
- models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb +323 -0
- models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb +500 -0
- models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb +460 -0
- models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb +326 -0
- models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb +3 -0
- models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb +273 -0
- models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb +0 -0
- models/.ipynb_checkpoints/resnet_model-checkpoint.py +160 -0
- models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py +166 -0
- models/.ipynb_checkpoints/train-checkpoint.py +105 -0
- models/.ipynb_checkpoints/train-mask-8-checkpoint.py +103 -0
- models/.ipynb_checkpoints/train-mask-checkpoint.py +104 -0
- models/.ipynb_checkpoints/utils-checkpoint.py +393 -0
- models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py +65 -0
- models/HITS-FEB-10.zip +3 -0
- models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png +3 -0
- models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739230556_9.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739230556_9.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739231399_1.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739231399_1.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739231802_11.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739231802_11.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_13.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_13.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_14.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739234628_14.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739235333_29.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739235333_29.png +3 -0
- models/HITS-FEB-10/hit_100000000_1739235841_12.npy +3 -0
- models/HITS-FEB-10/hit_100000000_1739235841_12.png +3 -0
- models/HITS-FEB-10/hit_50233055_1739232802_29.npy +3 -0
- models/HITS-FEB-10/hit_50233055_1739232802_29.png +3 -0
- models/HITS-FEB-10/hit_52111435_1739229641_28.npy +3 -0
- models/HITS-FEB-10/hit_52111435_1739229641_28.png +3 -0
- models/HITS-FEB-10/hit_52550001_1739233595_4.npy +3 -0
- models/HITS-FEB-10/hit_52550001_1739233595_4.png +3 -0
.gitattributes
CHANGED
@@ -36,3 +36,68 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
37 |
accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
38 |
accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
37 |
accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
38 |
accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb filter=lfs diff=lfs merge=lfs -text
|
40 |
+
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
models/HITS-FEB-10/hit_100000000_1739230556_9.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
models/HITS-FEB-10/hit_100000000_1739231399_1.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
models/HITS-FEB-10/hit_100000000_1739231802_11.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
models/HITS-FEB-10/hit_100000000_1739234628_13.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
models/HITS-FEB-10/hit_100000000_1739234628_14.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
models/HITS-FEB-10/hit_100000000_1739235333_29.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
models/HITS-FEB-10/hit_100000000_1739235841_12.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
models/HITS-FEB-10/hit_50233055_1739232802_29.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
models/HITS-FEB-10/hit_52111435_1739229641_28.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
models/HITS-FEB-10/hit_52550001_1739233595_4.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
models/HITS-FEB-10/hit_57096732_1739234611_11.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
models/HITS-FEB-10/hit_57521253_1739232651_11.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
models/HITS-FEB-10/hit_58032264_1739232672_22.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
models/HITS-FEB-10/hit_58165746_1739230560_10.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
models/HITS-FEB-10/hit_62177701_1739230031_23.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
models/HITS-FEB-10/hit_64237575_1739233604_2.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
models/HITS-FEB-10/hit_67249737_1739231769_23.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
models/HITS-FEB-10/hit_71882680_1739230648_29.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
models/HITS-FEB-10/hit_72677566_1739232113_26.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
models/HITS-FEB-10/hit_74160848_1739234611_5.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
models/HITS-FEB-10/hit_75109552_1739231790_10.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
models/HITS-FEB-10/hit_79640130_1739231950_5.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
models/HITS-FEB-10/hit_81910572_1739231764_29.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
models/HITS-FEB-10/hit_83296520_1739233906_31.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
models/HITS-FEB-10/hit_84171229_1739231886_5.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
models/HITS-FEB-10/hit_84411238_1739233784_3.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
models/HITS-FEB-10/hit_87957837_1739232059_29.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
models/HITS-FEB-10/hit_88699241_1739235027_30.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
models/HITS-FEB-10/hit_90233018_1739235740_0.png filter=lfs diff=lfs merge=lfs -text
|
71 |
+
models/HITS-FEB-10/hit_93808281_1739233104_14.png filter=lfs diff=lfs merge=lfs -text
|
72 |
+
models/HITS-FEB-10/hit_93821122_1739232374_4.png filter=lfs diff=lfs merge=lfs -text
|
73 |
+
models/HITS-FEB-10/hit_94705507_1739231215_6.png filter=lfs diff=lfs merge=lfs -text
|
74 |
+
models/HITS-FEB-10/hit_95400645_1739230724_21.png filter=lfs diff=lfs merge=lfs -text
|
75 |
+
models/HITS-FEB-10/hit_95544329_1739233784_17.png filter=lfs diff=lfs merge=lfs -text
|
76 |
+
models/HITS-FEB-10/hit_96119644_1739232285_1.png filter=lfs diff=lfs merge=lfs -text
|
77 |
+
models/HITS-FEB-10/hit_96369222_1739233852_3.png filter=lfs diff=lfs merge=lfs -text
|
78 |
+
models/HITS-FEB-10/hit_96470689_1739233843_18.png filter=lfs diff=lfs merge=lfs -text
|
79 |
+
models/HITS-FEB-10/hit_96471079_1739232857_0.png filter=lfs diff=lfs merge=lfs -text
|
80 |
+
models/HITS-FEB-10/hit_96497133_1739233692_27.png filter=lfs diff=lfs merge=lfs -text
|
81 |
+
models/HITS-FEB-10/hit_98139322_1739231500_5.png filter=lfs diff=lfs merge=lfs -text
|
82 |
+
models/HITS-FEB-10/hit_98582385_1739232894_12.png filter=lfs diff=lfs merge=lfs -text
|
83 |
+
models/HITS-FEB-10/hit_98697207_1739229930_16.png filter=lfs diff=lfs merge=lfs -text
|
84 |
+
models/HITS-FEB-10/hit_99172221_1739232365_29.png filter=lfs diff=lfs merge=lfs -text
|
85 |
+
models/HITS-FEB-10/hit_99314646_1739233667_1.png filter=lfs diff=lfs merge=lfs -text
|
86 |
+
models/HITS-FEB-10/hit_99756914_1739230207_8.png filter=lfs diff=lfs merge=lfs -text
|
87 |
+
models/HITS-FEB-10/hit_99939211_1739233705_24.png filter=lfs diff=lfs merge=lfs -text
|
88 |
+
models/HITS-FEB-10/hit_99972041_1739234066_9.png filter=lfs diff=lfs merge=lfs -text
|
89 |
+
models/HITS-FEB-10/hit_99977277_1739231773_10.png filter=lfs diff=lfs merge=lfs -text
|
90 |
+
models/HITS-FEB-10/hit_99986237_1739234058_8.png filter=lfs diff=lfs merge=lfs -text
|
91 |
+
models/HITS-FEB-10/hit_99998032_1739232348_29.png filter=lfs diff=lfs merge=lfs -text
|
92 |
+
models/HITS-FEB-10/hit_99999287_1739233700_3.png filter=lfs diff=lfs merge=lfs -text
|
93 |
+
models/HITS-FEB-10/hit_99999351_1739235476_0.png filter=lfs diff=lfs merge=lfs -text
|
94 |
+
models/HITS-FEB-10/hit_99999979_1739232399_15.png filter=lfs diff=lfs merge=lfs -text
|
95 |
+
models/combined_frb_detections.pdf filter=lfs diff=lfs merge=lfs -text
|
96 |
+
models/combined_frb_detections.png filter=lfs diff=lfs merge=lfs -text
|
97 |
+
models/hits.png filter=lfs diff=lfs merge=lfs -text
|
98 |
+
models/hits_crab.pdf filter=lfs diff=lfs merge=lfs -text
|
99 |
+
models/models_mask/accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
|
100 |
+
models/models_mask/accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
|
101 |
+
models/models_mask/accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
|
102 |
+
models/recover_crab.ipynb filter=lfs diff=lfs merge=lfs -text
|
103 |
+
models/recover_new_crab-debug.ipynb filter=lfs diff=lfs merge=lfs -text
|
models/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 10,
|
6 |
+
"id": "5577ffee-a5c9-4648-8849-95c2c7ebcebe",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
11 |
+
"from utils_batched_preproc import transform_batched, preproc_flip\n",
|
12 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
13 |
+
"import torch\n",
|
14 |
+
"import numpy as np\n",
|
15 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
16 |
+
"import torch\n",
|
17 |
+
"import torch.nn as nn\n",
|
18 |
+
"import torch.optim as optim\n",
|
19 |
+
"import tqdm \n",
|
20 |
+
"import torch.nn.functional as F\n",
|
21 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
22 |
+
"import pickle\n",
|
23 |
+
"import torch\n",
|
24 |
+
"from functorch import vmap"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 3,
|
30 |
+
"id": "f1180d60-83e7-47ca-aa09-58d26af3c706",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"# def renorm_batched(data):\n",
|
35 |
+
"# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
|
36 |
+
"# std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
|
37 |
+
"# standardized_data = (data - mean) / std\n",
|
38 |
+
"# return standardized_data\n",
|
39 |
+
"\n",
|
40 |
+
"# def transform_batched(data):\n",
|
41 |
+
"# copy_data = data.detach().clone()\n",
|
42 |
+
"# rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std\n",
|
43 |
+
"# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean\n",
|
44 |
+
"# masks_rms = [-1, 5]\n",
|
45 |
+
" \n",
|
46 |
+
"# # Prepare the new_data tensor\n",
|
47 |
+
"# num_masks = len(masks_rms) + 1\n",
|
48 |
+
"# new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)\n",
|
49 |
+
"\n",
|
50 |
+
"# # First layer: Apply renorm(log10(copy_data + epsilon))\n",
|
51 |
+
"# new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
|
52 |
+
"# for i, scale in enumerate(masks_rms, start=1):\n",
|
53 |
+
"# copy_data = data.detach().clone()\n",
|
54 |
+
" \n",
|
55 |
+
"# # Apply masking based on the scale\n",
|
56 |
+
"# if scale < 0:\n",
|
57 |
+
"# ind = copy_data < abs(scale) * rms + mean\n",
|
58 |
+
"# else:\n",
|
59 |
+
"# ind = copy_data > scale * rms + mean\n",
|
60 |
+
"# copy_data[ind] = 0\n",
|
61 |
+
" \n",
|
62 |
+
"# # Renormalize and log10 transform\n",
|
63 |
+
"# new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
|
64 |
+
" \n",
|
65 |
+
"# # Convert to float32\n",
|
66 |
+
"# new_data = new_data.type(torch.float32)\n",
|
67 |
+
"\n",
|
68 |
+
"# # Chunk along the last dimension and stack\n",
|
69 |
+
"# slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing\n",
|
70 |
+
"# new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1\n",
|
71 |
+
"# new_data = torch.swapaxes(new_data, 0,1)\n",
|
72 |
+
"# # Reshape into final format\n",
|
73 |
+
"# new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions\n",
|
74 |
+
"# return new_data\n",
|
75 |
+
"\n",
|
76 |
+
"\n"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": 11,
|
82 |
+
"id": "81cc81a8-ecef-43ef-a5cd-c35765384812",
|
83 |
+
"metadata": {
|
84 |
+
"scrolled": true
|
85 |
+
},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stdout",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"num params encoder 50840\n"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"name": "stderr",
|
96 |
+
"output_type": "stream",
|
97 |
+
"text": [
|
98 |
+
"/tmp/ipykernel_19147/1680389579.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
99 |
+
" model.load_state_dict(torch.load(model_path))\n"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"data": {
|
104 |
+
"text/plain": [
|
105 |
+
"DataParallel(\n",
|
106 |
+
" (module): ResNet(\n",
|
107 |
+
" (relu): ReLU()\n",
|
108 |
+
" (conv1): Sequential(\n",
|
109 |
+
" (0): Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
|
110 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
111 |
+
" (2): ReLU()\n",
|
112 |
+
" )\n",
|
113 |
+
" (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
114 |
+
" (layer0): Sequential(\n",
|
115 |
+
" (0): ResidualBlock(\n",
|
116 |
+
" (conv1): Sequential(\n",
|
117 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
118 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
119 |
+
" (2): ReLU()\n",
|
120 |
+
" )\n",
|
121 |
+
" (conv2): Sequential(\n",
|
122 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
123 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
124 |
+
" )\n",
|
125 |
+
" (relu): ReLU()\n",
|
126 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
127 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
128 |
+
" )\n",
|
129 |
+
" (1): ResidualBlock(\n",
|
130 |
+
" (conv1): Sequential(\n",
|
131 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
132 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
133 |
+
" (2): ReLU()\n",
|
134 |
+
" )\n",
|
135 |
+
" (conv2): Sequential(\n",
|
136 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
137 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
138 |
+
" )\n",
|
139 |
+
" (relu): ReLU()\n",
|
140 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
141 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
142 |
+
" )\n",
|
143 |
+
" (2): ResidualBlock(\n",
|
144 |
+
" (conv1): Sequential(\n",
|
145 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
146 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
147 |
+
" (2): ReLU()\n",
|
148 |
+
" )\n",
|
149 |
+
" (conv2): Sequential(\n",
|
150 |
+
" (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
151 |
+
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
152 |
+
" )\n",
|
153 |
+
" (relu): ReLU()\n",
|
154 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
155 |
+
" (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
156 |
+
" )\n",
|
157 |
+
" )\n",
|
158 |
+
" (layer1): Sequential(\n",
|
159 |
+
" (0): ResidualBlock(\n",
|
160 |
+
" (conv1): Sequential(\n",
|
161 |
+
" (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
|
162 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
163 |
+
" (2): ReLU()\n",
|
164 |
+
" )\n",
|
165 |
+
" (conv2): Sequential(\n",
|
166 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
167 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
168 |
+
" )\n",
|
169 |
+
" (downsample): Sequential(\n",
|
170 |
+
" (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))\n",
|
171 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
172 |
+
" )\n",
|
173 |
+
" (relu): ReLU()\n",
|
174 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
175 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
176 |
+
" )\n",
|
177 |
+
" (1): ResidualBlock(\n",
|
178 |
+
" (conv1): Sequential(\n",
|
179 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
180 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
181 |
+
" (2): ReLU()\n",
|
182 |
+
" )\n",
|
183 |
+
" (conv2): Sequential(\n",
|
184 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
185 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
186 |
+
" )\n",
|
187 |
+
" (relu): ReLU()\n",
|
188 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
189 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
190 |
+
" )\n",
|
191 |
+
" (2): ResidualBlock(\n",
|
192 |
+
" (conv1): Sequential(\n",
|
193 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
194 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
195 |
+
" (2): ReLU()\n",
|
196 |
+
" )\n",
|
197 |
+
" (conv2): Sequential(\n",
|
198 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
199 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
200 |
+
" )\n",
|
201 |
+
" (relu): ReLU()\n",
|
202 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
203 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
204 |
+
" )\n",
|
205 |
+
" (3): ResidualBlock(\n",
|
206 |
+
" (conv1): Sequential(\n",
|
207 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
208 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
209 |
+
" (2): ReLU()\n",
|
210 |
+
" )\n",
|
211 |
+
" (conv2): Sequential(\n",
|
212 |
+
" (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
213 |
+
" (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
214 |
+
" )\n",
|
215 |
+
" (relu): ReLU()\n",
|
216 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
217 |
+
" (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
218 |
+
" )\n",
|
219 |
+
" )\n",
|
220 |
+
" (layer2): Sequential(\n",
|
221 |
+
" (0): ResidualBlock(\n",
|
222 |
+
" (conv1): Sequential(\n",
|
223 |
+
" (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
|
224 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
225 |
+
" (2): ReLU()\n",
|
226 |
+
" )\n",
|
227 |
+
" (conv2): Sequential(\n",
|
228 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
229 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
230 |
+
" )\n",
|
231 |
+
" (downsample): Sequential(\n",
|
232 |
+
" (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))\n",
|
233 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
234 |
+
" )\n",
|
235 |
+
" (relu): ReLU()\n",
|
236 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
237 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
238 |
+
" )\n",
|
239 |
+
" (1): ResidualBlock(\n",
|
240 |
+
" (conv1): Sequential(\n",
|
241 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
242 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
243 |
+
" (2): ReLU()\n",
|
244 |
+
" )\n",
|
245 |
+
" (conv2): Sequential(\n",
|
246 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
247 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
248 |
+
" )\n",
|
249 |
+
" (relu): ReLU()\n",
|
250 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
251 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
252 |
+
" )\n",
|
253 |
+
" (2): ResidualBlock(\n",
|
254 |
+
" (conv1): Sequential(\n",
|
255 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
256 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
257 |
+
" (2): ReLU()\n",
|
258 |
+
" )\n",
|
259 |
+
" (conv2): Sequential(\n",
|
260 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
261 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
262 |
+
" )\n",
|
263 |
+
" (relu): ReLU()\n",
|
264 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
265 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
266 |
+
" )\n",
|
267 |
+
" (3): ResidualBlock(\n",
|
268 |
+
" (conv1): Sequential(\n",
|
269 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
270 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
271 |
+
" (2): ReLU()\n",
|
272 |
+
" )\n",
|
273 |
+
" (conv2): Sequential(\n",
|
274 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
275 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
276 |
+
" )\n",
|
277 |
+
" (relu): ReLU()\n",
|
278 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
279 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
280 |
+
" )\n",
|
281 |
+
" (4): ResidualBlock(\n",
|
282 |
+
" (conv1): Sequential(\n",
|
283 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
284 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
285 |
+
" (2): ReLU()\n",
|
286 |
+
" )\n",
|
287 |
+
" (conv2): Sequential(\n",
|
288 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
289 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
290 |
+
" )\n",
|
291 |
+
" (relu): ReLU()\n",
|
292 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
293 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
294 |
+
" )\n",
|
295 |
+
" (5): ResidualBlock(\n",
|
296 |
+
" (conv1): Sequential(\n",
|
297 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
298 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
299 |
+
" (2): ReLU()\n",
|
300 |
+
" )\n",
|
301 |
+
" (conv2): Sequential(\n",
|
302 |
+
" (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
303 |
+
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
304 |
+
" )\n",
|
305 |
+
" (relu): ReLU()\n",
|
306 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
307 |
+
" (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
308 |
+
" )\n",
|
309 |
+
" )\n",
|
310 |
+
" (layer3): Sequential(\n",
|
311 |
+
" (0): ResidualBlock(\n",
|
312 |
+
" (conv1): Sequential(\n",
|
313 |
+
" (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
314 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
315 |
+
" (2): ReLU()\n",
|
316 |
+
" )\n",
|
317 |
+
" (conv2): Sequential(\n",
|
318 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
319 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
320 |
+
" )\n",
|
321 |
+
" (downsample): Sequential(\n",
|
322 |
+
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
|
323 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
324 |
+
" )\n",
|
325 |
+
" (relu): ReLU()\n",
|
326 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
327 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
328 |
+
" )\n",
|
329 |
+
" (1): ResidualBlock(\n",
|
330 |
+
" (conv1): Sequential(\n",
|
331 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
332 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
333 |
+
" (2): ReLU()\n",
|
334 |
+
" )\n",
|
335 |
+
" (conv2): Sequential(\n",
|
336 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
337 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
338 |
+
" )\n",
|
339 |
+
" (relu): ReLU()\n",
|
340 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
341 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
342 |
+
" )\n",
|
343 |
+
" (2): ResidualBlock(\n",
|
344 |
+
" (conv1): Sequential(\n",
|
345 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
346 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
347 |
+
" (2): ReLU()\n",
|
348 |
+
" )\n",
|
349 |
+
" (conv2): Sequential(\n",
|
350 |
+
" (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
351 |
+
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
352 |
+
" )\n",
|
353 |
+
" (relu): ReLU()\n",
|
354 |
+
" (dropout1): Dropout(p=0.5, inplace=False)\n",
|
355 |
+
" (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
356 |
+
" )\n",
|
357 |
+
" )\n",
|
358 |
+
" (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
|
359 |
+
" (fc): Linear(in_features=39424, out_features=2, bias=True)\n",
|
360 |
+
" (dropout1): Dropout(p=0.3, inplace=False)\n",
|
361 |
+
" (encoder): Sequential(\n",
|
362 |
+
" (0): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
363 |
+
" (1): ReLU(inplace=True)\n",
|
364 |
+
" (2): Dropout(p=0.3, inplace=False)\n",
|
365 |
+
" (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
366 |
+
" (4): ReLU(inplace=True)\n",
|
367 |
+
" (5): Dropout(p=0.3, inplace=False)\n",
|
368 |
+
" (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
369 |
+
" (7): ReLU(inplace=True)\n",
|
370 |
+
" (8): Dropout(p=0.3, inplace=False)\n",
|
371 |
+
" (9): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
372 |
+
" (10): Sigmoid()\n",
|
373 |
+
" )\n",
|
374 |
+
" )\n",
|
375 |
+
")"
|
376 |
+
]
|
377 |
+
},
|
378 |
+
"execution_count": 11,
|
379 |
+
"metadata": {},
|
380 |
+
"output_type": "execute_result"
|
381 |
+
}
|
382 |
+
],
|
383 |
+
"source": [
|
384 |
+
"model_path = 'models/model-47-99.125.pt'\n",
|
385 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
386 |
+
"\n",
|
387 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=2).to(device)\n",
|
388 |
+
"model = nn.DataParallel(model)\n",
|
389 |
+
"model = model.to(device)\n",
|
390 |
+
"model.load_state_dict(torch.load(model_path))\n",
|
391 |
+
"model.eval()"
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": 12,
|
397 |
+
"id": "58b8c338-df2f-4ef0-92cf-409c9f034cab",
|
398 |
+
"metadata": {},
|
399 |
+
"outputs": [
|
400 |
+
{
|
401 |
+
"name": "stderr",
|
402 |
+
"output_type": "stream",
|
403 |
+
"text": [
|
404 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
405 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"name": "stdout",
|
410 |
+
"output_type": "stream",
|
411 |
+
"text": [
|
412 |
+
"tensor([[ 4.1780, -4.1750],\n",
|
413 |
+
" [ 4.6414, -4.6303],\n",
|
414 |
+
" [ 5.0103, -5.0162],\n",
|
415 |
+
" [ 4.8273, -4.8311],\n",
|
416 |
+
" [ 4.8523, -4.8661],\n",
|
417 |
+
" [ 4.8855, -4.9074],\n",
|
418 |
+
" [ 4.4973, -4.5213],\n",
|
419 |
+
" [ 5.5996, -5.6192],\n",
|
420 |
+
" [ 4.7929, -4.8116],\n",
|
421 |
+
" [ 5.5999, -5.5925],\n",
|
422 |
+
" [ 4.7918, -4.7998],\n",
|
423 |
+
" [ 4.0914, -4.0766],\n",
|
424 |
+
" [ 0.7072, -0.6955],\n",
|
425 |
+
" [ 4.7136, -4.7234],\n",
|
426 |
+
" [ 5.3918, -5.4307],\n",
|
427 |
+
" [ 4.5491, -4.5524],\n",
|
428 |
+
" [ 4.5412, -4.5391],\n",
|
429 |
+
" [ 4.6264, -4.6137],\n",
|
430 |
+
" [ 3.9378, -3.9300],\n",
|
431 |
+
" [ 5.0673, -5.0792],\n",
|
432 |
+
" [ 5.7389, -5.7330],\n",
|
433 |
+
" [ 5.2259, -5.2326],\n",
|
434 |
+
" [ 5.3856, -5.4036],\n",
|
435 |
+
" [ 5.0781, -5.1232],\n",
|
436 |
+
" [ 5.2432, -5.2584],\n",
|
437 |
+
" [ 5.8163, -5.8209],\n",
|
438 |
+
" [ 4.7730, -4.7823],\n",
|
439 |
+
" [ 5.1320, -5.1657],\n",
|
440 |
+
" [ 5.6486, -5.6485],\n",
|
441 |
+
" [ 3.7626, -3.7674],\n",
|
442 |
+
" [ 4.1834, -4.1797],\n",
|
443 |
+
" [ 4.4452, -4.4566]], device='cuda:0', grad_fn=<GatherBackward>)\n"
|
444 |
+
]
|
445 |
+
}
|
446 |
+
],
|
447 |
+
"source": [
|
448 |
+
"test_in = abs(torch.randn(32, 192, 2048).to(device))\n",
|
449 |
+
"results = []\n",
|
450 |
+
"for i in range(32):\n",
|
451 |
+
" results.append(transform(test_in[i,:,:]))\n",
|
452 |
+
"intermediate = torch.stack(results).cuda()\n",
|
453 |
+
"out = model(intermediate)\n",
|
454 |
+
"test_in.cpu().detach().numpy().tofile(\"input.bin\")\n",
|
455 |
+
"intermediate.cpu().detach().numpy().tofile(\"intermediate.bin\")\n",
|
456 |
+
"out.cpu().detach().numpy().tofile(\"output.bin\")\n",
|
457 |
+
"print(out)"
|
458 |
+
]
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"cell_type": "code",
|
462 |
+
"execution_count": 13,
|
463 |
+
"id": "ad56299a-44e4-4d6b-afcc-18a5f4cf0138",
|
464 |
+
"metadata": {},
|
465 |
+
"outputs": [
|
466 |
+
{
|
467 |
+
"ename": "NameError",
|
468 |
+
"evalue": "name 'preproc_flip' is not defined",
|
469 |
+
"output_type": "error",
|
470 |
+
"traceback": [
|
471 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
472 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
473 |
+
"Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m preproc_model \u001b[38;5;241m=\u001b[39m preproc_flip()\n\u001b[1;32m 2\u001b[0m Convert_ONNX(preproc_model,\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodels_mask/preproc_flip.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m, input_data_mock\u001b[38;5;241m=\u001b[39mtest_in\u001b[38;5;241m.\u001b[39mto(device))\n",
|
474 |
+
"\u001b[0;31mNameError\u001b[0m: name 'preproc_flip' is not defined"
|
475 |
+
]
|
476 |
+
}
|
477 |
+
],
|
478 |
+
"source": [
|
479 |
+
"preproc_model = preproc_flip()\n",
|
480 |
+
"Convert_ONNX(preproc_model,f'models_mask/preproc_flip.onnx', input_data_mock=test_in.to(device))\n",
|
481 |
+
"# Convert_ONNX(model.module,f'models_mask/model_test.onnx', input_data_mock=intermediate.to(device))"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "code",
|
486 |
+
"execution_count": 7,
|
487 |
+
"id": "30e84a9b-0d4f-4cb2-a92b-2e3f0b2ccb20",
|
488 |
+
"metadata": {},
|
489 |
+
"outputs": [
|
490 |
+
{
|
491 |
+
"data": {
|
492 |
+
"text/plain": [
|
493 |
+
"torch.Size([32, 192, 2048])"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
"execution_count": 7,
|
497 |
+
"metadata": {},
|
498 |
+
"output_type": "execute_result"
|
499 |
+
}
|
500 |
+
],
|
501 |
+
"source": [
|
502 |
+
"test_in.shape"
|
503 |
+
]
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"cell_type": "code",
|
507 |
+
"execution_count": 13,
|
508 |
+
"id": "1bb26727-7914-470e-bb48-43d7ee81cb50",
|
509 |
+
"metadata": {},
|
510 |
+
"outputs": [
|
511 |
+
{
|
512 |
+
"data": {
|
513 |
+
"text/plain": [
|
514 |
+
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
|
515 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
516 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
517 |
+
" ...,\n",
|
518 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
519 |
+
" [0., 0., 0., ..., 0., 0., 0.],\n",
|
520 |
+
" [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
"execution_count": 13,
|
524 |
+
"metadata": {},
|
525 |
+
"output_type": "execute_result"
|
526 |
+
}
|
527 |
+
],
|
528 |
+
"source": [
|
529 |
+
"import torch\n",
|
530 |
+
"torch.flip(test_in[0,:,:], dims = (0,)) - torch.flipud(test_in[0,:,:])"
|
531 |
+
]
|
532 |
+
},
|
533 |
+
{
|
534 |
+
"cell_type": "code",
|
535 |
+
"execution_count": 29,
|
536 |
+
"id": "aeaaab90-6a2a-4851-a1ca-28c54a446573",
|
537 |
+
"metadata": {},
|
538 |
+
"outputs": [
|
539 |
+
{
|
540 |
+
"name": "stdout",
|
541 |
+
"output_type": "stream",
|
542 |
+
"text": [
|
543 |
+
"tensor(float)\n",
|
544 |
+
"torch.float32\n",
|
545 |
+
"Input Name: modelInput\n",
|
546 |
+
"Output Name: modelOutput\n",
|
547 |
+
"[array([[ 4.3262615, -4.3409047],\n",
|
548 |
+
" [ 4.9648395, -4.968621 ],\n",
|
549 |
+
" [ 5.5126643, -5.522872 ],\n",
|
550 |
+
" [ 4.7735534, -4.8004475],\n",
|
551 |
+
" [ 4.0924144, -4.112945 ],\n",
|
552 |
+
" [ 4.588802 , -4.6043544],\n",
|
553 |
+
" [ 4.6231914, -4.617625 ],\n",
|
554 |
+
" [ 5.229881 , -5.2555394],\n",
|
555 |
+
" [ 4.877381 , -4.882144 ],\n",
|
556 |
+
" [ 5.2514744, -5.2786503],\n",
|
557 |
+
" [ 4.2948875, -4.3169603],\n",
|
558 |
+
" [ 4.5997186, -4.6177607],\n",
|
559 |
+
" [ 4.9509926, -4.9685597],\n",
|
560 |
+
" [ 4.933158 , -4.9568825],\n",
|
561 |
+
" [ 4.747336 , -4.7639017],\n",
|
562 |
+
" [ 5.020595 , -5.0202913],\n",
|
563 |
+
" [ 4.914437 , -4.9206715],\n",
|
564 |
+
" [ 5.193108 , -5.1925435],\n",
|
565 |
+
" [ 4.5233765, -4.512763 ],\n",
|
566 |
+
" [ 4.7573333, -4.762632 ],\n",
|
567 |
+
" [ 5.268702 , -5.2838397],\n",
|
568 |
+
" [ 4.857734 , -4.8605857],\n",
|
569 |
+
" [ 5.1886744, -5.2047734],\n",
|
570 |
+
" [ 5.512568 , -5.5503583],\n",
|
571 |
+
" [ 5.320961 , -5.344709 ],\n",
|
572 |
+
" [ 4.1023226, -4.1073256],\n",
|
573 |
+
" [ 5.17857 , -5.185736 ],\n",
|
574 |
+
" [ 4.997028 , -4.9933476],\n",
|
575 |
+
" [ 4.771303 , -4.767269 ],\n",
|
576 |
+
" [ 5.312805 , -5.3265243],\n",
|
577 |
+
" [ 5.0030336, -5.0492 ],\n",
|
578 |
+
" [ 5.429731 , -5.4249325]], dtype=float32)]\n"
|
579 |
+
]
|
580 |
+
}
|
581 |
+
],
|
582 |
+
"source": [
|
583 |
+
"import onnxruntime as ort\n",
|
584 |
+
"import onnx\n",
|
585 |
+
"\n",
|
586 |
+
"# Path to your ONNX model\n",
|
587 |
+
"model_path = \"models/model-47-99.125.onnx\"\n",
|
588 |
+
"\n",
|
589 |
+
"# Load the ONNX model\n",
|
590 |
+
"session = ort.InferenceSession(model_path)\n",
|
591 |
+
"\n",
|
592 |
+
"# Get input and output details\n",
|
593 |
+
"input_name = session.get_inputs()[0].name\n",
|
594 |
+
"output_name = session.get_outputs()[0].name\n",
|
595 |
+
"\n",
|
596 |
+
"print(session.get_inputs()[0].type)\n",
|
597 |
+
"print(test_in.dtype)\n",
|
598 |
+
"\n",
|
599 |
+
"print(f\"Input Name: {input_name}\")\n",
|
600 |
+
"print(f\"Output Name: {output_name}\")\n",
|
601 |
+
"\n",
|
602 |
+
"# Example Input Data (Replace with your actual input data)\n",
|
603 |
+
"import numpy as np\n",
|
604 |
+
"\n",
|
605 |
+
"# Perform inference\n",
|
606 |
+
"outputs = session.run([output_name], {input_name: intermediate.cpu().numpy()})\n",
|
607 |
+
"print(outputs)\n",
|
608 |
+
"\n",
|
609 |
+
"onnx_model = onnx.load(model_path)"
|
610 |
+
]
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"cell_type": "code",
|
614 |
+
"execution_count": 30,
|
615 |
+
"id": "f250739d-4c8a-4752-964a-d0b929c396f4",
|
616 |
+
"metadata": {},
|
617 |
+
"outputs": [],
|
618 |
+
"source": [
|
619 |
+
"# import onnxruntime as ort\n",
|
620 |
+
"# import onnx\n",
|
621 |
+
"\n",
|
622 |
+
"# # Path to your ONNX model\n",
|
623 |
+
"# model_path = \"models_mask/preproc_test.onnx\"\n",
|
624 |
+
"\n",
|
625 |
+
"# # Load the ONNX model\n",
|
626 |
+
"# session = ort.InferenceSession(model_path)\n",
|
627 |
+
"\n",
|
628 |
+
"# # Get input and output details\n",
|
629 |
+
"# input_name = session.get_inputs()[0].name\n",
|
630 |
+
"# output_name = session.get_outputs()[0].name\n",
|
631 |
+
"\n",
|
632 |
+
"# print(session.get_inputs()[0].type)\n",
|
633 |
+
"# print(test_in.dtype)\n",
|
634 |
+
"\n",
|
635 |
+
"# print(f\"Input Name: {input_name}\")\n",
|
636 |
+
"# print(f\"Output Name: {output_name}\")\n",
|
637 |
+
"\n",
|
638 |
+
"# # Example Input Data (Replace with your actual input data)\n",
|
639 |
+
"# import numpy as np\n",
|
640 |
+
"\n",
|
641 |
+
"# # Perform inference\n",
|
642 |
+
"# outputs = session.run([output_name], {input_name: test_in.cpu().numpy()})\n",
|
643 |
+
"# print(\"Model Output:\", outputs)\n",
|
644 |
+
"\n",
|
645 |
+
"# onnx_model = onnx.load(model_path)"
|
646 |
+
]
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "code",
|
650 |
+
"execution_count": 8,
|
651 |
+
"id": "24fed4e7-4838-44cc-9c3a-0862bdbe173a",
|
652 |
+
"metadata": {},
|
653 |
+
"outputs": [
|
654 |
+
{
|
655 |
+
"data": {
|
656 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAey0lEQVR4nO3df0xd9f3H8deFymXMcpURL8WCzG124o/L5Jd0dpblboQ6snZZxvaHItm6ZcFFc6NL+w+4rJMsMUiynIXtmyDZr8gaIy5zqanXH/gDQwvFVfEXjhkWvZc26r3luoBezvePxau0UHvhlvs5nOcjuX/ccw/nvDkhl2fuPedej23btgAAAAyRk+0BAAAAPok4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGCUTdkeIF2Li4t66623tHnzZnk8nmyPAwAAzoFt2zp16pRKS0uVk3P210YcFydvvfWWysrKsj0GAABYhZmZGW3duvWs6zgmTizLkmVZ+vDDDyX975crLCzM8lQAAOBcxONxlZWVafPmzZ+6rsdp360Tj8fl8/kUi8WIEwAAHCKd/9+cEAsAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMQpwAAACjECcAAMAoxAkAADBKVuJkenpajY2Nqqys1DXXXKNEIpGNMQAAgIE2ZWOnt956qw4cOKAdO3bonXfekdfrzcYYy7vbd9r9WHbmAADApdY9Tl566SVdcMEF2rFjhySpqKhovUcAAAAGS/ttneHhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo6mHnv99dd14YUXqqWlRdddd53uueeeNf0CAABgY0k7ThKJhAKBgCzLWvbxwcFBhUIhdXV1aXx8XIFAQE1NTZqdnZUkffjhh3r66af129/+ViMjIzp8+LAOHz68tt8CAABsGGnHSXNzsw4cOKA9e/Ys+3hPT4/27t2r9vZ2VVZWqq+vTwUFBerv75ckXXrppaqpqVFZWZm8Xq927dqliYmJFfc3Pz+veDy+5AYAADaujF6ts7CwoLGxMQWDwY93kJOjYDCokZERSVJtba1mZ2f17rvvanFxUcPDw7ryyitX3GZ3d7d8Pl/qVlZWlsmRAQCAYTIaJydPnlQymZTf71+y3O/3KxKJSJI2bdqke+65R1/72td07bXX6ktf+pK+9a1vrbjN/fv3KxaLpW4zMzOZHBkAABgmK5cSNzc3q7m5+ZzW9Xq9Zl1qDAAAzquMvnJSXFys3NxcRaPRJcuj0ahKSkrWtG3LslRZWana2to1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw2E1NDSsadsdHR2anJzUkSNH1jomAAAwWNpv68zNzWlqaip1f3p6WhMTEyoqKlJ5eblCoZDa2tpUU1Ojuro69fb2KpFIqL29PaODAwCAjSntODl69KgaGxtT90OhkCSpra1NAwMDam1t1YkTJ9TZ2alIJKKqqiodOnTojJNkAQAAluOxbdvO9hDnwrIsWZalZDKp1157TbFYTIWFhZnfEd+tAwBAxsXjcfl8vnP6/52VbyVeDc45AQDAHRwTJwAAwB0cEydcSgwAgDs4Jk54WwcAAHdwTJwAAAB3IE4AAIBRiBMAAGAUx8QJJ8QCAOAOjokTTogFAMAdHBMnAADAHYgTAABgFOIEAAAYxTFxwgmxAAC4g2PihBNiAQBwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAURwTJ1ytAwCAOzgmTrhaBwAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFMfECR/CBgCAOzgmTvgQNgAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwsfXAwDgDo6JEz6+HgAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABhlUzZ2WlFRocLCQuXk5Ojiiy/WE088kY0xAACAgbISJ5L03HPP6cILL8zW7gEAgKF4WwcAABgl7TgZHh5WS0uLSktL5fF4NDQ0dMY6lmWpoqJC+fn5qq+v1+jo6JLHPR6PbrzxRtXW1urPf/7zqocHAAAbT9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsbGqdZ555RmNjY/rb3/6me+65R//85z9X/xsAAIANJe04aW5u1oEDB7Rnz55lH+/p6dHevXvV3t6uyspK9fX1qaCgQP39/al1Lr30UknSli1btGvXLo2Pj6+4v/n5ecXj8SU3AACwcWX0nJOFhQWNjY0pGAx+vIOcHAWDQY2MjEj63ysvp06dkiTNzc3p8ccf11VXXbXiNru7u+Xz+VK3srKyTI4MAAAMk9E4OXnypJLJpPx+/5Llfr9fkUhEkhSNRnXDDTcoEAjo+uuv1y233KLa2toVt7l//37FYrHUbWZmJpMjAwAAw6z7pcSXX365XnjhhXNe3+v1yuv1nseJAACASTL6yklxcbFyc3MVjUaXLI9GoyopKVnTti3LUmVl5VlfZQEAAM6X0TjJy8tTdXW1wuFwatni4qLC4bAaGhrWtO2Ojg5NTk7qyJEjax0TAAAYLO23debm5jQ1NZW6Pz09rYmJCRUVFam8vFyhUEhtbW2qqalRXV2dent7lUgk1N7entHBAQDAxpR2nBw9elSNjY2p+6FQSJLU1tamgYEBtba26sSJE+rs7FQkElFVVZUOHTp0xkmy6bIsS5ZlKZlMrmk7AADAbB7btu1sD5GOeDwun8+nWCymwsLCzO/gbt9p92OZ3wcAAC6Tzv9vvlsHAAAYhTgBAABGcUyccCkxAADu4Jg44VJiAADcwTFxAgAA3IE4AQAARnFMnHDOCQAA7uCYOOGcEwAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTHxAlX6wAA4A6OiROu1gEAwB0cEycAAMAdiBMAAGAU4gQAABiFOAEAAEZxTJxwtQ4AAO7gmDjhah0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEZxTJzwOScAALiDY+KEzzkBAMAdHBMnAADAHYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTHxAnfrQMAgDs4Jk74bh0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABglKzFyfvvv6/LLrtMd955Z7ZGAAAABspanPzqV7/S9ddfn63dAwAAQ2UlTl5//XW98soram5uzsbuAQCAwdKOk+HhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo4uefzOO+9Ud3f3qocGAAAbV9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsrCTp4Ycf1hVXXKErrrhibZMDAIANaVO6P9Dc3HzWt2N6enq0d+9etbe3S5L6+vr0yCOPqL+/X/v27dPzzz+vBx54QAcPHtTc3Jw++OADFRYWqrOzc9ntzc/Pa35+PnU/Ho+nOzIAAHCQjJ5zsrCwoLGxMQWDwY93kJOjYDCokZERSVJ3d7dmZmb073//W/fee6/27t27Yph8tL7P50vdysrKMjkyAAAwTEbj5OTJk0omk/L7/UuW+/1+RSKRVW1z//79isViqdvMzEwmRgUAAIZK+22dTLr11ls/dR2v1yuv13v+hwEAAEbI6CsnxcXFys3NVTQaXbI8Go2qpKRkTdu2LEuVlZWqra1d03YAAIDZMhoneXl5qq6uVjgcTi1bXFxUOBxWQ0PDmrbd0dGhyclJHTlyZK1jAgAAg6X9ts7c3JympqZS96enpzUxMaGioiKVl5crFAqpra1NNTU1qqurU29vrxKJROrqHQAAgLNJO06OHj2qxsbG1P1QKCRJamtr08DAgFpbW3XixAl1dnYqEomoqqpKhw4dOuMk2XRZliXLspRMJte0HQAAYDaPbdt2todIRzwel8/nUywWU2FhYeZ3cLfvtPuxzO8DAACXSef/d9a++A8AAGA5xAkAADCKY+KES4kBAHAHx8QJlxIDAOAOjokTAADgDsQJAAAwimPihHNOAABwB8fECeecAADgDo6JEwAA4A7ECQAAMApxAgAAjOKYOOGEWAAA3MExccIJsQAAuINj4gQAALgDcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oRLiQEAcAfHxAmXEgMA4A6OiRMAAOAOxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oTPOQEAwB0cEyd8zgkAAO7gmDgBAADuQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjOKYOOHj6wEAcAfHxAkfXw8AgDs4Jk4AAIA7ECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMsu5x8t5776mmpkZVVVW6+uqr9X//93/rPQIAADDYpvXe4ebNmzU8PKyCggIlEgldffXV+s53vqPPfe5z6z0KAAAw0Lq/cpKbm6uCggJJ0vz8vGzblm3b6z0GAAAwVNpxMjw8rJaWFpWWlsrj8WhoaOiMdSzLUkVFhfLz81VfX6/R0dElj7/33nsKBALaunWr7rrrLhUXF6/6FwAAABtL2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Oxsap2LLrpIL7zwgqanp/WXv/xF0Wh09b8BAADYUNKOk+bmZh04cEB79uxZ9vGenh7t3btX7e3tqqysVF9fnwoKCtTf33/Gun6/X4FAQE8//fSK+5ufn1c8Hl9yAwAAG1dGzzlZWFjQ2NiYgsHgxzvIyVEwGNTIyIgkKRqN6tSpU5KkWCym4eFhbdu2bcVtdnd3y+fzpW5lZWWZHBkAABgmo3Fy8uRJJZNJ+f3+Jcv9fr8ikYgk6c0339SOHTsUCAS0Y8cO/exnP9M111yz4jb379+vWCyWus3MzGRyZAAAYJh1v5S4rq5OExMT57y+1+uV1+s9fwMBAACjZPSVk+LiYuXm5p5xgms0GlVJScmatm1ZliorK1VbW7um7QAAALNlNE7y8vJUXV2tcDicWra4uKhwOKyGhoY1bbujo0OTk5M6cuTIWscEAAAGS/ttnbm5OU1NTaXuT09Pa2JiQkVFRSovL1coFFJbW5tqampUV1en3t5eJRIJtbe3Z3RwAACwMaUdJ0ePHlVjY2PqfigUkiS1tbVpYGBAra2tOnHihDo7OxWJRFRVVaVDhw6dcZIsAADAcjy2Qz473rIsWZalZDKp1157TbFYTIWFhZnf0d2+0+7HMr8PAABcJh6Py+fzndP/73X/bp3V4pwTAADcwTFxAgAA3MExccKlxAAAuINj4oS3dQAAcAfHxAkAAHAH4gQAABiFOAEAAEZxTJxwQiwAAO7gmDjhhFgAANzBMXECAADcgTgBAABGIU4AAIBRHBMnnBALAIA7OCZOOCEWAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYxTFxwtU6AAC4g2PihKt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwuecAADgDo6JEz7nBAAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfLcOAADu4Jg44bt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFHWPU5mZma0c+dOVVZW6tprr9XBgwfXewQAAGCwTeu+w02b1Nvbq6qqKkUiEVVXV2vXrl367Gc/u96jAAAAA617nGzZskVbtmyRJJWUlKi4uFjvvPMOcQIAACSt4m2d4eFhtbS0qLS0VB6PR0NDQ2esY1mWKioqlJ+fr/r6eo2Oji67rbGxMSWTSZWVlaU9OAAA2JjSjpNEIqFAICDLspZ9fHBwUKFQSF1dXRofH1cgEFBTU5NmZ2eXrPfOO+/olltu0e9///vVTQ4AADaktN/WaW5uVnNz84qP9/T0aO/evWpvb5ck9fX16ZFHHlF/f7/27dsnSZqfn9fu3bu1b98+bd++/az7m5+f1/z8fOp+PB5Pd2QAAOAgGb1aZ2FhQWNjYwoGgx/vICdHwWBQIyMjkiTbtnXrrbfq61//um6++eZP3WZ3d7d8Pl/qxltAAABsbBmNk5MnTyqZTMrv9y9Z7vf7FYlEJEnPPvusBgcHNTQ0pKqqKlVVVen48eMrbnP//v2KxWKp28zMTCZHBgAAhln3q3VuuOEGLS4unvP6Xq9XXq/3PE4EAABMktFXToqLi5Wbm6toNLpkeTQaVUlJyZq2bVmWKisrVVtbu6btAAAAs2U0TvLy8lRdXa1wOJxatri4qHA4rIaGhjVtu6OjQ5OTkzpy5MhaxwQAAAZL+22dubk5TU1Npe5PT09rYmJCRUVFKi8vVygUUltbm2pqalRXV6fe3l4lEonU1TsAAABnk3acHD16VI2Njan7oVBIktTW1qaBgQG1trbqxIkT6uzsVCQSUVVVlQ4dOnTGSbLpsixLlmUpmUyuaTsAAMBsHtu27WwPkY54PC6fz6dYLKbCwsLM7+Bu32n3Y5nfBwAALpPO/+91/1ZiAACAsyFOAACAURwTJ1xKDACAOzgmTriUGAAAd3BMnAAAAHcgTgAAgFEcEyeccwIAgDs4Jk445wQAAHdwTJwAAAB3IE4AAIBRiBMAAGAU4gQAABjFMXHC1ToAALiDY+KEq3UAAHAHx8QJAABwB+IEAAAYhTgBAABGIU4AAIBRHBMnXK0DAIA7OCZOuFoHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfM4JAADu4Jg44XNOAABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABjFMXHCd+sAAOAOjokTvlsHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGCUrcbJnzx5dfPHF+u53v5uN3QMAAINlJU5uv/12/eEPf8jGrgEAgOGyEic7d+7U5s2bs7FrAABguLTjZHh4WC0tLSotLZXH49HQ0NAZ61iWpYqKCuXn56u+vl6jo6OZmBUAALhA2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Ozsqgacn59XPB5fcgMAABtX2nHS3NysAwcOaM+ePcs+3tPTo71796q9vV2VlZXq6+tTQUGB+vv7VzVgd3e3fD5f6lZWVraq7QAAAGfI6DknCwsLGhsbUzAY/HgHOTkKBoMaGRlZ1Tb379+vWCyWus3MzGRqXAAAYKBNmdzYyZMnlUwm5ff7lyz3+/165ZVXUveDwaBeeOEFJRIJbd26VQcPHlRDQ8Oy2/R6vfJ6vZkcEwAAGCyjcXKuHnvssbR/xrIsWZalZDJ5HiYCAMCl7vYtsyy2/nN8Qkbf1ikuLlZubq6i0eiS5dFoVCUlJWvadkdHhyYnJ3XkyJE1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw+EV37YBAAD4pLTf1pmbm9PU1FTq/vT0tCYmJlRUVKTy8nKFQiG1tbWppqZGdXV16u3tVSKRUHt7e0YHBwAAG1PacXL06FE1Njam7odCIUlSW1ubBgYG1NraqhMnTqizs1ORSERVVVU6dOjQGSfJpotzTgAAcAePbdt2todIRzwel8/nUywWU2FhYeZ3cPqJQVk+KQgAgPNqnU6ITef/d1a+WwcAAGAlxAkAADCKY+LEsixVVlaqtrY226MAAIDzyDFxwuecAADgDo6JEwAA4A7ECQAAMIpj4oRzTgAAcAfHxAnnnAAA4A6OiRMAAOAOxAkAADBK2t+tk20ffdp+PB4/PzuYP+3T/M/XfgAAMMHp//ek8/K/76P/2+fyrTmO+W6dj774b2FhQW+88Ua2xwEAAKswMzOjrVu3nnUdx8TJRxYXF/XWW29p8+bN8ng8Gd12PB5XWVmZZmZmzs+XCroIxzJzOJaZw7HMHI5l5rjlWNq2rVOnTqm0tFQ5OWc/q8Rxb+vk5OR8anGtVWFh4Yb+A1lPHMvM4VhmDscycziWmeOGY+nzLfMNyMvghFgAAGAU4gQAABiFOPkEr9errq4ueb3ebI/ieBzLzOFYZg7HMnM4lpnDsTyT406IBQAAGxuvnAAAAKMQJwAAwCjECQAAMApxAgAAjOK6OLEsSxUVFcrPz1d9fb1GR0fPuv7Bgwf15S9/Wfn5+brmmmv0j3/8Y50mNV86x3JgYEAej2fJLT8/fx2nNdPw8LBaWlpUWloqj8ejoaGhT/2ZJ598Utddd528Xq+++MUvamBg4LzP6QTpHssnn3zyjL9Jj8ejSCSyPgMbrLu7W7W1tdq8ebMuueQS7d69W6+++uqn/hzPl2dazbHk+dJlcTI4OKhQKKSuri6Nj48rEAioqalJs7Ozy67/3HPP6Qc/+IF++MMf6tixY9q9e7d2796tF198cZ0nN0+6x1L636cfvv3226nbm2++uY4TmymRSCgQCMiyrHNaf3p6WjfddJMaGxs1MTGhO+64Qz/60Y/06KOPnudJzZfusfzIq6++uuTv8pJLLjlPEzrHU089pY6ODj3//PM6fPiwPvjgA33zm99UIpFY8Wd4vlzeao6lxPOlbBepq6uzOzo6UveTyaRdWlpqd3d3L7v+9773Pfumm25asqy+vt7+yU9+cl7ndIJ0j+X9999v+3y+dZrOmSTZDz300FnX+fnPf25fddVVS5a1trbaTU1N53Ey5zmXY/nEE0/Ykux33313XWZystnZWVuS/dRTT624Ds+X5+ZcjiXPl7btmldOFhYWNDY2pmAwmFqWk5OjYDCokZGRZX9mZGRkyfqS1NTUtOL6brGaYylJc3Nzuuyyy1RWVqZvf/vbeumll9Zj3A2Fv8nMq6qq0pYtW/SNb3xDzz77bLbHMVIsFpMkFRUVrbgOf5vn5lyOpcTzpWvi5OTJk0omk/L7/UuW+/3+Fd9jjkQiaa3vFqs5ltu2bVN/f78efvhh/elPf9Li4qK2b9+u//znP+sx8oax0t9kPB7Xf//73yxN5UxbtmxRX1+fHnzwQT344IMqKyvTzp07NT4+nu3RjLK4uKg77rhDX/3qV3X11VevuB7Pl5/uXI8lz5cO/FZiOFNDQ4MaGhpS97dv364rr7xSv/vd7/TLX/4yi5PBrbZt26Zt27al7m/fvl1vvPGG7rvvPv3xj3/M4mRm6ejo0Isvvqhnnnkm26M43rkeS54vXfTKSXFxsXJzcxWNRpcsj0ajKikpWfZnSkpK0lrfLVZzLE93wQUX6Ctf+YqmpqbOx4gb1kp/k4WFhfrMZz6Tpak2jrq6Ov4mP+G2227T3//+dz3xxBPaunXrWdfl+fLs0jmWp3Pj86Vr4iQvL0/V1dUKh8OpZYuLiwqHw0sK9ZMaGhqWrC9Jhw8fXnF9t1jNsTxdMpnU8ePHtWXLlvM15obE3+T5NTExwd+kJNu2ddttt+mhhx7S448/rs9//vOf+jP8bS5vNcfydK58vsz2Gbnr6YEHHrC9Xq89MDBgT05O2j/+8Y/tiy66yI5EIrZt2/bNN99s79u3L7X+s88+a2/atMm+99577Zdfftnu6uqyL7jgAvv48ePZ+hWMke6x/MUvfmE/+uij9htvvGGPjY3Z3//+9+38/Hz7pZdeytavYIRTp07Zx44ds48dO2ZLsnt6euxjx47Zb775pm3btr1v3z775ptvTq3/r3/9yy4oKLDvuusu++WXX7Yty7Jzc3PtQ4cOZetXMEa6x/K+++6zh4aG7Ndff90+fvy4ffvtt9s5OTn2Y489lq1fwRg//elPbZ/PZz/55JP222+/nbq9//77qXV4vjw3qzmWPF/atqvixLZt+ze/+Y1dXl5u5+Xl2XV1dfbzzz+feuzGG2+029ralqz/17/+1b7iiivsvLw8+6qrrrIfeeSRdZ7YXOkcyzvuuCO1rt/vt3ft2mWPj49nYWqzfHQ56+m3j45dW1ubfeONN57xM1VVVXZeXp59+eWX2/fff/+6z22idI/lr3/9a/sLX/iCnZ+fbxcVFdk7d+60H3/88ewMb5jljqOkJX9rPF+em9UcS54vbdtj27a9fq/TAAAAnJ1rzjkBAADOQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwyv8D6KAeY7AISbEAAAAASUVORK5CYII=",
|
657 |
+
"text/plain": [
|
658 |
+
"<Figure size 640x480 with 1 Axes>"
|
659 |
+
]
|
660 |
+
},
|
661 |
+
"metadata": {},
|
662 |
+
"output_type": "display_data"
|
663 |
+
}
|
664 |
+
],
|
665 |
+
"source": [
|
666 |
+
"import matplotlib.pyplot as plt\n",
|
667 |
+
"%matplotlib inline\n",
|
668 |
+
"plt.hist(abs(intermediate-outputs[0]).ravel(), bins = 100)\n",
|
669 |
+
"plt.yscale('log')\n",
|
670 |
+
"plt.show()"
|
671 |
+
]
|
672 |
+
},
|
673 |
+
{
|
674 |
+
"cell_type": "code",
|
675 |
+
"execution_count": 15,
|
676 |
+
"id": "71cb219e-b91a-4629-99f6-00db786903c7",
|
677 |
+
"metadata": {},
|
678 |
+
"outputs": [
|
679 |
+
{
|
680 |
+
"data": {
|
681 |
+
"text/plain": [
|
682 |
+
"tensor([1.1902e-03, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00,\n",
|
683 |
+
" 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00])"
|
684 |
+
]
|
685 |
+
},
|
686 |
+
"execution_count": 15,
|
687 |
+
"metadata": {},
|
688 |
+
"output_type": "execute_result"
|
689 |
+
}
|
690 |
+
],
|
691 |
+
"source": [
|
692 |
+
"torch.sort(abs(intermediate-outputs[0]).ravel())[0][-10:]"
|
693 |
+
]
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"cell_type": "code",
|
697 |
+
"execution_count": null,
|
698 |
+
"id": "92ba5920-5451-4bb1-af0e-5ea987841ab1",
|
699 |
+
"metadata": {},
|
700 |
+
"outputs": [],
|
701 |
+
"source": [
|
702 |
+
"import onnxruntime as ort\n",
|
703 |
+
"\n",
|
704 |
+
"session_options = ort.SessionOptions()\n",
|
705 |
+
"session_options.log_severity_level = 0 # Verbose logging\n",
|
706 |
+
"session = ort.InferenceSession(\"models_mask/preproc_test.onnx\", sess_options=session_options)"
|
707 |
+
]
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"cell_type": "code",
|
711 |
+
"execution_count": null,
|
712 |
+
"id": "3277b343-245d-4ac8-a91c-373061dcbf53",
|
713 |
+
"metadata": {},
|
714 |
+
"outputs": [],
|
715 |
+
"source": [
|
716 |
+
"import matplotlib.pyplot as plt\n",
|
717 |
+
"%matplotlib inline\n",
|
718 |
+
"plt.imshow(outputs[0][0,8,:,:])\n",
|
719 |
+
"plt.show()"
|
720 |
+
]
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"cell_type": "code",
|
724 |
+
"execution_count": null,
|
725 |
+
"id": "67e99ef8-e49a-4037-a818-244555b0bdc5",
|
726 |
+
"metadata": {},
|
727 |
+
"outputs": [],
|
728 |
+
"source": [
|
729 |
+
"import onnx\n",
|
730 |
+
"\n",
|
731 |
+
"# Path to your ONNX model\n",
|
732 |
+
"model_path = \"models/model-47-99.125.onnx\"\n",
|
733 |
+
"\n",
|
734 |
+
"# Load the ONNX model\n",
|
735 |
+
"onnx_model = onnx.load(model_path)\n",
|
736 |
+
"\n",
|
737 |
+
"# Check the model for validity\n",
|
738 |
+
"onnx.checker.check_model(onnx_model)\n",
|
739 |
+
"\n",
|
740 |
+
"# Print model graph structure (optional)\n",
|
741 |
+
"print(onnx.helper.printable_graph(onnx_model.graph))\n"
|
742 |
+
]
|
743 |
+
}
|
744 |
+
],
|
745 |
+
"metadata": {
|
746 |
+
"kernelspec": {
|
747 |
+
"display_name": "Python 3 (ipykernel)",
|
748 |
+
"language": "python",
|
749 |
+
"name": "python3"
|
750 |
+
},
|
751 |
+
"language_info": {
|
752 |
+
"codemirror_mode": {
|
753 |
+
"name": "ipython",
|
754 |
+
"version": 3
|
755 |
+
},
|
756 |
+
"file_extension": ".py",
|
757 |
+
"mimetype": "text/x-python",
|
758 |
+
"name": "python",
|
759 |
+
"nbconvert_exporter": "python",
|
760 |
+
"pygments_lexer": "ipython3",
|
761 |
+
"version": "3.11.9"
|
762 |
+
}
|
763 |
+
},
|
764 |
+
"nbformat": 4,
|
765 |
+
"nbformat_minor": 5
|
766 |
+
}
|
models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"6\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
" 0%| | 0/48 [00:22<?, ?it/s]\n"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"ename": "AttributeError",
|
29 |
+
"evalue": "Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n",
|
30 |
+
"output_type": "error",
|
31 |
+
"traceback": [
|
32 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
33 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
34 |
+
"Cell \u001b[0;32mIn[1], line 50\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[0;32m---> 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 51\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmax(outputs, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 52\u001b[0m results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mextend(outputs\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mtolist())\n",
|
35 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
36 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
37 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:186\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 185\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 186\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparallel_apply(replicas, inputs, module_kwargs)\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
|
38 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:201\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 201\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parallel_apply(replicas, inputs, kwargs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(replicas)])\n",
|
39 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:108\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 106\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m--> 108\u001b[0m output\u001b[38;5;241m.\u001b[39mreraise()\n\u001b[1;32m 109\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
|
40 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/_utils.py:706\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 703\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 704\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 706\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
|
41 |
+
"\u001b[0;31mAttributeError\u001b[0m: Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
47 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
48 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
49 |
+
"from tqdm import tqdm\n",
|
50 |
+
"import torch\n",
|
51 |
+
"import numpy as np\n",
|
52 |
+
"from resnet_model import ResidualBlock, ResNet\n",
|
53 |
+
"import torch\n",
|
54 |
+
"import torch.nn as nn\n",
|
55 |
+
"import torch.optim as optim\n",
|
56 |
+
"from tqdm import tqdm \n",
|
57 |
+
"import torch.nn.functional as F\n",
|
58 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
59 |
+
"import pickle\n",
|
60 |
+
"\n",
|
61 |
+
"torch.manual_seed(1)\n",
|
62 |
+
"# torch.manual_seed(42)\n",
|
63 |
+
"\n",
|
64 |
+
"\n",
|
65 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
66 |
+
"num_gpus = torch.cuda.device_count()\n",
|
67 |
+
"print(num_gpus)\n",
|
68 |
+
"\n",
|
69 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
70 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
71 |
+
"\n",
|
72 |
+
"num_classes = 2\n",
|
73 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
74 |
+
"\n",
|
75 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
76 |
+
"model = nn.DataParallel(model)\n",
|
77 |
+
"model = model.to(device)\n",
|
78 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
79 |
+
"print(\"num params \",params)\n",
|
80 |
+
"\n",
|
81 |
+
"model_1 = 'models/model-23-99.045.pt'\n",
|
82 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
83 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
84 |
+
"model = model.eval()\n",
|
85 |
+
"\n",
|
86 |
+
"# eval\n",
|
87 |
+
"val_loss = 0.0\n",
|
88 |
+
"correct_valid = 0\n",
|
89 |
+
"total = 0\n",
|
90 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
91 |
+
"model.eval()\n",
|
92 |
+
"with torch.no_grad():\n",
|
93 |
+
" for images, labels in tqdm(testloader):\n",
|
94 |
+
" inputs, labels = images.to(device), labels\n",
|
95 |
+
" outputs = model(inputs)\n",
|
96 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
97 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
98 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
99 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
100 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
101 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
102 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
103 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
104 |
+
" total += labels[0].size(0)\n",
|
105 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
106 |
+
"# Calculate training accuracy after each epoch\n",
|
107 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
108 |
+
"print(\"===========================\")\n",
|
109 |
+
"print('accuracy: ', val_accuracy)\n",
|
110 |
+
"print(\"===========================\")\n",
|
111 |
+
"\n",
|
112 |
+
"import pickle\n",
|
113 |
+
"\n",
|
114 |
+
"# Pickle the dictionary to a file\n",
|
115 |
+
"with open('models/test_42.pkl', 'wb') as f:\n",
|
116 |
+
" pickle.dump(results, f)"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": null,
|
122 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
127 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
128 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
129 |
+
"from tqdm import tqdm\n",
|
130 |
+
"import torch\n",
|
131 |
+
"import numpy as np\n",
|
132 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
133 |
+
"import torch\n",
|
134 |
+
"import torch.nn as nn\n",
|
135 |
+
"import torch.optim as optim\n",
|
136 |
+
"from tqdm import tqdm \n",
|
137 |
+
"import torch.nn.functional as F\n",
|
138 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
139 |
+
"import pickle\n",
|
140 |
+
"\n",
|
141 |
+
"torch.manual_seed(1)\n",
|
142 |
+
"# torch.manual_seed(42)\n",
|
143 |
+
"\n",
|
144 |
+
"\n",
|
145 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
146 |
+
"num_gpus = torch.cuda.device_count()\n",
|
147 |
+
"print(num_gpus)\n",
|
148 |
+
"\n",
|
149 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
150 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
151 |
+
"\n",
|
152 |
+
"num_classes = 2\n",
|
153 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
154 |
+
"\n",
|
155 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
156 |
+
"model = nn.DataParallel(model)\n",
|
157 |
+
"model = model.to(device)\n",
|
158 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
159 |
+
"print(\"num params \",params)\n",
|
160 |
+
"\n",
|
161 |
+
"\n",
|
162 |
+
"model_1 = 'models/model-14-98.005.pt'\n",
|
163 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
164 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
165 |
+
"model = model.eval()\n",
|
166 |
+
"\n",
|
167 |
+
"# eval\n",
|
168 |
+
"val_loss = 0.0\n",
|
169 |
+
"correct_valid = 0\n",
|
170 |
+
"total = 0\n",
|
171 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
172 |
+
"model.eval()\n",
|
173 |
+
"with torch.no_grad():\n",
|
174 |
+
" for images, labels in tqdm(testloader):\n",
|
175 |
+
" inputs, labels = images.to(device), labels\n",
|
176 |
+
" outputs = model(inputs)\n",
|
177 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
178 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
179 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
180 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
181 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
182 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
183 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
184 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
185 |
+
" total += labels[0].size(0)\n",
|
186 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
187 |
+
" \n",
|
188 |
+
"# Calculate training accuracy after each epoch\n",
|
189 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
190 |
+
"print(\"===========================\")\n",
|
191 |
+
"print('accuracy: ', val_accuracy)\n",
|
192 |
+
"print(\"===========================\")\n",
|
193 |
+
"\n",
|
194 |
+
"import pickle\n",
|
195 |
+
"\n",
|
196 |
+
"# Pickle the dictionary to a file\n",
|
197 |
+
"with open('models/test_1.pkl', 'wb') as f:\n",
|
198 |
+
" pickle.dump(results, f)"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
209 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
210 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
211 |
+
"from tqdm import tqdm\n",
|
212 |
+
"import torch\n",
|
213 |
+
"import numpy as np\n",
|
214 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
215 |
+
"import torch\n",
|
216 |
+
"import torch.nn as nn\n",
|
217 |
+
"import torch.optim as optim\n",
|
218 |
+
"from tqdm import tqdm \n",
|
219 |
+
"import torch.nn.functional as F\n",
|
220 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
221 |
+
"import pickle\n",
|
222 |
+
"\n",
|
223 |
+
"torch.manual_seed(1)\n",
|
224 |
+
"# torch.manual_seed(42)\n",
|
225 |
+
"\n",
|
226 |
+
"\n",
|
227 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
228 |
+
"num_gpus = torch.cuda.device_count()\n",
|
229 |
+
"print(num_gpus)\n",
|
230 |
+
"\n",
|
231 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
232 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
233 |
+
"\n",
|
234 |
+
"num_classes = 2\n",
|
235 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
236 |
+
"\n",
|
237 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
238 |
+
"model = nn.DataParallel(model)\n",
|
239 |
+
"model = model.to(device)\n",
|
240 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
241 |
+
"print(\"num params \",params)\n",
|
242 |
+
"\n",
|
243 |
+
"\n",
|
244 |
+
"model_1 = 'models/model-28-98.955.pt'\n",
|
245 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
246 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
247 |
+
"model = model.eval()\n",
|
248 |
+
"\n",
|
249 |
+
"# eval\n",
|
250 |
+
"val_loss = 0.0\n",
|
251 |
+
"correct_valid = 0\n",
|
252 |
+
"total = 0\n",
|
253 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
254 |
+
"model.eval()\n",
|
255 |
+
"with torch.no_grad():\n",
|
256 |
+
" for images, labels in tqdm(testloader):\n",
|
257 |
+
" inputs, labels = images.to(device), labels\n",
|
258 |
+
" outputs = model(inputs)\n",
|
259 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
260 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
261 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
262 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
263 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
264 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
265 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
266 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
267 |
+
" total += labels[0].size(0)\n",
|
268 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
269 |
+
" \n",
|
270 |
+
"# Calculate training accuracy after each epoch\n",
|
271 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
272 |
+
"print(\"===========================\")\n",
|
273 |
+
"print('accuracy: ', val_accuracy)\n",
|
274 |
+
"print(\"===========================\")\n",
|
275 |
+
"\n",
|
276 |
+
"import pickle\n",
|
277 |
+
"\n",
|
278 |
+
"# Pickle the dictionary to a file\n",
|
279 |
+
"with open('models/test_7109.pkl', 'wb') as f:\n",
|
280 |
+
" pickle.dump(results, f)"
|
281 |
+
]
|
282 |
+
}
|
283 |
+
],
|
284 |
+
"metadata": {
|
285 |
+
"kernelspec": {
|
286 |
+
"display_name": "Python 3 (ipykernel)",
|
287 |
+
"language": "python",
|
288 |
+
"name": "python3"
|
289 |
+
},
|
290 |
+
"language_info": {
|
291 |
+
"codemirror_mode": {
|
292 |
+
"name": "ipython",
|
293 |
+
"version": 3
|
294 |
+
},
|
295 |
+
"file_extension": ".py",
|
296 |
+
"mimetype": "text/x-python",
|
297 |
+
"name": "python",
|
298 |
+
"nbconvert_exporter": "python",
|
299 |
+
"pygments_lexer": "ipython3",
|
300 |
+
"version": "3.11.9"
|
301 |
+
}
|
302 |
+
},
|
303 |
+
"nbformat": 4,
|
304 |
+
"nbformat_minor": 5
|
305 |
+
}
|
models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"2\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:39<00:00, 1.21it/s]"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"===========================\n",
|
32 |
+
"accuracy: 98.82\n",
|
33 |
+
"===========================\n",
|
34 |
+
"False Positive Rate: 0.010\n",
|
35 |
+
"Precision: 0.990\n",
|
36 |
+
"Recall: 0.986\n",
|
37 |
+
"F1 Score: 0.988\n"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "stderr",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"\n"
|
45 |
+
]
|
46 |
+
}
|
47 |
+
],
|
48 |
+
"source": [
|
49 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
50 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
51 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
52 |
+
"from tqdm import tqdm\n",
|
53 |
+
"import torch\n",
|
54 |
+
"import numpy as np\n",
|
55 |
+
"from resnet_model import ResidualBlock, ResNet\n",
|
56 |
+
"import torch\n",
|
57 |
+
"import torch.nn as nn\n",
|
58 |
+
"import torch.optim as optim\n",
|
59 |
+
"from tqdm import tqdm \n",
|
60 |
+
"import torch.nn.functional as F\n",
|
61 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
62 |
+
"import pickle\n",
|
63 |
+
"\n",
|
64 |
+
"torch.manual_seed(1)\n",
|
65 |
+
"# torch.manual_seed(42)\n",
|
66 |
+
"\n",
|
67 |
+
"\n",
|
68 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
69 |
+
"num_gpus = torch.cuda.device_count()\n",
|
70 |
+
"print(num_gpus)\n",
|
71 |
+
"\n",
|
72 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
73 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
74 |
+
"\n",
|
75 |
+
"num_classes = 2\n",
|
76 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
77 |
+
"\n",
|
78 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
79 |
+
"model = nn.DataParallel(model)\n",
|
80 |
+
"model = model.to(device)\n",
|
81 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
82 |
+
"print(\"num params \",params)\n",
|
83 |
+
"\n",
|
84 |
+
"model_1 = 'models/model-23-99.045.pt'\n",
|
85 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
86 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
87 |
+
"model = model.eval()\n",
|
88 |
+
"\n",
|
89 |
+
"# eval\n",
|
90 |
+
"val_loss = 0.0\n",
|
91 |
+
"correct_valid = 0\n",
|
92 |
+
"total = 0\n",
|
93 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
94 |
+
"model.eval()\n",
|
95 |
+
"with torch.no_grad():\n",
|
96 |
+
" for images, labels in tqdm(testloader):\n",
|
97 |
+
" inputs, labels = images.to(device), labels\n",
|
98 |
+
" outputs = model(inputs)\n",
|
99 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
100 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
101 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
102 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
103 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
104 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
105 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
106 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
107 |
+
" total += labels[0].size(0)\n",
|
108 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
109 |
+
"# Calculate training accuracy after each epoch\n",
|
110 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
111 |
+
"print(\"===========================\")\n",
|
112 |
+
"print('accuracy: ', val_accuracy)\n",
|
113 |
+
"print(\"===========================\")\n",
|
114 |
+
"\n",
|
115 |
+
"import pickle\n",
|
116 |
+
"\n",
|
117 |
+
"# Pickle the dictionary to a file\n",
|
118 |
+
"with open('models/test_42.pkl', 'wb') as f:\n",
|
119 |
+
" pickle.dump(results, f)\n",
|
120 |
+
"\n",
|
121 |
+
"\n",
|
122 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
123 |
+
"from sklearn.metrics import confusion_matrix\n",
|
124 |
+
"\n",
|
125 |
+
"# Example binary labels\n",
|
126 |
+
"true = results['true'] # ground truth\n",
|
127 |
+
"pred = results['pred'] # predicted\n",
|
128 |
+
"\n",
|
129 |
+
"# Compute metrics\n",
|
130 |
+
"precision = precision_score(true, pred)\n",
|
131 |
+
"recall = recall_score(true, pred)\n",
|
132 |
+
"f1 = f1_score(true, pred)\n",
|
133 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
134 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
135 |
+
"\n",
|
136 |
+
"# Compute FPR\n",
|
137 |
+
"fpr = fp / (fp + tn)\n",
|
138 |
+
"\n",
|
139 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
140 |
+
"\n",
|
141 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
142 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
143 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": 2,
|
149 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
150 |
+
"metadata": {},
|
151 |
+
"outputs": [
|
152 |
+
{
|
153 |
+
"name": "stdout",
|
154 |
+
"output_type": "stream",
|
155 |
+
"text": [
|
156 |
+
"2\n",
|
157 |
+
"num params encoder 50840\n",
|
158 |
+
"num params 21496282\n"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"name": "stderr",
|
163 |
+
"output_type": "stream",
|
164 |
+
"text": [
|
165 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
166 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
167 |
+
"100%|███████████████████████████████████████████| 48/48 [00:51<00:00, 1.07s/it]"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"name": "stdout",
|
172 |
+
"output_type": "stream",
|
173 |
+
"text": [
|
174 |
+
"===========================\n",
|
175 |
+
"accuracy: 97.185\n",
|
176 |
+
"===========================\n",
|
177 |
+
"False Positive Rate: 0.038\n",
|
178 |
+
"Precision: 0.963\n",
|
179 |
+
"Recall: 0.981\n",
|
180 |
+
"F1 Score: 0.972\n"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"name": "stderr",
|
185 |
+
"output_type": "stream",
|
186 |
+
"text": [
|
187 |
+
"\n"
|
188 |
+
]
|
189 |
+
}
|
190 |
+
],
|
191 |
+
"source": [
|
192 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
193 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
194 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
195 |
+
"from tqdm import tqdm\n",
|
196 |
+
"import torch\n",
|
197 |
+
"import numpy as np\n",
|
198 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
199 |
+
"import torch\n",
|
200 |
+
"import torch.nn as nn\n",
|
201 |
+
"import torch.optim as optim\n",
|
202 |
+
"from tqdm import tqdm \n",
|
203 |
+
"import torch.nn.functional as F\n",
|
204 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
205 |
+
"import pickle\n",
|
206 |
+
"\n",
|
207 |
+
"torch.manual_seed(1)\n",
|
208 |
+
"# torch.manual_seed(42)\n",
|
209 |
+
"\n",
|
210 |
+
"\n",
|
211 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
212 |
+
"num_gpus = torch.cuda.device_count()\n",
|
213 |
+
"print(num_gpus)\n",
|
214 |
+
"\n",
|
215 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
216 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
217 |
+
"\n",
|
218 |
+
"num_classes = 2\n",
|
219 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
220 |
+
"\n",
|
221 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
222 |
+
"model = nn.DataParallel(model)\n",
|
223 |
+
"model = model.to(device)\n",
|
224 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
225 |
+
"print(\"num params \",params)\n",
|
226 |
+
"\n",
|
227 |
+
"\n",
|
228 |
+
"model_1 = 'models/model-14-98.005.pt'\n",
|
229 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
230 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
231 |
+
"model = model.eval()\n",
|
232 |
+
"\n",
|
233 |
+
"# eval\n",
|
234 |
+
"val_loss = 0.0\n",
|
235 |
+
"correct_valid = 0\n",
|
236 |
+
"total = 0\n",
|
237 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
238 |
+
"model.eval()\n",
|
239 |
+
"with torch.no_grad():\n",
|
240 |
+
" for images, labels in tqdm(testloader):\n",
|
241 |
+
" inputs, labels = images.to(device), labels\n",
|
242 |
+
" outputs = model(inputs)\n",
|
243 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
244 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
245 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
246 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
247 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
248 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
249 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
250 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
251 |
+
" total += labels[0].size(0)\n",
|
252 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
253 |
+
" \n",
|
254 |
+
"# Calculate training accuracy after each epoch\n",
|
255 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
256 |
+
"print(\"===========================\")\n",
|
257 |
+
"print('accuracy: ', val_accuracy)\n",
|
258 |
+
"print(\"===========================\")\n",
|
259 |
+
"\n",
|
260 |
+
"import pickle\n",
|
261 |
+
"\n",
|
262 |
+
"# Pickle the dictionary to a file\n",
|
263 |
+
"with open('models/test_1.pkl', 'wb') as f:\n",
|
264 |
+
" pickle.dump(results, f)\n",
|
265 |
+
"\n",
|
266 |
+
"\n",
|
267 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
268 |
+
"from sklearn.metrics import confusion_matrix\n",
|
269 |
+
"\n",
|
270 |
+
"# Example binary labels\n",
|
271 |
+
"true = results['true'] # ground truth\n",
|
272 |
+
"pred = results['pred'] # predicted\n",
|
273 |
+
"\n",
|
274 |
+
"# Compute metrics\n",
|
275 |
+
"precision = precision_score(true, pred)\n",
|
276 |
+
"recall = recall_score(true, pred)\n",
|
277 |
+
"f1 = f1_score(true, pred)\n",
|
278 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
279 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
280 |
+
"\n",
|
281 |
+
"# Compute FPR\n",
|
282 |
+
"fpr = fp / (fp + tn)\n",
|
283 |
+
"\n",
|
284 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
285 |
+
"\n",
|
286 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
287 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
288 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 3,
|
294 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [
|
297 |
+
{
|
298 |
+
"name": "stdout",
|
299 |
+
"output_type": "stream",
|
300 |
+
"text": [
|
301 |
+
"2\n",
|
302 |
+
"num params encoder 50840\n",
|
303 |
+
"num params 21496282\n"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"name": "stderr",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
311 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
312 |
+
"100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.11s/it]"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"name": "stdout",
|
317 |
+
"output_type": "stream",
|
318 |
+
"text": [
|
319 |
+
"===========================\n",
|
320 |
+
"accuracy: 98.455\n",
|
321 |
+
"===========================\n",
|
322 |
+
"False Positive Rate: 0.010\n",
|
323 |
+
"Precision: 0.990\n",
|
324 |
+
"Recall: 0.979\n",
|
325 |
+
"F1 Score: 0.984\n"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"name": "stderr",
|
330 |
+
"output_type": "stream",
|
331 |
+
"text": [
|
332 |
+
"\n"
|
333 |
+
]
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
338 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
339 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
340 |
+
"from tqdm import tqdm\n",
|
341 |
+
"import torch\n",
|
342 |
+
"import numpy as np\n",
|
343 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
344 |
+
"import torch\n",
|
345 |
+
"import torch.nn as nn\n",
|
346 |
+
"import torch.optim as optim\n",
|
347 |
+
"from tqdm import tqdm \n",
|
348 |
+
"import torch.nn.functional as F\n",
|
349 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
350 |
+
"import pickle\n",
|
351 |
+
"\n",
|
352 |
+
"torch.manual_seed(1)\n",
|
353 |
+
"# torch.manual_seed(42)\n",
|
354 |
+
"\n",
|
355 |
+
"\n",
|
356 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
357 |
+
"num_gpus = torch.cuda.device_count()\n",
|
358 |
+
"print(num_gpus)\n",
|
359 |
+
"\n",
|
360 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
361 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
362 |
+
"\n",
|
363 |
+
"num_classes = 2\n",
|
364 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
365 |
+
"\n",
|
366 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
367 |
+
"model = nn.DataParallel(model)\n",
|
368 |
+
"model = model.to(device)\n",
|
369 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
370 |
+
"print(\"num params \",params)\n",
|
371 |
+
"\n",
|
372 |
+
"\n",
|
373 |
+
"model_1 = 'models/model-28-98.955.pt'\n",
|
374 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
375 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
376 |
+
"model = model.eval()\n",
|
377 |
+
"\n",
|
378 |
+
"# eval\n",
|
379 |
+
"val_loss = 0.0\n",
|
380 |
+
"correct_valid = 0\n",
|
381 |
+
"total = 0\n",
|
382 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
383 |
+
"model.eval()\n",
|
384 |
+
"with torch.no_grad():\n",
|
385 |
+
" for images, labels in tqdm(testloader):\n",
|
386 |
+
" inputs, labels = images.to(device), labels\n",
|
387 |
+
" outputs = model(inputs)\n",
|
388 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
389 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
390 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
391 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
392 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
393 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
394 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
395 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
396 |
+
" total += labels[0].size(0)\n",
|
397 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
398 |
+
" \n",
|
399 |
+
"# Calculate training accuracy after each epoch\n",
|
400 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
401 |
+
"print(\"===========================\")\n",
|
402 |
+
"print('accuracy: ', val_accuracy)\n",
|
403 |
+
"print(\"===========================\")\n",
|
404 |
+
"\n",
|
405 |
+
"import pickle\n",
|
406 |
+
"\n",
|
407 |
+
"# Pickle the dictionary to a file\n",
|
408 |
+
"with open('models/test_7109.pkl', 'wb') as f:\n",
|
409 |
+
" pickle.dump(results, f)\n",
|
410 |
+
"\n",
|
411 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
412 |
+
"from sklearn.metrics import confusion_matrix\n",
|
413 |
+
"\n",
|
414 |
+
"# Example binary labels\n",
|
415 |
+
"true = results['true'] # ground truth\n",
|
416 |
+
"pred = results['pred'] # predicted\n",
|
417 |
+
"\n",
|
418 |
+
"# Compute metrics\n",
|
419 |
+
"precision = precision_score(true, pred)\n",
|
420 |
+
"recall = recall_score(true, pred)\n",
|
421 |
+
"f1 = f1_score(true, pred)\n",
|
422 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
423 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
424 |
+
"\n",
|
425 |
+
"# Compute FPR\n",
|
426 |
+
"fpr = fp / (fp + tn)\n",
|
427 |
+
"\n",
|
428 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
429 |
+
"\n",
|
430 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
431 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
432 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": 4,
|
438 |
+
"id": "ad4ef08f-6a9b-495f-bb93-0a13251f0adb",
|
439 |
+
"metadata": {},
|
440 |
+
"outputs": [
|
441 |
+
{
|
442 |
+
"name": "stdout",
|
443 |
+
"output_type": "stream",
|
444 |
+
"text": [
|
445 |
+
"98.15333333333332 0.7007416705811665\n",
|
446 |
+
"0.9803333333333333 0.0009428090415820641\n",
|
447 |
+
"0.9813333333333333 0.006798692684790386\n",
|
448 |
+
"0.019333333333333334 0.013199326582148887\n"
|
449 |
+
]
|
450 |
+
}
|
451 |
+
],
|
452 |
+
"source": [
|
453 |
+
"# acc\n",
|
454 |
+
"print(np.mean([98.82,97.185,98.455]), np.std([98.82,97.185,98.455]))\n",
|
455 |
+
"# recall\n",
|
456 |
+
"print(np.mean([0.981,0.981, 0.979]), np.std([0.981,0.981, 0.979]))\n",
|
457 |
+
"# f1\n",
|
458 |
+
"print(np.mean([0.988,0.972,0.984]),np.std([0.988,0.972,0.984]))\n",
|
459 |
+
"# fp\n",
|
460 |
+
"print(np.mean([0.010,0.038,0.010]),np.std([0.010,0.038,0.010]))"
|
461 |
+
]
|
462 |
+
}
|
463 |
+
],
|
464 |
+
"metadata": {
|
465 |
+
"kernelspec": {
|
466 |
+
"display_name": "Python 3 (ipykernel)",
|
467 |
+
"language": "python",
|
468 |
+
"name": "python3"
|
469 |
+
},
|
470 |
+
"language_info": {
|
471 |
+
"codemirror_mode": {
|
472 |
+
"name": "ipython",
|
473 |
+
"version": 3
|
474 |
+
},
|
475 |
+
"file_extension": ".py",
|
476 |
+
"mimetype": "text/x-python",
|
477 |
+
"name": "python",
|
478 |
+
"nbconvert_exporter": "python",
|
479 |
+
"pygments_lexer": "ipython3",
|
480 |
+
"version": "3.11.9"
|
481 |
+
}
|
482 |
+
},
|
483 |
+
"nbformat": 4,
|
484 |
+
"nbformat_minor": 5
|
485 |
+
}
|
models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"6\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
"100%|███████████████████████████████████████████| 48/48 [01:43<00:00, 2.16s/it]"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"===========================\n",
|
32 |
+
"accuracy: 98.94\n",
|
33 |
+
"===========================\n"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "stderr",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"\n"
|
41 |
+
]
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"source": [
|
45 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
46 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
47 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
48 |
+
"from tqdm import tqdm\n",
|
49 |
+
"import torch\n",
|
50 |
+
"import numpy as np\n",
|
51 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
52 |
+
"import torch\n",
|
53 |
+
"import torch.nn as nn\n",
|
54 |
+
"import torch.optim as optim\n",
|
55 |
+
"from tqdm import tqdm \n",
|
56 |
+
"import torch.nn.functional as F\n",
|
57 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
58 |
+
"import pickle\n",
|
59 |
+
"\n",
|
60 |
+
"torch.manual_seed(1)\n",
|
61 |
+
"# torch.manual_seed(42)\n",
|
62 |
+
"\n",
|
63 |
+
"\n",
|
64 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
65 |
+
"num_gpus = torch.cuda.device_count()\n",
|
66 |
+
"print(num_gpus)\n",
|
67 |
+
"\n",
|
68 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
69 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
70 |
+
"\n",
|
71 |
+
"num_classes = 2\n",
|
72 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
73 |
+
"\n",
|
74 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
75 |
+
"model = nn.DataParallel(model)\n",
|
76 |
+
"model = model.to(device)\n",
|
77 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
78 |
+
"print(\"num params \",params)\n",
|
79 |
+
"\n",
|
80 |
+
"model_1 = 'models_8/model-25-99.31_7109.pt'\n",
|
81 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
82 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
83 |
+
"model = model.eval()\n",
|
84 |
+
"\n",
|
85 |
+
"# eval\n",
|
86 |
+
"val_loss = 0.0\n",
|
87 |
+
"correct_valid = 0\n",
|
88 |
+
"total = 0\n",
|
89 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
90 |
+
"model.eval()\n",
|
91 |
+
"with torch.no_grad():\n",
|
92 |
+
" for images, labels in tqdm(testloader):\n",
|
93 |
+
" inputs, labels = images.to(device), labels\n",
|
94 |
+
" outputs = model(inputs, return_mask = True)\n",
|
95 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
96 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
97 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
98 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
99 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
100 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
101 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
102 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
103 |
+
" total += labels[0].size(0)\n",
|
104 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
105 |
+
" \n",
|
106 |
+
" \n",
|
107 |
+
"# Calculate training accuracy after each epoch\n",
|
108 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
109 |
+
"print(\"===========================\")\n",
|
110 |
+
"print('accuracy: ', val_accuracy)\n",
|
111 |
+
"print(\"===========================\")\n",
|
112 |
+
"\n",
|
113 |
+
"import pickle\n",
|
114 |
+
"\n",
|
115 |
+
"# Pickle the dictionary to a file\n",
|
116 |
+
"with open('models_8/test_7109.pkl', 'wb') as f:\n",
|
117 |
+
" pickle.dump(results, f)"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 2,
|
123 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"name": "stdout",
|
128 |
+
"output_type": "stream",
|
129 |
+
"text": [
|
130 |
+
"6\n",
|
131 |
+
"num params encoder 50840\n",
|
132 |
+
"num params 21496282\n"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"name": "stderr",
|
137 |
+
"output_type": "stream",
|
138 |
+
"text": [
|
139 |
+
"100%|██████████████████��████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stdout",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"===========================\n",
|
147 |
+
"accuracy: 99.17\n",
|
148 |
+
"===========================\n"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"name": "stderr",
|
153 |
+
"output_type": "stream",
|
154 |
+
"text": [
|
155 |
+
"\n"
|
156 |
+
]
|
157 |
+
}
|
158 |
+
],
|
159 |
+
"source": [
|
160 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
161 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
162 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
163 |
+
"from tqdm import tqdm\n",
|
164 |
+
"import torch\n",
|
165 |
+
"import numpy as np\n",
|
166 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
167 |
+
"import torch\n",
|
168 |
+
"import torch.nn as nn\n",
|
169 |
+
"import torch.optim as optim\n",
|
170 |
+
"from tqdm import tqdm \n",
|
171 |
+
"import torch.nn.functional as F\n",
|
172 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
173 |
+
"import pickle\n",
|
174 |
+
"\n",
|
175 |
+
"torch.manual_seed(1)\n",
|
176 |
+
"# torch.manual_seed(42)\n",
|
177 |
+
"\n",
|
178 |
+
"\n",
|
179 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
180 |
+
"num_gpus = torch.cuda.device_count()\n",
|
181 |
+
"print(num_gpus)\n",
|
182 |
+
"\n",
|
183 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
184 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
185 |
+
"\n",
|
186 |
+
"num_classes = 2\n",
|
187 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
188 |
+
"\n",
|
189 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
190 |
+
"model = nn.DataParallel(model)\n",
|
191 |
+
"model = model.to(device)\n",
|
192 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
193 |
+
"print(\"num params \",params)\n",
|
194 |
+
"\n",
|
195 |
+
"\n",
|
196 |
+
"model_1 = 'models_8/model-44-99.445_42.pt'\n",
|
197 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
198 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
199 |
+
"model = model.eval()\n",
|
200 |
+
"\n",
|
201 |
+
"# eval\n",
|
202 |
+
"val_loss = 0.0\n",
|
203 |
+
"correct_valid = 0\n",
|
204 |
+
"total = 0\n",
|
205 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
206 |
+
"model.eval()\n",
|
207 |
+
"with torch.no_grad():\n",
|
208 |
+
" for images, labels in tqdm(testloader):\n",
|
209 |
+
" inputs, labels = images.to(device), labels\n",
|
210 |
+
" outputs = model(inputs, return_mask = True)\n",
|
211 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
212 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
213 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
214 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
215 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
216 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
217 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
218 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
219 |
+
" total += labels[0].size(0)\n",
|
220 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
221 |
+
" \n",
|
222 |
+
"# Calculate training accuracy after each epoch\n",
|
223 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
224 |
+
"print(\"===========================\")\n",
|
225 |
+
"print('accuracy: ', val_accuracy)\n",
|
226 |
+
"print(\"===========================\")\n",
|
227 |
+
"\n",
|
228 |
+
"import pickle\n",
|
229 |
+
"\n",
|
230 |
+
"# Pickle the dictionary to a file\n",
|
231 |
+
"with open('models_8/test_42.pkl', 'wb') as f:\n",
|
232 |
+
" pickle.dump(results, f)"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": 3,
|
238 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
239 |
+
"metadata": {},
|
240 |
+
"outputs": [
|
241 |
+
{
|
242 |
+
"name": "stdout",
|
243 |
+
"output_type": "stream",
|
244 |
+
"text": [
|
245 |
+
"6\n",
|
246 |
+
"num params encoder 50840\n",
|
247 |
+
"num params 21496282\n"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"name": "stderr",
|
252 |
+
"output_type": "stream",
|
253 |
+
"text": [
|
254 |
+
"100%|███████████████████████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"name": "stdout",
|
259 |
+
"output_type": "stream",
|
260 |
+
"text": [
|
261 |
+
"===========================\n",
|
262 |
+
"accuracy: 99.035\n",
|
263 |
+
"===========================\n"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"name": "stderr",
|
268 |
+
"output_type": "stream",
|
269 |
+
"text": [
|
270 |
+
"\n"
|
271 |
+
]
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
276 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
277 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
278 |
+
"from tqdm import tqdm\n",
|
279 |
+
"import torch\n",
|
280 |
+
"import numpy as np\n",
|
281 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
282 |
+
"import torch\n",
|
283 |
+
"import torch.nn as nn\n",
|
284 |
+
"import torch.optim as optim\n",
|
285 |
+
"from tqdm import tqdm \n",
|
286 |
+
"import torch.nn.functional as F\n",
|
287 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
288 |
+
"import pickle\n",
|
289 |
+
"\n",
|
290 |
+
"torch.manual_seed(1)\n",
|
291 |
+
"# torch.manual_seed(42)\n",
|
292 |
+
"\n",
|
293 |
+
"\n",
|
294 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
295 |
+
"num_gpus = torch.cuda.device_count()\n",
|
296 |
+
"print(num_gpus)\n",
|
297 |
+
"\n",
|
298 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
299 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
300 |
+
"\n",
|
301 |
+
"num_classes = 2\n",
|
302 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
303 |
+
"\n",
|
304 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
305 |
+
"model = nn.DataParallel(model)\n",
|
306 |
+
"model = model.to(device)\n",
|
307 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
308 |
+
"print(\"num params \",params)\n",
|
309 |
+
"\n",
|
310 |
+
"\n",
|
311 |
+
"model_1 = 'models_8/model-43-99.355_1.pt'\n",
|
312 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
313 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
314 |
+
"model = model.eval()\n",
|
315 |
+
"\n",
|
316 |
+
"# eval\n",
|
317 |
+
"val_loss = 0.0\n",
|
318 |
+
"correct_valid = 0\n",
|
319 |
+
"total = 0\n",
|
320 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
321 |
+
"model.eval()\n",
|
322 |
+
"with torch.no_grad():\n",
|
323 |
+
" for images, labels in tqdm(testloader):\n",
|
324 |
+
" inputs, labels = images.to(device), labels\n",
|
325 |
+
" outputs = model(inputs, return_mask = True)\n",
|
326 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
327 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
328 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
329 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
330 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
331 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
332 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
333 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
334 |
+
" total += labels[0].size(0)\n",
|
335 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
336 |
+
" \n",
|
337 |
+
"# Calculate training accuracy after each epoch\n",
|
338 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
339 |
+
"print(\"===========================\")\n",
|
340 |
+
"print('accuracy: ', val_accuracy)\n",
|
341 |
+
"print(\"===========================\")\n",
|
342 |
+
"\n",
|
343 |
+
"import pickle\n",
|
344 |
+
"\n",
|
345 |
+
"# Pickle the dictionary to a file\n",
|
346 |
+
"with open('models_8/test_1.pkl', 'wb') as f:\n",
|
347 |
+
" pickle.dump(results, f)"
|
348 |
+
]
|
349 |
+
}
|
350 |
+
],
|
351 |
+
"metadata": {
|
352 |
+
"kernelspec": {
|
353 |
+
"display_name": "Python 3 (ipykernel)",
|
354 |
+
"language": "python",
|
355 |
+
"name": "python3"
|
356 |
+
},
|
357 |
+
"language_info": {
|
358 |
+
"codemirror_mode": {
|
359 |
+
"name": "ipython",
|
360 |
+
"version": 3
|
361 |
+
},
|
362 |
+
"file_extension": ".py",
|
363 |
+
"mimetype": "text/x-python",
|
364 |
+
"name": "python",
|
365 |
+
"nbconvert_exporter": "python",
|
366 |
+
"pygments_lexer": "ipython3",
|
367 |
+
"version": "3.11.9"
|
368 |
+
}
|
369 |
+
},
|
370 |
+
"nbformat": 4,
|
371 |
+
"nbformat_minor": 5
|
372 |
+
}
|
models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"2\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:44<00:00, 1.09it/s]"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"===========================\n",
|
32 |
+
"accuracy: 98.94\n",
|
33 |
+
"===========================\n",
|
34 |
+
"False Positive Rate: 0.004\n",
|
35 |
+
"Precision: 0.996\n",
|
36 |
+
"Recall: 0.983\n",
|
37 |
+
"F1 Score: 0.989\n"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "stderr",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"\n"
|
45 |
+
]
|
46 |
+
}
|
47 |
+
],
|
48 |
+
"source": [
|
49 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
50 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
51 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
52 |
+
"from tqdm import tqdm\n",
|
53 |
+
"import torch\n",
|
54 |
+
"import numpy as np\n",
|
55 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
56 |
+
"import torch\n",
|
57 |
+
"import torch.nn as nn\n",
|
58 |
+
"import torch.optim as optim\n",
|
59 |
+
"from tqdm import tqdm \n",
|
60 |
+
"import torch.nn.functional as F\n",
|
61 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
62 |
+
"import pickle\n",
|
63 |
+
"\n",
|
64 |
+
"torch.manual_seed(1)\n",
|
65 |
+
"# torch.manual_seed(42)\n",
|
66 |
+
"\n",
|
67 |
+
"\n",
|
68 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
69 |
+
"num_gpus = torch.cuda.device_count()\n",
|
70 |
+
"print(num_gpus)\n",
|
71 |
+
"\n",
|
72 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
73 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
74 |
+
"\n",
|
75 |
+
"num_classes = 2\n",
|
76 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
77 |
+
"\n",
|
78 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
79 |
+
"model = nn.DataParallel(model)\n",
|
80 |
+
"model = model.to(device)\n",
|
81 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
82 |
+
"print(\"num params \",params)\n",
|
83 |
+
"\n",
|
84 |
+
"model_1 = 'models_8/model-25-99.31_7109.pt'\n",
|
85 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
86 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
87 |
+
"model = model.eval()\n",
|
88 |
+
"\n",
|
89 |
+
"# eval\n",
|
90 |
+
"val_loss = 0.0\n",
|
91 |
+
"correct_valid = 0\n",
|
92 |
+
"total = 0\n",
|
93 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
94 |
+
"model.eval()\n",
|
95 |
+
"with torch.no_grad():\n",
|
96 |
+
" for images, labels in tqdm(testloader):\n",
|
97 |
+
" inputs, labels = images.to(device), labels\n",
|
98 |
+
" outputs = model(inputs, return_mask = True)\n",
|
99 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
100 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
101 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
102 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
103 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
104 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
105 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
106 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
107 |
+
" total += labels[0].size(0)\n",
|
108 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
109 |
+
" \n",
|
110 |
+
" \n",
|
111 |
+
"# Calculate training accuracy after each epoch\n",
|
112 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
113 |
+
"print(\"===========================\")\n",
|
114 |
+
"print('accuracy: ', val_accuracy)\n",
|
115 |
+
"print(\"===========================\")\n",
|
116 |
+
"\n",
|
117 |
+
"import pickle\n",
|
118 |
+
"\n",
|
119 |
+
"# Pickle the dictionary to a file\n",
|
120 |
+
"with open('models_8/test_7109.pkl', 'wb') as f:\n",
|
121 |
+
" pickle.dump(results, f)\n",
|
122 |
+
"\n",
|
123 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
124 |
+
"from sklearn.metrics import confusion_matrix\n",
|
125 |
+
"\n",
|
126 |
+
"# Example binary labels\n",
|
127 |
+
"true = results['true'] # ground truth\n",
|
128 |
+
"pred = results['pred'] # predicted\n",
|
129 |
+
"\n",
|
130 |
+
"# Compute metrics\n",
|
131 |
+
"precision = precision_score(true, pred)\n",
|
132 |
+
"recall = recall_score(true, pred)\n",
|
133 |
+
"f1 = f1_score(true, pred)\n",
|
134 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
135 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
136 |
+
"\n",
|
137 |
+
"# Compute FPR\n",
|
138 |
+
"fpr = fp / (fp + tn)\n",
|
139 |
+
"\n",
|
140 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
141 |
+
"\n",
|
142 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
143 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
144 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": 3,
|
150 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [
|
153 |
+
{
|
154 |
+
"name": "stdout",
|
155 |
+
"output_type": "stream",
|
156 |
+
"text": [
|
157 |
+
"2\n",
|
158 |
+
"num params encoder 50840\n",
|
159 |
+
"num params 21496282\n"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"name": "stderr",
|
164 |
+
"output_type": "stream",
|
165 |
+
"text": [
|
166 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
167 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
168 |
+
"100%|███████████████████████████████████████████| 48/48 [00:41<00:00, 1.15it/s]"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"name": "stdout",
|
173 |
+
"output_type": "stream",
|
174 |
+
"text": [
|
175 |
+
"===========================\n",
|
176 |
+
"accuracy: 99.17\n",
|
177 |
+
"===========================\n",
|
178 |
+
"False Positive Rate: 0.004\n",
|
179 |
+
"Precision: 0.995\n",
|
180 |
+
"Recall: 0.988\n",
|
181 |
+
"F1 Score: 0.992\n"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"name": "stderr",
|
186 |
+
"output_type": "stream",
|
187 |
+
"text": [
|
188 |
+
"\n"
|
189 |
+
]
|
190 |
+
}
|
191 |
+
],
|
192 |
+
"source": [
|
193 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
194 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
195 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
196 |
+
"from tqdm import tqdm\n",
|
197 |
+
"import torch\n",
|
198 |
+
"import numpy as np\n",
|
199 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
200 |
+
"import torch\n",
|
201 |
+
"import torch.nn as nn\n",
|
202 |
+
"import torch.optim as optim\n",
|
203 |
+
"from tqdm import tqdm \n",
|
204 |
+
"import torch.nn.functional as F\n",
|
205 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
206 |
+
"import pickle\n",
|
207 |
+
"\n",
|
208 |
+
"torch.manual_seed(1)\n",
|
209 |
+
"# torch.manual_seed(42)\n",
|
210 |
+
"\n",
|
211 |
+
"\n",
|
212 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
213 |
+
"num_gpus = torch.cuda.device_count()\n",
|
214 |
+
"print(num_gpus)\n",
|
215 |
+
"\n",
|
216 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
217 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
218 |
+
"\n",
|
219 |
+
"num_classes = 2\n",
|
220 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
221 |
+
"\n",
|
222 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
223 |
+
"model = nn.DataParallel(model)\n",
|
224 |
+
"model = model.to(device)\n",
|
225 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
226 |
+
"print(\"num params \",params)\n",
|
227 |
+
"\n",
|
228 |
+
"\n",
|
229 |
+
"model_1 = 'models_8/model-44-99.445_42.pt'\n",
|
230 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
231 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
232 |
+
"model = model.eval()\n",
|
233 |
+
"\n",
|
234 |
+
"# eval\n",
|
235 |
+
"val_loss = 0.0\n",
|
236 |
+
"correct_valid = 0\n",
|
237 |
+
"total = 0\n",
|
238 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
239 |
+
"model.eval()\n",
|
240 |
+
"with torch.no_grad():\n",
|
241 |
+
" for images, labels in tqdm(testloader):\n",
|
242 |
+
" inputs, labels = images.to(device), labels\n",
|
243 |
+
" outputs = model(inputs, return_mask = True)\n",
|
244 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
245 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
246 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
247 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
248 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
249 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
250 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
251 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
252 |
+
" total += labels[0].size(0)\n",
|
253 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
254 |
+
" \n",
|
255 |
+
"# Calculate training accuracy after each epoch\n",
|
256 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
257 |
+
"print(\"===========================\")\n",
|
258 |
+
"print('accuracy: ', val_accuracy)\n",
|
259 |
+
"print(\"===========================\")\n",
|
260 |
+
"\n",
|
261 |
+
"import pickle\n",
|
262 |
+
"\n",
|
263 |
+
"# Pickle the dictionary to a file\n",
|
264 |
+
"with open('models_8/test_42.pkl', 'wb') as f:\n",
|
265 |
+
" pickle.dump(results, f)\n",
|
266 |
+
"\n",
|
267 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
268 |
+
"\n",
|
269 |
+
"# Example binary labels\n",
|
270 |
+
"true = results['true'] # ground truth\n",
|
271 |
+
"pred = results['pred'] # predicted\n",
|
272 |
+
"\n",
|
273 |
+
"# Compute metrics\n",
|
274 |
+
"precision = precision_score(true, pred)\n",
|
275 |
+
"recall = recall_score(true, pred)\n",
|
276 |
+
"f1 = f1_score(true, pred)\n",
|
277 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
278 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
279 |
+
"\n",
|
280 |
+
"# Compute FPR\n",
|
281 |
+
"fpr = fp / (fp + tn)\n",
|
282 |
+
"\n",
|
283 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
284 |
+
"\n",
|
285 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
286 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
287 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": 4,
|
293 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"name": "stdout",
|
298 |
+
"output_type": "stream",
|
299 |
+
"text": [
|
300 |
+
"2\n",
|
301 |
+
"num params encoder 50840\n",
|
302 |
+
"num params 21496282\n"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"name": "stderr",
|
307 |
+
"output_type": "stream",
|
308 |
+
"text": [
|
309 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
310 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
311 |
+
"100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.12s/it]"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"name": "stdout",
|
316 |
+
"output_type": "stream",
|
317 |
+
"text": [
|
318 |
+
"===========================\n",
|
319 |
+
"accuracy: 99.035\n",
|
320 |
+
"===========================\n",
|
321 |
+
"False Positive Rate: 0.010\n",
|
322 |
+
"Precision: 0.990\n",
|
323 |
+
"Recall: 0.990\n",
|
324 |
+
"F1 Score: 0.990\n"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"name": "stderr",
|
329 |
+
"output_type": "stream",
|
330 |
+
"text": [
|
331 |
+
"\n"
|
332 |
+
]
|
333 |
+
}
|
334 |
+
],
|
335 |
+
"source": [
|
336 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
337 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
338 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
339 |
+
"from tqdm import tqdm\n",
|
340 |
+
"import torch\n",
|
341 |
+
"import numpy as np\n",
|
342 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
343 |
+
"import torch\n",
|
344 |
+
"import torch.nn as nn\n",
|
345 |
+
"import torch.optim as optim\n",
|
346 |
+
"from tqdm import tqdm \n",
|
347 |
+
"import torch.nn.functional as F\n",
|
348 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
349 |
+
"import pickle\n",
|
350 |
+
"\n",
|
351 |
+
"torch.manual_seed(1)\n",
|
352 |
+
"# torch.manual_seed(42)\n",
|
353 |
+
"\n",
|
354 |
+
"\n",
|
355 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
356 |
+
"num_gpus = torch.cuda.device_count()\n",
|
357 |
+
"print(num_gpus)\n",
|
358 |
+
"\n",
|
359 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
360 |
+
"test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
|
361 |
+
"\n",
|
362 |
+
"num_classes = 2\n",
|
363 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
364 |
+
"\n",
|
365 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
366 |
+
"model = nn.DataParallel(model)\n",
|
367 |
+
"model = model.to(device)\n",
|
368 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
369 |
+
"print(\"num params \",params)\n",
|
370 |
+
"\n",
|
371 |
+
"\n",
|
372 |
+
"model_1 = 'models_8/model-43-99.355_1.pt'\n",
|
373 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
374 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
375 |
+
"model = model.eval()\n",
|
376 |
+
"\n",
|
377 |
+
"# eval\n",
|
378 |
+
"val_loss = 0.0\n",
|
379 |
+
"correct_valid = 0\n",
|
380 |
+
"total = 0\n",
|
381 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
382 |
+
"model.eval()\n",
|
383 |
+
"with torch.no_grad():\n",
|
384 |
+
" for images, labels in tqdm(testloader):\n",
|
385 |
+
" inputs, labels = images.to(device), labels\n",
|
386 |
+
" outputs = model(inputs, return_mask = True)\n",
|
387 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
388 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
389 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
390 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
391 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
392 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
393 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
394 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
395 |
+
" total += labels[0].size(0)\n",
|
396 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
397 |
+
" \n",
|
398 |
+
"# Calculate training accuracy after each epoch\n",
|
399 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
400 |
+
"print(\"===========================\")\n",
|
401 |
+
"print('accuracy: ', val_accuracy)\n",
|
402 |
+
"print(\"===========================\")\n",
|
403 |
+
"\n",
|
404 |
+
"import pickle\n",
|
405 |
+
"\n",
|
406 |
+
"# Pickle the dictionary to a file\n",
|
407 |
+
"with open('models_8/test_1.pkl', 'wb') as f:\n",
|
408 |
+
" pickle.dump(results, f)\n",
|
409 |
+
"\n",
|
410 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
411 |
+
"\n",
|
412 |
+
"# Example binary labels\n",
|
413 |
+
"true = results['true'] # ground truth\n",
|
414 |
+
"pred = results['pred'] # predicted\n",
|
415 |
+
"\n",
|
416 |
+
"# Compute metrics\n",
|
417 |
+
"precision = precision_score(true, pred)\n",
|
418 |
+
"recall = recall_score(true, pred)\n",
|
419 |
+
"f1 = f1_score(true, pred)\n",
|
420 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
421 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
422 |
+
"\n",
|
423 |
+
"# Compute FPR\n",
|
424 |
+
"fpr = fp / (fp + tn)\n",
|
425 |
+
"\n",
|
426 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
427 |
+
"\n",
|
428 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
429 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
430 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "code",
|
435 |
+
"execution_count": 7,
|
436 |
+
"id": "8444ced5-686c-4c6e-bb02-8c4a45ff8af9",
|
437 |
+
"metadata": {},
|
438 |
+
"outputs": [
|
439 |
+
{
|
440 |
+
"name": "stdout",
|
441 |
+
"output_type": "stream",
|
442 |
+
"text": [
|
443 |
+
"99.04833333333333 0.09436925111261553\n",
|
444 |
+
"0.9936666666666666 0.002624669291337273\n",
|
445 |
+
"0.9903333333333334 0.0012472191289246482\n",
|
446 |
+
"0.006000000000000001 0.00282842712474619\n"
|
447 |
+
]
|
448 |
+
}
|
449 |
+
],
|
450 |
+
"source": [
|
451 |
+
"# acc\n",
|
452 |
+
"print(np.mean([98.94,99.17, 99.035 ]), np.std([98.94,99.17, 99.035 ]))\n",
|
453 |
+
"# recall\n",
|
454 |
+
"print(np.mean([0.996,0.988, 0.990]), np.std([0.996,0.995, 0.990]))\n",
|
455 |
+
"# f1\n",
|
456 |
+
"print(np.mean([0.989,0.992,0.990 ]),np.std([0.989,0.992,0.990 ]))\n",
|
457 |
+
"# fp\n",
|
458 |
+
"print(np.mean([0.004,0.004,0.010]),np.std([0.004,0.004,0.010]))\n"
|
459 |
+
]
|
460 |
+
}
|
461 |
+
],
|
462 |
+
"metadata": {
|
463 |
+
"kernelspec": {
|
464 |
+
"display_name": "Python 3 (ipykernel)",
|
465 |
+
"language": "python",
|
466 |
+
"name": "python3"
|
467 |
+
},
|
468 |
+
"language_info": {
|
469 |
+
"codemirror_mode": {
|
470 |
+
"name": "ipython",
|
471 |
+
"version": 3
|
472 |
+
},
|
473 |
+
"file_extension": ".py",
|
474 |
+
"mimetype": "text/x-python",
|
475 |
+
"name": "python",
|
476 |
+
"nbconvert_exporter": "python",
|
477 |
+
"pygments_lexer": "ipython3",
|
478 |
+
"version": "3.11.9"
|
479 |
+
}
|
480 |
+
},
|
481 |
+
"nbformat": 4,
|
482 |
+
"nbformat_minor": 5
|
483 |
+
}
|
models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"6\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
" 8%|███▋ | 4/48 [02:22<25:11, 34.35s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7efcbc3f67d0>>\n",
|
25 |
+
"Traceback (most recent call last):\n",
|
26 |
+
" File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n",
|
27 |
+
" def _clean_thread_parent_frames(\n",
|
28 |
+
"\n",
|
29 |
+
"KeyboardInterrupt: \n",
|
30 |
+
" 10%|████▌ | 5/48 [02:53<24:56, 34.79s/it]\n"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"ename": "RuntimeError",
|
35 |
+
"evalue": "DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly",
|
36 |
+
"output_type": "error",
|
37 |
+
"traceback": [
|
38 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
39 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
40 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
|
41 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _ForkingPickler\u001b[38;5;241m.\u001b[39mloads(res)\n",
|
42 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/multiprocessing/reductions.py:496\u001b[0m, in \u001b[0;36mrebuild_storage_fd\u001b[0;34m(cls, df, size)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrebuild_storage_fd\u001b[39m(\u001b[38;5;28mcls\u001b[39m, df, size):\n\u001b[0;32m--> 496\u001b[0m fd \u001b[38;5;241m=\u001b[39m df\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
|
43 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:57\u001b[0m, in \u001b[0;36mDupFd.detach\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m'''Get the fd. This should only be called once.'''\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _resource_sharer\u001b[38;5;241m.\u001b[39mget_connection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_id) \u001b[38;5;28;01mas\u001b[39;00m conn:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m reduction\u001b[38;5;241m.\u001b[39mrecv_handle(conn)\n",
|
44 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:86\u001b[0m, in \u001b[0;36m_ResourceSharer.get_connection\u001b[0;34m(ident)\u001b[0m\n\u001b[1;32m 85\u001b[0m address, key \u001b[38;5;241m=\u001b[39m ident\n\u001b[0;32m---> 86\u001b[0m c \u001b[38;5;241m=\u001b[39m Client(address, authkey\u001b[38;5;241m=\u001b[39mprocess\u001b[38;5;241m.\u001b[39mcurrent_process()\u001b[38;5;241m.\u001b[39mauthkey)\n\u001b[1;32m 87\u001b[0m c\u001b[38;5;241m.\u001b[39msend((key, os\u001b[38;5;241m.\u001b[39mgetpid()))\n",
|
45 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:519\u001b[0m, in \u001b[0;36mClient\u001b[0;34m(address, family, authkey)\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 519\u001b[0m c \u001b[38;5;241m=\u001b[39m SocketClient(address)\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m authkey \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(authkey, \u001b[38;5;28mbytes\u001b[39m):\n",
|
46 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:647\u001b[0m, in \u001b[0;36mSocketClient\u001b[0;34m(address)\u001b[0m\n\u001b[1;32m 646\u001b[0m s\u001b[38;5;241m.\u001b[39msetblocking(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 647\u001b[0m s\u001b[38;5;241m.\u001b[39mconnect(address)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Connection(s\u001b[38;5;241m.\u001b[39mdetach())\n",
|
47 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory",
|
48 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
49 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
50 |
+
"Cell \u001b[0;32mIn[1], line 48\u001b[0m\n\u001b[1;32m 46\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[1;32m 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
51 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
|
52 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_data()\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
|
53 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_data()\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
|
54 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_try_get_data()\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
|
55 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1144\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1143\u001b[0m pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[0;32m-> 1144\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1145\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
|
56 |
+
"\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"source": [
|
61 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
62 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
63 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
64 |
+
"from tqdm import tqdm\n",
|
65 |
+
"import torch\n",
|
66 |
+
"import numpy as np\n",
|
67 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
68 |
+
"import torch\n",
|
69 |
+
"import torch.nn as nn\n",
|
70 |
+
"import torch.optim as optim\n",
|
71 |
+
"from tqdm import tqdm \n",
|
72 |
+
"import torch.nn.functional as F\n",
|
73 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
74 |
+
"import pickle\n",
|
75 |
+
"\n",
|
76 |
+
"torch.manual_seed(1)\n",
|
77 |
+
"# torch.manual_seed(42)\n",
|
78 |
+
"\n",
|
79 |
+
"\n",
|
80 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
81 |
+
"num_gpus = torch.cuda.device_count()\n",
|
82 |
+
"print(num_gpus)\n",
|
83 |
+
"\n",
|
84 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
85 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
86 |
+
"\n",
|
87 |
+
"num_classes = 2\n",
|
88 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
89 |
+
"\n",
|
90 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
91 |
+
"model = nn.DataParallel(model)\n",
|
92 |
+
"model = model.to(device)\n",
|
93 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
94 |
+
"print(\"num params \",params)\n",
|
95 |
+
"\n",
|
96 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
97 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
98 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
99 |
+
"model = model.eval()\n",
|
100 |
+
"\n",
|
101 |
+
"# eval\n",
|
102 |
+
"val_loss = 0.0\n",
|
103 |
+
"correct_valid = 0\n",
|
104 |
+
"total = 0\n",
|
105 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
106 |
+
"model.eval()\n",
|
107 |
+
"with torch.no_grad():\n",
|
108 |
+
" for images, labels in tqdm(testloader):\n",
|
109 |
+
" inputs, labels = images.to(device), labels\n",
|
110 |
+
" outputs = model(inputs, return_mask = True)\n",
|
111 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
112 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
113 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
114 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
115 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
116 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
117 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
118 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
119 |
+
" total += labels[0].size(0)\n",
|
120 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
121 |
+
" \n",
|
122 |
+
"# Calculate training accuracy after each epoch\n",
|
123 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
124 |
+
"print(\"===========================\")\n",
|
125 |
+
"print('accuracy: ', val_accuracy)\n",
|
126 |
+
"print(\"===========================\")\n",
|
127 |
+
"\n",
|
128 |
+
"import pickle\n",
|
129 |
+
"\n",
|
130 |
+
"# Pickle the dictionary to a file\n",
|
131 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
132 |
+
" pickle.dump(results, f)"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": null,
|
138 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
143 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
144 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
145 |
+
"from tqdm import tqdm\n",
|
146 |
+
"import torch\n",
|
147 |
+
"import numpy as np\n",
|
148 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
149 |
+
"import torch\n",
|
150 |
+
"import torch.nn as nn\n",
|
151 |
+
"import torch.optim as optim\n",
|
152 |
+
"from tqdm import tqdm \n",
|
153 |
+
"import torch.nn.functional as F\n",
|
154 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
155 |
+
"import pickle\n",
|
156 |
+
"\n",
|
157 |
+
"torch.manual_seed(1)\n",
|
158 |
+
"# torch.manual_seed(42)\n",
|
159 |
+
"\n",
|
160 |
+
"\n",
|
161 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
162 |
+
"num_gpus = torch.cuda.device_count()\n",
|
163 |
+
"print(num_gpus)\n",
|
164 |
+
"\n",
|
165 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
166 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
167 |
+
"\n",
|
168 |
+
"num_classes = 2\n",
|
169 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
170 |
+
"\n",
|
171 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
172 |
+
"model = nn.DataParallel(model)\n",
|
173 |
+
"model = model.to(device)\n",
|
174 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
175 |
+
"print(\"num params \",params)\n",
|
176 |
+
"\n",
|
177 |
+
"\n",
|
178 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
179 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
180 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
181 |
+
"model = model.eval()\n",
|
182 |
+
"\n",
|
183 |
+
"# eval\n",
|
184 |
+
"val_loss = 0.0\n",
|
185 |
+
"correct_valid = 0\n",
|
186 |
+
"total = 0\n",
|
187 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
188 |
+
"model.eval()\n",
|
189 |
+
"with torch.no_grad():\n",
|
190 |
+
" for images, labels in tqdm(testloader):\n",
|
191 |
+
" inputs, labels = images.to(device), labels\n",
|
192 |
+
" outputs = model(inputs, return_mask = True)\n",
|
193 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
194 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
195 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
196 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
197 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
198 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
199 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
200 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
201 |
+
" total += labels[0].size(0)\n",
|
202 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
203 |
+
" \n",
|
204 |
+
" \n",
|
205 |
+
"# Calculate training accuracy after each epoch\n",
|
206 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
207 |
+
"print(\"===========================\")\n",
|
208 |
+
"print('accuracy: ', val_accuracy)\n",
|
209 |
+
"print(\"===========================\")\n",
|
210 |
+
"\n",
|
211 |
+
"import pickle\n",
|
212 |
+
"\n",
|
213 |
+
"# Pickle the dictionary to a file\n",
|
214 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
215 |
+
" pickle.dump(results, f)"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"execution_count": null,
|
221 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
226 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
227 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
228 |
+
"from tqdm import tqdm\n",
|
229 |
+
"import torch\n",
|
230 |
+
"import numpy as np\n",
|
231 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
232 |
+
"import torch\n",
|
233 |
+
"import torch.nn as nn\n",
|
234 |
+
"import torch.optim as optim\n",
|
235 |
+
"from tqdm import tqdm \n",
|
236 |
+
"import torch.nn.functional as F\n",
|
237 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
238 |
+
"import pickle\n",
|
239 |
+
"\n",
|
240 |
+
"torch.manual_seed(1)\n",
|
241 |
+
"# torch.manual_seed(42)\n",
|
242 |
+
"\n",
|
243 |
+
"\n",
|
244 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
245 |
+
"num_gpus = torch.cuda.device_count()\n",
|
246 |
+
"print(num_gpus)\n",
|
247 |
+
"\n",
|
248 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
249 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
250 |
+
"\n",
|
251 |
+
"num_classes = 2\n",
|
252 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
253 |
+
"\n",
|
254 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
255 |
+
"model = nn.DataParallel(model)\n",
|
256 |
+
"model = model.to(device)\n",
|
257 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
258 |
+
"print(\"num params \",params)\n",
|
259 |
+
"\n",
|
260 |
+
"\n",
|
261 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
262 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
263 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
264 |
+
"model = model.eval()\n",
|
265 |
+
"\n",
|
266 |
+
"# eval\n",
|
267 |
+
"val_loss = 0.0\n",
|
268 |
+
"correct_valid = 0\n",
|
269 |
+
"total = 0\n",
|
270 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
271 |
+
"model.eval()\n",
|
272 |
+
"with torch.no_grad():\n",
|
273 |
+
" for images, labels in tqdm(testloader):\n",
|
274 |
+
" inputs, labels = images.to(device), labels\n",
|
275 |
+
" outputs = model(inputs, return_mask = True)\n",
|
276 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
277 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
278 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
279 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
280 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
281 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
282 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
283 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
284 |
+
" total += labels[0].size(0)\n",
|
285 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
286 |
+
" \n",
|
287 |
+
" \n",
|
288 |
+
"# Calculate training accuracy after each epoch\n",
|
289 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
290 |
+
"print(\"===========================\")\n",
|
291 |
+
"print('accuracy: ', val_accuracy)\n",
|
292 |
+
"print(\"===========================\")\n",
|
293 |
+
"\n",
|
294 |
+
"import pickle\n",
|
295 |
+
"\n",
|
296 |
+
"# Pickle the dictionary to a file\n",
|
297 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
298 |
+
" pickle.dump(results, f)"
|
299 |
+
]
|
300 |
+
}
|
301 |
+
],
|
302 |
+
"metadata": {
|
303 |
+
"kernelspec": {
|
304 |
+
"display_name": "Python 3 (ipykernel)",
|
305 |
+
"language": "python",
|
306 |
+
"name": "python3"
|
307 |
+
},
|
308 |
+
"language_info": {
|
309 |
+
"codemirror_mode": {
|
310 |
+
"name": "ipython",
|
311 |
+
"version": 3
|
312 |
+
},
|
313 |
+
"file_extension": ".py",
|
314 |
+
"mimetype": "text/x-python",
|
315 |
+
"name": "python",
|
316 |
+
"nbconvert_exporter": "python",
|
317 |
+
"pygments_lexer": "ipython3",
|
318 |
+
"version": "3.11.9"
|
319 |
+
}
|
320 |
+
},
|
321 |
+
"nbformat": 4,
|
322 |
+
"nbformat_minor": 5
|
323 |
+
}
|
models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 9,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"2\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
24 |
+
"100%|███████████████████████████████████████████| 48/48 [00:45<00:00, 1.07it/s]"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"===========================\n",
|
32 |
+
"accuracy: 99.125\n",
|
33 |
+
"===========================\n",
|
34 |
+
"Precision: 0.992\n",
|
35 |
+
"Recall: 0.991\n",
|
36 |
+
"F1 Score: 0.991\n"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"name": "stderr",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"\n"
|
44 |
+
]
|
45 |
+
}
|
46 |
+
],
|
47 |
+
"source": [
|
48 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
49 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
50 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
51 |
+
"from tqdm import tqdm\n",
|
52 |
+
"import torch\n",
|
53 |
+
"import numpy as np\n",
|
54 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
55 |
+
"import torch\n",
|
56 |
+
"import torch.nn as nn\n",
|
57 |
+
"import torch.optim as optim\n",
|
58 |
+
"from tqdm import tqdm \n",
|
59 |
+
"import torch.nn.functional as F\n",
|
60 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
61 |
+
"import pickle\n",
|
62 |
+
"\n",
|
63 |
+
"torch.manual_seed(1)\n",
|
64 |
+
"# torch.manual_seed(42)\n",
|
65 |
+
"\n",
|
66 |
+
"\n",
|
67 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
68 |
+
"num_gpus = torch.cuda.device_count()\n",
|
69 |
+
"print(num_gpus)\n",
|
70 |
+
"\n",
|
71 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
72 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
73 |
+
"\n",
|
74 |
+
"num_classes = 2\n",
|
75 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
76 |
+
"\n",
|
77 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
78 |
+
"model = nn.DataParallel(model)\n",
|
79 |
+
"model = model.to(device)\n",
|
80 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
81 |
+
"print(\"num params \",params)\n",
|
82 |
+
"\n",
|
83 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
84 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
85 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
86 |
+
"model = model.eval()\n",
|
87 |
+
"\n",
|
88 |
+
"# eval\n",
|
89 |
+
"val_loss = 0.0\n",
|
90 |
+
"correct_valid = 0\n",
|
91 |
+
"total = 0\n",
|
92 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
93 |
+
"model.eval()\n",
|
94 |
+
"with torch.no_grad():\n",
|
95 |
+
" for images, labels in tqdm(testloader):\n",
|
96 |
+
" inputs, labels = images.to(device), labels\n",
|
97 |
+
" outputs = model(inputs, return_mask = True)\n",
|
98 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
99 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
100 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
101 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
102 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
103 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
104 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
105 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
106 |
+
" total += labels[0].size(0)\n",
|
107 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
108 |
+
" \n",
|
109 |
+
"# Calculate training accuracy after each epoch\n",
|
110 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
111 |
+
"print(\"===========================\")\n",
|
112 |
+
"print('accuracy: ', val_accuracy)\n",
|
113 |
+
"print(\"===========================\")\n",
|
114 |
+
"\n",
|
115 |
+
"import pickle\n",
|
116 |
+
"\n",
|
117 |
+
"# Pickle the dictionary to a file\n",
|
118 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
119 |
+
" pickle.dump(results, f)\n",
|
120 |
+
"\n",
|
121 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
122 |
+
"\n",
|
123 |
+
"# Example binary labels\n",
|
124 |
+
"true = results['true'] # ground truth\n",
|
125 |
+
"pred = results['pred'] # predicted\n",
|
126 |
+
"\n",
|
127 |
+
"# Compute metrics\n",
|
128 |
+
"precision = precision_score(true, pred)\n",
|
129 |
+
"recall = recall_score(true, pred)\n",
|
130 |
+
"f1 = f1_score(true, pred)\n",
|
131 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
132 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
133 |
+
"\n",
|
134 |
+
"# Compute FPR\n",
|
135 |
+
"fpr = fp / (fp + tn)\n",
|
136 |
+
"\n",
|
137 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
138 |
+
"\n",
|
139 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
140 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
141 |
+
"print(f\"F1 Score: {f1:.3f}\")\n"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 10,
|
147 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [
|
150 |
+
{
|
151 |
+
"name": "stdout",
|
152 |
+
"output_type": "stream",
|
153 |
+
"text": [
|
154 |
+
"2\n",
|
155 |
+
"num params encoder 50840\n",
|
156 |
+
"num params 21496282\n"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"name": "stderr",
|
161 |
+
"output_type": "stream",
|
162 |
+
"text": [
|
163 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
164 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
165 |
+
"100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"name": "stdout",
|
170 |
+
"output_type": "stream",
|
171 |
+
"text": [
|
172 |
+
"===========================\n",
|
173 |
+
"accuracy: 98.77\n",
|
174 |
+
"===========================\n",
|
175 |
+
"Precision: 0.982\n",
|
176 |
+
"Recall: 0.993\n",
|
177 |
+
"F1 Score: 0.988\n"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"name": "stderr",
|
182 |
+
"output_type": "stream",
|
183 |
+
"text": [
|
184 |
+
"\n"
|
185 |
+
]
|
186 |
+
}
|
187 |
+
],
|
188 |
+
"source": [
|
189 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
190 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
191 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
192 |
+
"from tqdm import tqdm\n",
|
193 |
+
"import torch\n",
|
194 |
+
"import numpy as np\n",
|
195 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
196 |
+
"import torch\n",
|
197 |
+
"import torch.nn as nn\n",
|
198 |
+
"import torch.optim as optim\n",
|
199 |
+
"from tqdm import tqdm \n",
|
200 |
+
"import torch.nn.functional as F\n",
|
201 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
202 |
+
"import pickle\n",
|
203 |
+
"\n",
|
204 |
+
"torch.manual_seed(1)\n",
|
205 |
+
"\n",
|
206 |
+
"\n",
|
207 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
208 |
+
"num_gpus = torch.cuda.device_count()\n",
|
209 |
+
"print(num_gpus)\n",
|
210 |
+
"\n",
|
211 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
212 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
213 |
+
"\n",
|
214 |
+
"num_classes = 2\n",
|
215 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
216 |
+
"\n",
|
217 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
218 |
+
"model = nn.DataParallel(model)\n",
|
219 |
+
"model = model.to(device)\n",
|
220 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
221 |
+
"print(\"num params \",params)\n",
|
222 |
+
"\n",
|
223 |
+
"\n",
|
224 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
225 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
226 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
227 |
+
"model = model.eval()\n",
|
228 |
+
"\n",
|
229 |
+
"# eval\n",
|
230 |
+
"val_loss = 0.0\n",
|
231 |
+
"correct_valid = 0\n",
|
232 |
+
"total = 0\n",
|
233 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
234 |
+
"model.eval()\n",
|
235 |
+
"with torch.no_grad():\n",
|
236 |
+
" for images, labels in tqdm(testloader):\n",
|
237 |
+
" inputs, labels = images.to(device), labels\n",
|
238 |
+
" outputs = model(inputs, return_mask = True)\n",
|
239 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
240 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
241 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
242 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
243 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
244 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
245 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
246 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
247 |
+
" total += labels[0].size(0)\n",
|
248 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
249 |
+
" \n",
|
250 |
+
" \n",
|
251 |
+
"# Calculate training accuracy after each epoch\n",
|
252 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
253 |
+
"print(\"===========================\")\n",
|
254 |
+
"print('accuracy: ', val_accuracy)\n",
|
255 |
+
"print(\"===========================\")\n",
|
256 |
+
"\n",
|
257 |
+
"import pickle\n",
|
258 |
+
"\n",
|
259 |
+
"# Pickle the dictionary to a file\n",
|
260 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
261 |
+
" pickle.dump(results, f)\n",
|
262 |
+
"\n",
|
263 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
264 |
+
"\n",
|
265 |
+
"# Example binary labels\n",
|
266 |
+
"true = results['true'] # ground truth\n",
|
267 |
+
"pred = results['pred'] # predicted\n",
|
268 |
+
"\n",
|
269 |
+
"# Compute metrics\n",
|
270 |
+
"precision = precision_score(true, pred)\n",
|
271 |
+
"recall = recall_score(true, pred)\n",
|
272 |
+
"f1 = f1_score(true, pred)\n",
|
273 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
274 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
275 |
+
"\n",
|
276 |
+
"# Compute FPR\n",
|
277 |
+
"fpr = fp / (fp + tn)\n",
|
278 |
+
"\n",
|
279 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
280 |
+
"\n",
|
281 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
282 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
283 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 11,
|
289 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"2\n",
|
297 |
+
"num params encoder 50840\n",
|
298 |
+
"num params 21496282\n"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"name": "stderr",
|
303 |
+
"output_type": "stream",
|
304 |
+
"text": [
|
305 |
+
" 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
306 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
307 |
+
"100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"name": "stdout",
|
312 |
+
"output_type": "stream",
|
313 |
+
"text": [
|
314 |
+
"===========================\n",
|
315 |
+
"accuracy: 99.03\n",
|
316 |
+
"===========================\n",
|
317 |
+
"Precision: 0.990\n",
|
318 |
+
"Recall: 0.990\n",
|
319 |
+
"F1 Score: 0.990\n"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"name": "stderr",
|
324 |
+
"output_type": "stream",
|
325 |
+
"text": [
|
326 |
+
"\n"
|
327 |
+
]
|
328 |
+
}
|
329 |
+
],
|
330 |
+
"source": [
|
331 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
332 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
333 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
334 |
+
"from tqdm import tqdm\n",
|
335 |
+
"import torch\n",
|
336 |
+
"import numpy as np\n",
|
337 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
338 |
+
"import torch\n",
|
339 |
+
"import torch.nn as nn\n",
|
340 |
+
"import torch.optim as optim\n",
|
341 |
+
"from tqdm import tqdm \n",
|
342 |
+
"import torch.nn.functional as F\n",
|
343 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
344 |
+
"import pickle\n",
|
345 |
+
"\n",
|
346 |
+
"torch.manual_seed(1)\n",
|
347 |
+
"# torch.manual_seed(42)\n",
|
348 |
+
"\n",
|
349 |
+
"\n",
|
350 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
351 |
+
"num_gpus = torch.cuda.device_count()\n",
|
352 |
+
"print(num_gpus)\n",
|
353 |
+
"\n",
|
354 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
355 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
356 |
+
"\n",
|
357 |
+
"num_classes = 2\n",
|
358 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
359 |
+
"\n",
|
360 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
361 |
+
"model = nn.DataParallel(model)\n",
|
362 |
+
"model = model.to(device)\n",
|
363 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
364 |
+
"print(\"num params \",params)\n",
|
365 |
+
"\n",
|
366 |
+
"\n",
|
367 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
368 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
369 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
370 |
+
"model = model.eval()\n",
|
371 |
+
"\n",
|
372 |
+
"# eval\n",
|
373 |
+
"val_loss = 0.0\n",
|
374 |
+
"correct_valid = 0\n",
|
375 |
+
"total = 0\n",
|
376 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
377 |
+
"model.eval()\n",
|
378 |
+
"with torch.no_grad():\n",
|
379 |
+
" for images, labels in tqdm(testloader):\n",
|
380 |
+
" inputs, labels = images.to(device), labels\n",
|
381 |
+
" outputs = model(inputs, return_mask = True)\n",
|
382 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
383 |
+
" results['output'].extend(outputs.cpu().numpy().tolist())\n",
|
384 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
385 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
386 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
387 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
388 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
389 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
390 |
+
" total += labels[0].size(0)\n",
|
391 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
392 |
+
" \n",
|
393 |
+
" \n",
|
394 |
+
"# Calculate training accuracy after each epoch\n",
|
395 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
396 |
+
"print(\"===========================\")\n",
|
397 |
+
"print('accuracy: ', val_accuracy)\n",
|
398 |
+
"print(\"===========================\")\n",
|
399 |
+
"\n",
|
400 |
+
"import pickle\n",
|
401 |
+
"\n",
|
402 |
+
"# Pickle the dictionary to a file\n",
|
403 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
404 |
+
" pickle.dump(results, f)\n",
|
405 |
+
"\n",
|
406 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
407 |
+
"\n",
|
408 |
+
"# Example binary labels\n",
|
409 |
+
"true = results['true'] # ground truth\n",
|
410 |
+
"pred = results['pred'] # predicted\n",
|
411 |
+
"\n",
|
412 |
+
"# Compute metrics\n",
|
413 |
+
"precision = precision_score(true, pred)\n",
|
414 |
+
"recall = recall_score(true, pred)\n",
|
415 |
+
"f1 = f1_score(true, pred)\n",
|
416 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
417 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
418 |
+
"\n",
|
419 |
+
"# Compute FPR\n",
|
420 |
+
"fpr = fp / (fp + tn)\n",
|
421 |
+
"\n",
|
422 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
423 |
+
"\n",
|
424 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
425 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
426 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
427 |
+
]
|
428 |
+
},
|
429 |
+
{
|
430 |
+
"cell_type": "code",
|
431 |
+
"execution_count": 17,
|
432 |
+
"id": "974e62d6-5088-4cd8-9721-6702717eadee",
|
433 |
+
"metadata": {},
|
434 |
+
"outputs": [
|
435 |
+
{
|
436 |
+
"name": "stdout",
|
437 |
+
"output_type": "stream",
|
438 |
+
"text": [
|
439 |
+
"98.97499999999998 0.1500555452713003\n",
|
440 |
+
"0.9913333333333333 0.0012472191289246482\n",
|
441 |
+
"0.9896666666666666 0.0012472191289246482\n"
|
442 |
+
]
|
443 |
+
}
|
444 |
+
],
|
445 |
+
"source": [
|
446 |
+
"# acc\n",
|
447 |
+
"print(np.mean([99.125,98.77, 99.03 ]), np.std([99.125,98.77, 99.03 ]))\n",
|
448 |
+
"# precision\n",
|
449 |
+
"print(np.mean([0.991,0.990, 0.993]), np.std([0.991,0.990, 0.993]))\n",
|
450 |
+
"# f1\n",
|
451 |
+
"print(np.mean([0.990,0.988,0.991 ]),np.std([0.990,0.988,0.991 ]))\n",
|
452 |
+
"# recall"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"cell_type": "code",
|
457 |
+
"execution_count": 18,
|
458 |
+
"id": "3eee97ad-114f-4090-b54a-6ec0cc7150f5",
|
459 |
+
"metadata": {},
|
460 |
+
"outputs": [
|
461 |
+
{
|
462 |
+
"name": "stdout",
|
463 |
+
"output_type": "stream",
|
464 |
+
"text": [
|
465 |
+
"False Positive Rate: 0.200\n"
|
466 |
+
]
|
467 |
+
}
|
468 |
+
],
|
469 |
+
"source": [
|
470 |
+
"from sklearn.metrics import confusion_matrix\n",
|
471 |
+
"\n",
|
472 |
+
"# Ground truth and predictions\n",
|
473 |
+
"true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 0]\n",
|
474 |
+
"pred = [1, 0, 1, 0, 0, 1, 1, 0, 1, 0]\n",
|
475 |
+
"\n"
|
476 |
+
]
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"metadata": {
|
480 |
+
"kernelspec": {
|
481 |
+
"display_name": "Python 3 (ipykernel)",
|
482 |
+
"language": "python",
|
483 |
+
"name": "python3"
|
484 |
+
},
|
485 |
+
"language_info": {
|
486 |
+
"codemirror_mode": {
|
487 |
+
"name": "ipython",
|
488 |
+
"version": 3
|
489 |
+
},
|
490 |
+
"file_extension": ".py",
|
491 |
+
"mimetype": "text/x-python",
|
492 |
+
"name": "python",
|
493 |
+
"nbconvert_exporter": "python",
|
494 |
+
"pygments_lexer": "ipython3",
|
495 |
+
"version": "3.11.9"
|
496 |
+
}
|
497 |
+
},
|
498 |
+
"nbformat": 4,
|
499 |
+
"nbformat_minor": 5
|
500 |
+
}
|
models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"2\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
23 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "stdout",
|
28 |
+
"output_type": "stream",
|
29 |
+
"text": [
|
30 |
+
"===========================\n",
|
31 |
+
"accuracy: 99.195\n",
|
32 |
+
"===========================\n",
|
33 |
+
"False Positive Rate: 0.005\n",
|
34 |
+
"Precision: 0.995\n",
|
35 |
+
"Recall: 0.989\n",
|
36 |
+
"F1 Score: 0.992\n"
|
37 |
+
]
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"source": [
|
41 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
42 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
43 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
44 |
+
"from tqdm import tqdm\n",
|
45 |
+
"import torch\n",
|
46 |
+
"import numpy as np\n",
|
47 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
48 |
+
"import torch\n",
|
49 |
+
"import torch.nn as nn\n",
|
50 |
+
"import torch.optim as optim\n",
|
51 |
+
"from tqdm import tqdm \n",
|
52 |
+
"import torch.nn.functional as F\n",
|
53 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
54 |
+
"import pickle\n",
|
55 |
+
"\n",
|
56 |
+
"torch.manual_seed(1)\n",
|
57 |
+
"# torch.manual_seed(42)\n",
|
58 |
+
"\n",
|
59 |
+
"\n",
|
60 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
61 |
+
"num_gpus = torch.cuda.device_count()\n",
|
62 |
+
"print(num_gpus)\n",
|
63 |
+
"threshold = 0.992\n",
|
64 |
+
"\n",
|
65 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
66 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
67 |
+
"\n",
|
68 |
+
"num_classes = 2\n",
|
69 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
70 |
+
"\n",
|
71 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
72 |
+
"model = nn.DataParallel(model)\n",
|
73 |
+
"model = model.to(device)\n",
|
74 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
75 |
+
"print(\"num params \",params)\n",
|
76 |
+
"\n",
|
77 |
+
"model_1 = 'models_mask/model-43-99.235_42.pt'\n",
|
78 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
79 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
80 |
+
"model = model.eval()\n",
|
81 |
+
"\n",
|
82 |
+
"# eval\n",
|
83 |
+
"val_loss = 0.0\n",
|
84 |
+
"correct_valid = 0\n",
|
85 |
+
"total = 0\n",
|
86 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
87 |
+
"model.eval()\n",
|
88 |
+
"with torch.no_grad():\n",
|
89 |
+
" for images, labels in testloader:\n",
|
90 |
+
" inputs, labels = images.to(device), labels\n",
|
91 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
92 |
+
" selection = outputs[:, 1] > threshold\n",
|
93 |
+
" predicted = selection.int()\n",
|
94 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
95 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
96 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
97 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
98 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
99 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
100 |
+
" total += labels[0].size(0)\n",
|
101 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
102 |
+
" \n",
|
103 |
+
"# Calculate training accuracy after each epoch\n",
|
104 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
105 |
+
"print(\"===========================\")\n",
|
106 |
+
"print('accuracy: ', val_accuracy)\n",
|
107 |
+
"print(\"===========================\")\n",
|
108 |
+
"\n",
|
109 |
+
"import pickle\n",
|
110 |
+
"\n",
|
111 |
+
"# Pickle the dictionary to a file\n",
|
112 |
+
"with open('models_mask/test_42.pkl', 'wb') as f:\n",
|
113 |
+
" pickle.dump(results, f)\n",
|
114 |
+
"\n",
|
115 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
116 |
+
"from sklearn.metrics import confusion_matrix\n",
|
117 |
+
"\n",
|
118 |
+
"# Example binary labels\n",
|
119 |
+
"true = results['true'] # ground truth\n",
|
120 |
+
"pred = results['pred'] # predicted\n",
|
121 |
+
"\n",
|
122 |
+
"# Compute metrics\n",
|
123 |
+
"precision = precision_score(true, pred)\n",
|
124 |
+
"recall = recall_score(true, pred)\n",
|
125 |
+
"f1 = f1_score(true, pred)\n",
|
126 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
127 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
128 |
+
"\n",
|
129 |
+
"# Compute FPR\n",
|
130 |
+
"fpr = fp / (fp + tn)\n",
|
131 |
+
"\n",
|
132 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
133 |
+
"\n",
|
134 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
135 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
136 |
+
"print(f\"F1 Score: {f1:.3f}\")\n"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 2,
|
142 |
+
"id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"name": "stdout",
|
147 |
+
"output_type": "stream",
|
148 |
+
"text": [
|
149 |
+
"2\n",
|
150 |
+
"num params encoder 50840\n",
|
151 |
+
"num params 21496282\n"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"name": "stderr",
|
156 |
+
"output_type": "stream",
|
157 |
+
"text": [
|
158 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
159 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"name": "stdout",
|
164 |
+
"output_type": "stream",
|
165 |
+
"text": [
|
166 |
+
"===========================\n",
|
167 |
+
"accuracy: 99.195\n",
|
168 |
+
"===========================\n",
|
169 |
+
"False Positive Rate: 0.007\n",
|
170 |
+
"Precision: 0.993\n",
|
171 |
+
"Recall: 0.991\n",
|
172 |
+
"F1 Score: 0.992\n"
|
173 |
+
]
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
178 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
179 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
180 |
+
"from tqdm import tqdm\n",
|
181 |
+
"import torch\n",
|
182 |
+
"import numpy as np\n",
|
183 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
184 |
+
"import torch\n",
|
185 |
+
"import torch.nn as nn\n",
|
186 |
+
"import torch.optim as optim\n",
|
187 |
+
"from tqdm import tqdm \n",
|
188 |
+
"import torch.nn.functional as F\n",
|
189 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
190 |
+
"import pickle\n",
|
191 |
+
"\n",
|
192 |
+
"torch.manual_seed(1)\n",
|
193 |
+
"\n",
|
194 |
+
"\n",
|
195 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
196 |
+
"num_gpus = torch.cuda.device_count()\n",
|
197 |
+
"print(num_gpus)\n",
|
198 |
+
"\n",
|
199 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
200 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
201 |
+
"\n",
|
202 |
+
"num_classes = 2\n",
|
203 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
204 |
+
"\n",
|
205 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
206 |
+
"model = nn.DataParallel(model)\n",
|
207 |
+
"model = model.to(device)\n",
|
208 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
209 |
+
"print(\"num params \",params)\n",
|
210 |
+
"\n",
|
211 |
+
"\n",
|
212 |
+
"model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
|
213 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
214 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
215 |
+
"model = model.eval()\n",
|
216 |
+
"\n",
|
217 |
+
"# eval\n",
|
218 |
+
"val_loss = 0.0\n",
|
219 |
+
"correct_valid = 0\n",
|
220 |
+
"total = 0\n",
|
221 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
222 |
+
"model.eval()\n",
|
223 |
+
"with torch.no_grad():\n",
|
224 |
+
" for images, labels in testloader:\n",
|
225 |
+
" inputs, labels = images.to(device), labels\n",
|
226 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
227 |
+
" selection = outputs[:, 1] > threshold\n",
|
228 |
+
" predicted = selection.int()\n",
|
229 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
230 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
231 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
232 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
233 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
234 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
235 |
+
" total += labels[0].size(0)\n",
|
236 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
237 |
+
" \n",
|
238 |
+
" \n",
|
239 |
+
"# Calculate training accuracy after each epoch\n",
|
240 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
241 |
+
"print(\"===========================\")\n",
|
242 |
+
"print('accuracy: ', val_accuracy)\n",
|
243 |
+
"print(\"===========================\")\n",
|
244 |
+
"\n",
|
245 |
+
"import pickle\n",
|
246 |
+
"\n",
|
247 |
+
"# Pickle the dictionary to a file\n",
|
248 |
+
"with open('models_mask/test_1.pkl', 'wb') as f:\n",
|
249 |
+
" pickle.dump(results, f)\n",
|
250 |
+
"\n",
|
251 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
252 |
+
"\n",
|
253 |
+
"# Example binary labels\n",
|
254 |
+
"true = results['true'] # ground truth\n",
|
255 |
+
"pred = results['pred'] # predicted\n",
|
256 |
+
"\n",
|
257 |
+
"# Compute metrics\n",
|
258 |
+
"precision = precision_score(true, pred)\n",
|
259 |
+
"recall = recall_score(true, pred)\n",
|
260 |
+
"f1 = f1_score(true, pred)\n",
|
261 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
262 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
263 |
+
"\n",
|
264 |
+
"# Compute FPR\n",
|
265 |
+
"fpr = fp / (fp + tn)\n",
|
266 |
+
"\n",
|
267 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
268 |
+
"\n",
|
269 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
270 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
271 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": 3,
|
277 |
+
"id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
|
278 |
+
"metadata": {},
|
279 |
+
"outputs": [
|
280 |
+
{
|
281 |
+
"name": "stdout",
|
282 |
+
"output_type": "stream",
|
283 |
+
"text": [
|
284 |
+
"2\n",
|
285 |
+
"num params encoder 50840\n",
|
286 |
+
"num params 21496282\n"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"name": "stderr",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
294 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"name": "stdout",
|
299 |
+
"output_type": "stream",
|
300 |
+
"text": [
|
301 |
+
"===========================\n",
|
302 |
+
"accuracy: 99.035\n",
|
303 |
+
"===========================\n",
|
304 |
+
"False Positive Rate: 0.007\n",
|
305 |
+
"Precision: 0.993\n",
|
306 |
+
"Recall: 0.987\n",
|
307 |
+
"F1 Score: 0.990\n"
|
308 |
+
]
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
313 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
314 |
+
"from utils import CustomDataset, TestingDataset, transform\n",
|
315 |
+
"from tqdm import tqdm\n",
|
316 |
+
"import torch\n",
|
317 |
+
"import numpy as np\n",
|
318 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
319 |
+
"import torch\n",
|
320 |
+
"import torch.nn as nn\n",
|
321 |
+
"import torch.optim as optim\n",
|
322 |
+
"from tqdm import tqdm \n",
|
323 |
+
"import torch.nn.functional as F\n",
|
324 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
325 |
+
"import pickle\n",
|
326 |
+
"\n",
|
327 |
+
"torch.manual_seed(1)\n",
|
328 |
+
"# torch.manual_seed(42)\n",
|
329 |
+
"\n",
|
330 |
+
"\n",
|
331 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
332 |
+
"num_gpus = torch.cuda.device_count()\n",
|
333 |
+
"print(num_gpus)\n",
|
334 |
+
"\n",
|
335 |
+
"test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
|
336 |
+
"test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
|
337 |
+
"\n",
|
338 |
+
"num_classes = 2\n",
|
339 |
+
"testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
340 |
+
"\n",
|
341 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
342 |
+
"model = nn.DataParallel(model)\n",
|
343 |
+
"model = model.to(device)\n",
|
344 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
345 |
+
"print(\"num params \",params)\n",
|
346 |
+
"\n",
|
347 |
+
"\n",
|
348 |
+
"model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
|
349 |
+
"# model_1 ='models/model-47-99.125.pt'\n",
|
350 |
+
"model.load_state_dict(torch.load(model_1, weights_only=True))\n",
|
351 |
+
"model = model.eval()\n",
|
352 |
+
"\n",
|
353 |
+
"# eval\n",
|
354 |
+
"val_loss = 0.0\n",
|
355 |
+
"correct_valid = 0\n",
|
356 |
+
"total = 0\n",
|
357 |
+
"results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
|
358 |
+
"model.eval()\n",
|
359 |
+
"with torch.no_grad():\n",
|
360 |
+
" for images, labels in testloader:\n",
|
361 |
+
" inputs, labels = images.to(device), labels\n",
|
362 |
+
" outputs = nn.Softmax(dim = 1)(model(inputs))\n",
|
363 |
+
" selection = outputs[:, 1] > threshold\n",
|
364 |
+
" predicted = selection.int()\n",
|
365 |
+
" results['pred'].extend(predicted.cpu().numpy().tolist())\n",
|
366 |
+
" results['true'].extend(labels[0].cpu().numpy().tolist())\n",
|
367 |
+
" results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
|
368 |
+
" results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
|
369 |
+
" results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
|
370 |
+
" results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
|
371 |
+
" total += labels[0].size(0)\n",
|
372 |
+
" correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
|
373 |
+
" \n",
|
374 |
+
" \n",
|
375 |
+
"# Calculate training accuracy after each epoch\n",
|
376 |
+
"val_accuracy = correct_valid / total * 100.0\n",
|
377 |
+
"print(\"===========================\")\n",
|
378 |
+
"print('accuracy: ', val_accuracy)\n",
|
379 |
+
"print(\"===========================\")\n",
|
380 |
+
"\n",
|
381 |
+
"import pickle\n",
|
382 |
+
"\n",
|
383 |
+
"# Pickle the dictionary to a file\n",
|
384 |
+
"with open('models_mask/test_7109.pkl', 'wb') as f:\n",
|
385 |
+
" pickle.dump(results, f)\n",
|
386 |
+
"\n",
|
387 |
+
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
388 |
+
"\n",
|
389 |
+
"# Example binary labels\n",
|
390 |
+
"true = results['true'] # ground truth\n",
|
391 |
+
"pred = results['pred'] # predicted\n",
|
392 |
+
"\n",
|
393 |
+
"# Compute metrics\n",
|
394 |
+
"precision = precision_score(true, pred)\n",
|
395 |
+
"recall = recall_score(true, pred)\n",
|
396 |
+
"f1 = f1_score(true, pred)\n",
|
397 |
+
"# Get confusion matrix: TN, FP, FN, TP\n",
|
398 |
+
"tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
|
399 |
+
"\n",
|
400 |
+
"# Compute FPR\n",
|
401 |
+
"fpr = fp / (fp + tn)\n",
|
402 |
+
"\n",
|
403 |
+
"print(f\"False Positive Rate: {fpr:.3f}\")\n",
|
404 |
+
"\n",
|
405 |
+
"print(f\"Precision: {precision:.3f}\")\n",
|
406 |
+
"print(f\"Recall: {recall:.3f}\")\n",
|
407 |
+
"print(f\"F1 Score: {f1:.3f}\")"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": 6,
|
413 |
+
"id": "974e62d6-5088-4cd8-9721-6702717eadee",
|
414 |
+
"metadata": {},
|
415 |
+
"outputs": [
|
416 |
+
{
|
417 |
+
"name": "stdout",
|
418 |
+
"output_type": "stream",
|
419 |
+
"text": [
|
420 |
+
"99.14166666666665 0.07542472332656346\n",
|
421 |
+
"0.9913333333333333 0.0012472191289246482\n",
|
422 |
+
"0.9913333333333334 0.0012472191289246482\n",
|
423 |
+
"0.006333333333333333 0.0009428090415820634\n"
|
424 |
+
]
|
425 |
+
}
|
426 |
+
],
|
427 |
+
"source": [
|
428 |
+
"# acc\n",
|
429 |
+
"print(np.mean([99.195,99.195, 99.035 ]), np.std([99.195,99.195, 99.035]))\n",
|
430 |
+
"# recall\n",
|
431 |
+
"print(np.mean([0.991,0.991, 0.987]), np.std([0.991,0.990, 0.993]))\n",
|
432 |
+
"# f1\n",
|
433 |
+
"print(np.mean([0.992,0.992,0.990 ]),np.std([0.990,0.988,0.991 ]))\n",
|
434 |
+
"# fp\n",
|
435 |
+
"print(np.mean([0.005,0.007,0.007 ]),np.std([0.005,0.007,0.007]))\n"
|
436 |
+
]
|
437 |
+
}
|
438 |
+
],
|
439 |
+
"metadata": {
|
440 |
+
"kernelspec": {
|
441 |
+
"display_name": "Python 3 (ipykernel)",
|
442 |
+
"language": "python",
|
443 |
+
"name": "python3"
|
444 |
+
},
|
445 |
+
"language_info": {
|
446 |
+
"codemirror_mode": {
|
447 |
+
"name": "ipython",
|
448 |
+
"version": 3
|
449 |
+
},
|
450 |
+
"file_extension": ".py",
|
451 |
+
"mimetype": "text/x-python",
|
452 |
+
"name": "python",
|
453 |
+
"nbconvert_exporter": "python",
|
454 |
+
"pygments_lexer": "ipython3",
|
455 |
+
"version": "3.11.9"
|
456 |
+
}
|
457 |
+
},
|
458 |
+
"nbformat": 4,
|
459 |
+
"nbformat_minor": 5
|
460 |
+
}
|
models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"6\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
|
21 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
22 |
+
"import torch\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
25 |
+
"import torch\n",
|
26 |
+
"import torch.nn as nn\n",
|
27 |
+
"import torch.optim as optim\n",
|
28 |
+
"from tqdm import tqdm \n",
|
29 |
+
"import torch.nn.functional as F\n",
|
30 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
31 |
+
"import pickle\n",
|
32 |
+
"\n",
|
33 |
+
"torch.manual_seed(1)\n",
|
34 |
+
"# torch.manual_seed(42)\n",
|
35 |
+
"\n",
|
36 |
+
"\n",
|
37 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
38 |
+
"num_gpus = torch.cuda.device_count()\n",
|
39 |
+
"print(num_gpus)\n",
|
40 |
+
"\n",
|
41 |
+
"# Create custom dataset instance\n",
|
42 |
+
"# Create custom dataset instance\n",
|
43 |
+
"data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
|
44 |
+
"dataset = CustomDataset(data_dir, transform=transform)\n",
|
45 |
+
"valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
|
46 |
+
"valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
|
47 |
+
"\n",
|
48 |
+
"\n",
|
49 |
+
"num_classes = 2\n",
|
50 |
+
"trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
51 |
+
"\n",
|
52 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
53 |
+
"model = nn.DataParallel(model)\n",
|
54 |
+
"model = model.to(device)\n",
|
55 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
56 |
+
"print(\"num params \",params)\n"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 2,
|
62 |
+
"id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"name": "stderr",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
" 0%| | 0/477 [00:00<?, ?batch/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
70 |
+
" with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
|
71 |
+
"100%|██████████████████████████████████████| 477/477 [08:57<00:00, 1.13s/batch]\n"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"ename": "NameError",
|
76 |
+
"evalue": "name 'validloader' is not defined",
|
77 |
+
"output_type": "error",
|
78 |
+
"traceback": [
|
79 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
80 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
81 |
+
"Cell \u001b[0;32mIn[2], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m validloader:\n\u001b[1;32m 30\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\u001b[38;5;241m.\u001b[39mto(device)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 31\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
|
82 |
+
"\u001b[0;31mNameError\u001b[0m: name 'validloader' is not defined"
|
83 |
+
]
|
84 |
+
}
|
85 |
+
],
|
86 |
+
"source": [
|
87 |
+
"criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))\n",
|
88 |
+
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
|
89 |
+
"scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)\n",
|
90 |
+
"\n",
|
91 |
+
"for epoch in range(5):\n",
|
92 |
+
" running_loss = 0.0\n",
|
93 |
+
" correct_train = 0\n",
|
94 |
+
" total_train = 0\n",
|
95 |
+
" with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
|
96 |
+
" model.train()\n",
|
97 |
+
" for i, (images, labels) in enumerate(tepoch):\n",
|
98 |
+
" inputs, labels = images.to(device), labels.to(device).float()\n",
|
99 |
+
" optimizer.zero_grad()\n",
|
100 |
+
" outputs = model(inputs, return_mask=False).to(device)\n",
|
101 |
+
" new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)\n",
|
102 |
+
" loss = criterion(outputs, new_label)\n",
|
103 |
+
" loss.backward()\n",
|
104 |
+
" optimizer.step()\n",
|
105 |
+
" running_loss += loss.item()\n",
|
106 |
+
" # Calculate training accuracy\n",
|
107 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
108 |
+
" total_train += labels.size(0)\n",
|
109 |
+
" correct_train += (predicted == labels).sum().item() \n",
|
110 |
+
" val_loss = 0.0\n",
|
111 |
+
" correct_valid = 0\n",
|
112 |
+
" total = 0\n",
|
113 |
+
" model.eval()\n",
|
114 |
+
" with torch.no_grad():\n",
|
115 |
+
" for images, labels in validloader:\n",
|
116 |
+
" inputs, labels = images.to(device), labels.to(device).float()\n",
|
117 |
+
" optimizer.zero_grad()\n",
|
118 |
+
" outputs = model(inputs, return_mask=False)\n",
|
119 |
+
" new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)\n",
|
120 |
+
" loss = criterion(outputs, new_label)\n",
|
121 |
+
" val_loss += loss.item()\n",
|
122 |
+
" _, predicted = torch.max(outputs, 1)\n",
|
123 |
+
" total += labels.size(0)\n",
|
124 |
+
" correct_valid += (predicted == labels).sum().item()\n",
|
125 |
+
" scheduler.step(val_loss)\n",
|
126 |
+
" # Calculate training accuracy after each epoch\n",
|
127 |
+
" train_accuracy = 100 * correct_train / total_train\n",
|
128 |
+
" val_accuracy = correct_valid / total * 100.0\n",
|
129 |
+
"\n",
|
130 |
+
"\n",
|
131 |
+
" print(\"===========================\")\n",
|
132 |
+
" print('accuracy: ', epoch, train_accuracy, val_accuracy)\n",
|
133 |
+
" print('learning rate: ', scheduler.get_last_lr())\n",
|
134 |
+
" print(\"===========================\")"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"id": "3faa4a11-89fb-4556-ae87-3645a47fa00d",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"train_accuracy = 100 * correct_train / total_train\n",
|
145 |
+
"print('accuracy: ', epoch, train_accuracy)"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": null,
|
151 |
+
"id": "e586c4d2-a7f4-4f14-81fc-4f84ffac52b3",
|
152 |
+
"metadata": {},
|
153 |
+
"outputs": [],
|
154 |
+
"source": [
|
155 |
+
"import sigpyproc.readers as r\n",
|
156 |
+
"import cv2\n",
|
157 |
+
"import numpy as np\n",
|
158 |
+
"import matplotlib.pyplot as plt\n",
|
159 |
+
"\n",
|
160 |
+
"from scipy.special import softmax\n",
|
161 |
+
"%matplotlib inline\n",
|
162 |
+
"path = '/mnt/primary/ata/projects/p051/fil_60565_59210_9756774_J0534+2200_0001/LoB.C0928/fil_60565_59210_9756774_J0534+2200_0001-beam0000.fil'\n",
|
163 |
+
"# path = '/mnt/primary/ata/projects/p051/fil_60564_62428_4679748_J0332+5434_0001/LoB.C0928/fil_60564_62428_4679748_J0332+5434_0001-beam0000.fil'\n",
|
164 |
+
"\n",
|
165 |
+
"# Get some metadata\n",
|
166 |
+
"\n",
|
167 |
+
"# Open the filterbank file\n",
|
168 |
+
"fil = r.FilReader(path)\n",
|
169 |
+
"header = fil.header\n",
|
170 |
+
"print(\"Header:\", header)\n",
|
171 |
+
"n=100\n",
|
172 |
+
"li = [ 7257608, 7324207, 10393163, 10641071, 11130537, 11085081,\n",
|
173 |
+
" 11419145, 11964112, 12329364, 13047181]\n",
|
174 |
+
"for el in li:\n",
|
175 |
+
" data = torch.tensor(fil.read_block(el-1024, 2048)).cuda()\n",
|
176 |
+
" print(data.shape)\n",
|
177 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
178 |
+
" print(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
179 |
+
" plt.figure(figsize=(10,10))\n",
|
180 |
+
" plt.imshow(data.cpu().numpy(), aspect = 10)\n",
|
181 |
+
" plt.show()"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": null,
|
187 |
+
"id": "609e5564-f14f-4bd1-b604-68e7e7d42834",
|
188 |
+
"metadata": {},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"triggers = []\n",
|
192 |
+
"counter = 0\n",
|
193 |
+
"with torch.no_grad():\n",
|
194 |
+
" for i in range(2048,10201921, 2048 ):\n",
|
195 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
196 |
+
" # Shuffle the tensor using the random indices\n",
|
197 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
198 |
+
" triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
199 |
+
" counter += 1\n",
|
200 |
+
" if counter > 1000:\n",
|
201 |
+
" break"
|
202 |
+
]
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": null,
|
207 |
+
"id": "08ee6dcf-cb30-4490-8624-4e52552fdf39",
|
208 |
+
"metadata": {},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"print(triggers[0])"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": null,
|
217 |
+
"id": "8c56c6f5-5a0b-4854-8a94-066a9baf4cfc",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"stack = np.stack(triggers)\n",
|
222 |
+
"positives = stack[:,0,1]\n",
|
223 |
+
"num_pos = np.where(positives > 0.5)[0].shape[0]\n",
|
224 |
+
"print(num_pos)"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"id": "eb1d1591-8855-4989-bf12-c8a9cdbf2a4d",
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [],
|
233 |
+
"source": [
|
234 |
+
"import pickle\n",
|
235 |
+
"\n",
|
236 |
+
"# Path to your pickle file\n",
|
237 |
+
"file_path = \"../dataset_generator/dir.pkl\"\n",
|
238 |
+
"\n",
|
239 |
+
"# Open and load the pickle file\n",
|
240 |
+
"with open(file_path, \"rb\") as file: # Use \"rb\" mode for reading binary files\n",
|
241 |
+
" data = pickle.load(file)\n",
|
242 |
+
"\n",
|
243 |
+
"# Print the contents of the file\n"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": null,
|
249 |
+
"id": "46f61d7e-55fa-44fe-be94-d4ddb3c576f9",
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"import sigpyproc.readers as r\n",
|
254 |
+
"import cv2\n",
|
255 |
+
"import numpy as np\n",
|
256 |
+
"import matplotlib.pyplot as plt\n",
|
257 |
+
"\n",
|
258 |
+
"from scipy.special import softmax\n",
|
259 |
+
"%matplotlib inline\n",
|
260 |
+
"path = data[0]\n",
|
261 |
+
"model.eval()\n",
|
262 |
+
"\n",
|
263 |
+
"fil = r.FilReader(path)\n",
|
264 |
+
"header = fil.header\n",
|
265 |
+
"print(\"Header:\", header)\n",
|
266 |
+
"n=100\n",
|
267 |
+
"\n",
|
268 |
+
"\n",
|
269 |
+
"triggers = []\n",
|
270 |
+
"counter = 0\n",
|
271 |
+
"for i in range(2048,10201921, 2048):\n",
|
272 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
273 |
+
" # Shuffle the tensor using the random indices\n",
|
274 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
275 |
+
" triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
|
276 |
+
" counter += 1\n",
|
277 |
+
" if counter > 1000:\n",
|
278 |
+
" break"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": null,
|
284 |
+
"id": "413d402e-2ce3-49fc-bbd4-a3cf1cc92388",
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [],
|
287 |
+
"source": [
|
288 |
+
"print(triggers[0])"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": null,
|
294 |
+
"id": "5c039dee-1b9b-4664-b42a-a79d780f37f1",
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"stack = np.stack(triggers)\n",
|
299 |
+
"positives = stack[:,0,1]\n",
|
300 |
+
"num_pos = np.where(positives > 0.5)[0].shape[0]\n",
|
301 |
+
"print(num_pos)"
|
302 |
+
]
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"metadata": {
|
306 |
+
"kernelspec": {
|
307 |
+
"display_name": "Python 3 (ipykernel)",
|
308 |
+
"language": "python",
|
309 |
+
"name": "python3"
|
310 |
+
},
|
311 |
+
"language_info": {
|
312 |
+
"codemirror_mode": {
|
313 |
+
"name": "ipython",
|
314 |
+
"version": 3
|
315 |
+
},
|
316 |
+
"file_extension": ".py",
|
317 |
+
"mimetype": "text/x-python",
|
318 |
+
"name": "python",
|
319 |
+
"nbconvert_exporter": "python",
|
320 |
+
"pygments_lexer": "ipython3",
|
321 |
+
"version": "3.11.9"
|
322 |
+
}
|
323 |
+
},
|
324 |
+
"nbformat": 4,
|
325 |
+
"nbformat_minor": 5
|
326 |
+
}
|
models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b063cb5ad71d38a551a5a4beb9ae21399e5e633af128c2608645398509854239
|
3 |
+
size 16982029
|
models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 8,
|
6 |
+
"id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"6\n",
|
14 |
+
"num params encoder 50840\n",
|
15 |
+
"num params 21496282\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"from utils import CustomDataset, transform, Convert_ONNX\n",
|
21 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
22 |
+
"import torch\n",
|
23 |
+
"import numpy as np\n",
|
24 |
+
"from resnet_model_mask import ResidualBlock, ResNet\n",
|
25 |
+
"import torch\n",
|
26 |
+
"import torch.nn as nn\n",
|
27 |
+
"import torch.optim as optim\n",
|
28 |
+
"from tqdm import tqdm \n",
|
29 |
+
"import torch.nn.functional as F\n",
|
30 |
+
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
|
31 |
+
"import pickle\n",
|
32 |
+
"\n",
|
33 |
+
"torch.manual_seed(1)\n",
|
34 |
+
"# torch.manual_seed(42)\n",
|
35 |
+
"\n",
|
36 |
+
"\n",
|
37 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
38 |
+
"num_gpus = torch.cuda.device_count()\n",
|
39 |
+
"print(num_gpus)\n",
|
40 |
+
"\n",
|
41 |
+
"# Create custom dataset instance\n",
|
42 |
+
"data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
|
43 |
+
"dataset = CustomDataset(data_dir, transform=transform)\n",
|
44 |
+
"valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
|
45 |
+
"valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"num_classes = 2\n",
|
49 |
+
"trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
|
50 |
+
"\n",
|
51 |
+
"model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
|
52 |
+
"model = nn.DataParallel(model)\n",
|
53 |
+
"model = model.to(device)\n",
|
54 |
+
"params = sum(p.numel() for p in model.parameters())\n",
|
55 |
+
"print(\"num params \",params)\n"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 9,
|
61 |
+
"id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"# model_path = 'models/model-62-98.78.pt'\n",
|
66 |
+
"# model_path = 'models/model-47-99.125.pt'\n",
|
67 |
+
"model_path = 'models_mask/model-37-99.175_42.pt'\n",
|
68 |
+
"\n",
|
69 |
+
"# model_path = 'models_mask/model-10-97.055_1.pt'\n",
|
70 |
+
"model.load_state_dict(torch.load(model_path, weights_only=True))\n",
|
71 |
+
"model = model.eval()"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 10,
|
77 |
+
"id": "89d108de-7eae-4bbd-837c-8e657082a1e6",
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [
|
80 |
+
{
|
81 |
+
"name": "stdout",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"Header(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', data_type='raw data', nchans=192, foff=-0.5, fch1=1187.5, nbits=32, tsamp=6.4e-05, tstart=60692.03208333333, nsamples=28125184, nifs=1, coord=<SkyCoord (ICRS): (ra, dec) in deg\n",
|
85 |
+
" (83.63322, 22.01446)>, azimuth=<Angle 80.54659271 deg>, zenith=<Angle 66.41192055 deg>, telescope='Effelsberg LOFAR', backend='FAKE', source='crab', frame='topocentric', ibeam=0, nbeams=2, dm=0, period=0, accel=0, signed=False, rawdatafile='', stream_info=StreamInfo(entries=[FileInfo(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', hdrlen=338, datalen=21600141312, nsamples=28125184, tstart=60692.03208333333, tsamp=6.4e-05)]))\n"
|
86 |
+
]
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"import sigpyproc.readers as r\n",
|
91 |
+
"import cv2\n",
|
92 |
+
"import numpy as np\n",
|
93 |
+
"import matplotlib.pyplot as plt\n",
|
94 |
+
"fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
|
95 |
+
"# fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoB.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
|
96 |
+
"header = fil.header\n",
|
97 |
+
"print(header)"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 11,
|
103 |
+
"id": "0b276e6e-d6c8-41da-808d-542ee22133d1",
|
104 |
+
"metadata": {},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"name": "stderr",
|
108 |
+
"output_type": "stream",
|
109 |
+
"text": [
|
110 |
+
" 0%| | 0/13732 [00:00<?, ?it/s]/tmp/ipykernel_19961/1777549771.py:15: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
|
111 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
112 |
+
" 3%|▉ | 351/13732 [00:13<08:27, 26.38it/s]\n"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"ename": "KeyboardInterrupt",
|
117 |
+
"evalue": "",
|
118 |
+
"output_type": "error",
|
119 |
+
"traceback": [
|
120 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
121 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
122 |
+
"Cell \u001b[0;32mIn[11], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m counter \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2048\u001b[39m,header\u001b[38;5;241m.\u001b[39mnsamples, \u001b[38;5;241m2048\u001b[39m)):\n\u001b[0;32m---> 13\u001b[0m data \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(fil\u001b[38;5;241m.\u001b[39mread_block(i\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1024\u001b[39m, \u001b[38;5;241m2048\u001b[39m))\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# Shuffle the tensor using the random indices\u001b[39;00m\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(transform(torch\u001b[38;5;241m.\u001b[39mtensor(data)\u001b[38;5;241m.\u001b[39mcuda())[\u001b[38;5;28;01mNone\u001b[39;00m])\n",
|
123 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"name": "stdout",
|
128 |
+
"output_type": "stream",
|
129 |
+
"text": [
|
130 |
+
"Error in callback <function flush_figures at 0x7f6c8689ae80> (for post_execute), with arguments args (),kwargs {}:\n"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"ename": "KeyboardInterrupt",
|
135 |
+
"evalue": "",
|
136 |
+
"output_type": "error",
|
137 |
+
"traceback": [
|
138 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
139 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
140 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:126\u001b[0m, in \u001b[0;36mflush_figures\u001b[0;34m()\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m InlineBackend\u001b[38;5;241m.\u001b[39minstance()\u001b[38;5;241m.\u001b[39mclose_figures:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;66;03m# ignore the tracking, just draw and close all figures\u001b[39;00m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m show(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 128\u001b[0m \u001b[38;5;66;03m# safely show traceback if in IPython, else raise\u001b[39;00m\n\u001b[1;32m 129\u001b[0m ip \u001b[38;5;241m=\u001b[39m get_ipython()\n",
|
141 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:90\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(close, block)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m figure_manager \u001b[38;5;129;01min\u001b[39;00m Gcf\u001b[38;5;241m.\u001b[39mget_all_fig_managers():\n\u001b[0;32m---> 90\u001b[0m display(\n\u001b[1;32m 91\u001b[0m figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure,\n\u001b[1;32m 92\u001b[0m metadata\u001b[38;5;241m=\u001b[39m_fetch_figure_metadata(figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure)\n\u001b[1;32m 93\u001b[0m )\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 95\u001b[0m show\u001b[38;5;241m.\u001b[39m_to_draw \u001b[38;5;241m=\u001b[39m []\n",
|
142 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/display_functions.py:298\u001b[0m, in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, raw, clear, *objs, **kwargs)\u001b[0m\n\u001b[1;32m 296\u001b[0m publish_display_data(data\u001b[38;5;241m=\u001b[39mobj, metadata\u001b[38;5;241m=\u001b[39mmetadata, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 298\u001b[0m format_dict, md_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mformat\u001b[39m(obj, include\u001b[38;5;241m=\u001b[39minclude, exclude\u001b[38;5;241m=\u001b[39mexclude)\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m format_dict:\n\u001b[1;32m 300\u001b[0m \u001b[38;5;66;03m# nothing to display (e.g. _ipython_display_ took over)\u001b[39;00m\n\u001b[1;32m 301\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n",
|
143 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:179\u001b[0m, in \u001b[0;36mDisplayFormatter.format\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m 177\u001b[0m md \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 179\u001b[0m data \u001b[38;5;241m=\u001b[39m formatter(obj)\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 181\u001b[0m \u001b[38;5;66;03m# FIXME: log the exception\u001b[39;00m\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
|
144 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/decorator.py:232\u001b[0m, in \u001b[0;36mdecorate.<locals>.fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kwsyntax:\n\u001b[1;32m 231\u001b[0m args, kw \u001b[38;5;241m=\u001b[39m fix(args, kw, sig)\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m caller(func, \u001b[38;5;241m*\u001b[39m(extras \u001b[38;5;241m+\u001b[39m args), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n",
|
145 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:223\u001b[0m, in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"show traceback on failed format call\"\"\"\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 223\u001b[0m r \u001b[38;5;241m=\u001b[39m method(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;66;03m# don't warn on NotImplementedErrors\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_return(\u001b[38;5;28;01mNone\u001b[39;00m, args[\u001b[38;5;241m0\u001b[39m])\n",
|
146 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:340\u001b[0m, in \u001b[0;36mBaseFormatter.__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 340\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m printer(obj)\n\u001b[1;32m 341\u001b[0m \u001b[38;5;66;03m# Finally look for special method names\u001b[39;00m\n\u001b[1;32m 342\u001b[0m method \u001b[38;5;241m=\u001b[39m get_real_method(obj, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_method)\n",
|
147 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/pylabtools.py:152\u001b[0m, in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, base64, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend_bases\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FigureCanvasBase\n\u001b[1;32m 150\u001b[0m FigureCanvasBase(fig)\n\u001b[0;32m--> 152\u001b[0m fig\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mprint_figure(bytes_io, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n\u001b[1;32m 153\u001b[0m data \u001b[38;5;241m=\u001b[39m bytes_io\u001b[38;5;241m.\u001b[39mgetvalue()\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fmt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msvg\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
|
148 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/backend_bases.py:2164\u001b[0m, in \u001b[0;36mFigureCanvasBase.print_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m 2161\u001b[0m \u001b[38;5;66;03m# we do this instead of `self.figure.draw_without_rendering`\u001b[39;00m\n\u001b[1;32m 2162\u001b[0m \u001b[38;5;66;03m# so that we can inject the orientation\u001b[39;00m\n\u001b[1;32m 2163\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(renderer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_draw_disabled\u001b[39m\u001b[38;5;124m\"\u001b[39m, nullcontext)():\n\u001b[0;32m-> 2164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 2165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches:\n\u001b[1;32m 2166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtight\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
|
149 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:95\u001b[0m, in \u001b[0;36m_finalize_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(draw)\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdraw_wrapper\u001b[39m(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 95\u001b[0m result \u001b[38;5;241m=\u001b[39m draw(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m renderer\u001b[38;5;241m.\u001b[39m_rasterizing:\n\u001b[1;32m 97\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstop_rasterizing()\n",
|
150 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
151 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/figure.py:3154\u001b[0m, in \u001b[0;36mFigure.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3151\u001b[0m \u001b[38;5;66;03m# ValueError can occur when resizing a window.\u001b[39;00m\n\u001b[1;32m 3153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpatch\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[0;32m-> 3154\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3155\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3157\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m sfig \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msubfigs:\n\u001b[1;32m 3158\u001b[0m sfig\u001b[38;5;241m.\u001b[39mdraw(renderer)\n",
|
152 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
|
153 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
154 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/axes/_base.py:3070\u001b[0m, in \u001b[0;36m_AxesBase.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3067\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artists_rasterized:\n\u001b[1;32m 3068\u001b[0m _draw_rasterized(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure, artists_rasterized, renderer)\n\u001b[0;32m-> 3070\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3071\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3073\u001b[0m renderer\u001b[38;5;241m.\u001b[39mclose_group(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maxes\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 3074\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstale \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
155 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
|
156 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
157 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:649\u001b[0m, in \u001b[0;36m_ImageBase.draw\u001b[0;34m(self, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 647\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im, trans)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 649\u001b[0m im, l, b, trans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_image(\n\u001b[1;32m 650\u001b[0m renderer, renderer\u001b[38;5;241m.\u001b[39mget_image_magnification())\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m im \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 652\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im)\n",
|
158 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:939\u001b[0m, in \u001b[0;36mAxesImage.make_image\u001b[0;34m(self, renderer, magnification, unsampled)\u001b[0m\n\u001b[1;32m 936\u001b[0m transformed_bbox \u001b[38;5;241m=\u001b[39m TransformedBbox(bbox, trans)\n\u001b[1;32m 937\u001b[0m clip \u001b[38;5;241m=\u001b[39m ((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_box() \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes\u001b[38;5;241m.\u001b[39mbbox) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_on()\n\u001b[1;32m 938\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mbbox)\n\u001b[0;32m--> 939\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_image(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_A, bbox, transformed_bbox, clip,\n\u001b[1;32m 940\u001b[0m magnification, unsampled\u001b[38;5;241m=\u001b[39munsampled)\n",
|
159 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:526\u001b[0m, in \u001b[0;36m_ImageBase._make_image\u001b[0;34m(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)\u001b[0m\n\u001b[1;32m 521\u001b[0m mask \u001b[38;5;241m=\u001b[39m (np\u001b[38;5;241m.\u001b[39mwhere(A\u001b[38;5;241m.\u001b[39mmask, np\u001b[38;5;241m.\u001b[39mfloat32(np\u001b[38;5;241m.\u001b[39mnan), np\u001b[38;5;241m.\u001b[39mfloat32(\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m A\u001b[38;5;241m.\u001b[39mmask\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m A\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;66;03m# nontrivial mask\u001b[39;00m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m np\u001b[38;5;241m.\u001b[39mones_like(A, np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[1;32m 524\u001b[0m \u001b[38;5;66;03m# we always have to interpolate the mask to account for\u001b[39;00m\n\u001b[1;32m 525\u001b[0m \u001b[38;5;66;03m# non-affine transformations\u001b[39;00m\n\u001b[0;32m--> 526\u001b[0m out_alpha \u001b[38;5;241m=\u001b[39m _resample(\u001b[38;5;28mself\u001b[39m, mask, out_shape, t, resample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m mask \u001b[38;5;66;03m# Make sure we don't use mask anymore!\u001b[39;00m\n\u001b[1;32m 528\u001b[0m \u001b[38;5;66;03m# Agg updates out_alpha in place. If the pixel has no image\u001b[39;00m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;66;03m# data it will not be updated (and still be 0 as we initialized\u001b[39;00m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;66;03m# it), if input data that would go into that output pixel than\u001b[39;00m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;66;03m# it will be `nan`, if all the input data for a pixel is good\u001b[39;00m\n\u001b[1;32m 532\u001b[0m \u001b[38;5;66;03m# it will be 1, and if there is _some_ good data in that output\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;66;03m# pixel it will be between [0, 1] (such as a rotated image).\u001b[39;00m\n",
|
160 |
+
"File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:208\u001b[0m, in \u001b[0;36m_resample\u001b[0;34m(image_obj, data, out_shape, transform, resample, alpha)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 207\u001b[0m resample \u001b[38;5;241m=\u001b[39m image_obj\u001b[38;5;241m.\u001b[39mget_resample()\n\u001b[0;32m--> 208\u001b[0m _image\u001b[38;5;241m.\u001b[39mresample(data, out, transform,\n\u001b[1;32m 209\u001b[0m _interpd_[interpolation],\n\u001b[1;32m 210\u001b[0m resample,\n\u001b[1;32m 211\u001b[0m alpha,\n\u001b[1;32m 212\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filternorm(),\n\u001b[1;32m 213\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filterrad())\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
|
161 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
162 |
+
]
|
163 |
+
}
|
164 |
+
],
|
165 |
+
"source": [
|
166 |
+
"import sigpyproc.readers as r\n",
|
167 |
+
"import cv2\n",
|
168 |
+
"import numpy as np\n",
|
169 |
+
"import matplotlib.pyplot as plt\n",
|
170 |
+
"from scipy.special import softmax\n",
|
171 |
+
"from tqdm import tqdm\n",
|
172 |
+
"%matplotlib inline\n",
|
173 |
+
"\n",
|
174 |
+
"header = fil.header\n",
|
175 |
+
"triggers = []\n",
|
176 |
+
"counter = 0\n",
|
177 |
+
"for i in tqdm(range(2048,header.nsamples, 2048)):\n",
|
178 |
+
" data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
|
179 |
+
" # Shuffle the tensor using the random indices\n",
|
180 |
+
" out = model(transform(torch.tensor(data).cuda())[None])\n",
|
181 |
+
" out = softmax(out.detach().cpu().numpy(), axis=1)\n",
|
182 |
+
" triggers.append(out)\n",
|
183 |
+
" counter += 1\n",
|
184 |
+
" # if counter > 1000:\n",
|
185 |
+
" # break\n",
|
186 |
+
" # if out[0, 1]>0.999:\n",
|
187 |
+
" # key = data.cpu().numpy()\n",
|
188 |
+
" # plt.figure(figsize=(10,10))\n",
|
189 |
+
" # plt.imshow(data.cpu().numpy(), aspect = 10, vmax = 54557.824)\n",
|
190 |
+
"stack = np.stack(triggers)\n",
|
191 |
+
"positives = stack[:,0,1]\n",
|
192 |
+
"num_pos = np.where(positives > 0.999)[0].shape[0]\n",
|
193 |
+
"print(num_pos)"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"id": "64df934d-f4a2-49f0-857d-2661b1d78b21",
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"np.flipud()"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": null,
|
209 |
+
"id": "1eafb2c1-857e-48be-aa8b-18669c0e0f8c",
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"plt.figure(figsize=(10,10))\n",
|
214 |
+
"# plt.imshow(key, aspect = 10, vmax = 54557.824)\n",
|
215 |
+
"plt.imshow(key, aspect = 10)"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"execution_count": null,
|
221 |
+
"id": "ed3783c3-8ed1-46d6-91e4-e906dfa44913",
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [],
|
224 |
+
"source": [
|
225 |
+
"key.shape"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"id": "8b56a356-a582-4f5d-a8e2-20f725a48fb3",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"total_data =[]\n",
|
236 |
+
"for i in range(32):\n",
|
237 |
+
" total_data.append(key)\n",
|
238 |
+
"total_data = torch.tensor(np.array(total_data))\n",
|
239 |
+
"total_data.cpu().detach().numpy().tofile(\"crab_in.bin\")\n",
|
240 |
+
"print(total_data.shape)\n",
|
241 |
+
"outputs_data = []\n",
|
242 |
+
"for i in range(32):\n",
|
243 |
+
" temp = model(transform(total_data.cuda()[i,:,:])[None])\n",
|
244 |
+
" print(temp)\n",
|
245 |
+
" # outputs_data.append(softmax(temp.detach().cpu().numpy(), axis=1))\n",
|
246 |
+
" outputs_data.append(temp.detach().cpu().numpy())\n",
|
247 |
+
"outputs_data = torch.tensor(outputs_data)\n",
|
248 |
+
"outputs_data.cpu().detach().numpy().tofile(\"crab_out.bin\")"
|
249 |
+
]
|
250 |
+
}
|
251 |
+
],
|
252 |
+
"metadata": {
|
253 |
+
"kernelspec": {
|
254 |
+
"display_name": "Python 3 (ipykernel)",
|
255 |
+
"language": "python",
|
256 |
+
"name": "python3"
|
257 |
+
},
|
258 |
+
"language_info": {
|
259 |
+
"codemirror_mode": {
|
260 |
+
"name": "ipython",
|
261 |
+
"version": 3
|
262 |
+
},
|
263 |
+
"file_extension": ".py",
|
264 |
+
"mimetype": "text/x-python",
|
265 |
+
"name": "python",
|
266 |
+
"nbconvert_exporter": "python",
|
267 |
+
"pygments_lexer": "ipython3",
|
268 |
+
"version": "3.11.9"
|
269 |
+
}
|
270 |
+
},
|
271 |
+
"nbformat": 4,
|
272 |
+
"nbformat_minor": 5
|
273 |
+
}
|
models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/.ipynb_checkpoints/resnet_model-checkpoint.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ResidualBlock(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
|
7 |
+
super(ResidualBlock, self).__init__()
|
8 |
+
self.conv1 = nn.Sequential(
|
9 |
+
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
|
10 |
+
nn.BatchNorm2d(out_channels),
|
11 |
+
nn.ReLU())
|
12 |
+
self.conv2 = nn.Sequential(
|
13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
|
14 |
+
nn.BatchNorm2d(out_channels))
|
15 |
+
self.downsample = downsample
|
16 |
+
self.relu = nn.ReLU()
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.dropout_percentage = 0.5
|
19 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
20 |
+
self.batchnorm_mod = nn.BatchNorm2d(out_channels)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
residual = x
|
24 |
+
out = self.conv1(x)
|
25 |
+
out = self.dropout1(out)
|
26 |
+
# out = self.batchnorm_mod(out)
|
27 |
+
out = self.conv2(out)
|
28 |
+
out = self.dropout1(out)
|
29 |
+
# out = self.batchnorm_mod(out)
|
30 |
+
if self.downsample:
|
31 |
+
residual = self.downsample(x)
|
32 |
+
out += residual
|
33 |
+
out = self.relu(out)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
class ResNet(nn.Module):
|
38 |
+
def __init__(self, inchan, block, layers, num_classes = 10):
|
39 |
+
super(ResNet, self).__init__()
|
40 |
+
self.inplanes = 64
|
41 |
+
self.eps = 1e-5
|
42 |
+
self.relu = nn.ReLU()
|
43 |
+
self.conv1 = nn.Sequential(
|
44 |
+
nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
|
45 |
+
nn.BatchNorm2d(64),
|
46 |
+
nn.ReLU())
|
47 |
+
self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
|
48 |
+
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
|
49 |
+
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
|
50 |
+
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
|
51 |
+
self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
|
52 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
53 |
+
self.fc = nn.Linear(39424, num_classes)
|
54 |
+
self.dropout_percentage = 0.3
|
55 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
56 |
+
|
57 |
+
# Encoder
|
58 |
+
self.encoder = nn.Sequential(
|
59 |
+
nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
|
60 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
61 |
+
nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
|
62 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
63 |
+
nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
|
64 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
65 |
+
nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
|
66 |
+
nn.Sigmoid()
|
67 |
+
)
|
68 |
+
params = sum(p.numel() for p in self.encoder.parameters())
|
69 |
+
print("num params encoder ",params)
|
70 |
+
|
71 |
+
def norm(self, x):
|
72 |
+
shifted = x-x.min()
|
73 |
+
maxes = torch.amax(abs(shifted), dim=(-2, -1))
|
74 |
+
repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
|
75 |
+
x = shifted/repeated_maxes
|
76 |
+
return x
|
77 |
+
|
78 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
79 |
+
downsample = None
|
80 |
+
if stride != 1 or self.inplanes != planes:
|
81 |
+
downsample = nn.Sequential(
|
82 |
+
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
|
83 |
+
nn.BatchNorm2d(planes),
|
84 |
+
)
|
85 |
+
layers = []
|
86 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
87 |
+
self.inplanes = planes
|
88 |
+
for i in range(1, blocks):
|
89 |
+
layers.append(block(self.inplanes, planes))
|
90 |
+
return nn.Sequential(*layers)
|
91 |
+
|
92 |
+
def forward(self, x, return_mask=False):
|
93 |
+
# x = self.norm(x)
|
94 |
+
x = self.conv1(x)
|
95 |
+
x = self.maxpool(x)
|
96 |
+
x = self.layer0(x)
|
97 |
+
x = self.layer1(x)
|
98 |
+
x = self.layer2(x)
|
99 |
+
x = self.layer3(x)
|
100 |
+
x = self.avgpool(x)
|
101 |
+
x = x.view(x.size(0), -1)
|
102 |
+
x = self.dropout1(x)
|
103 |
+
x = self.fc(x)
|
104 |
+
# return x
|
105 |
+
if return_mask:
|
106 |
+
return x, self.mask, self.value
|
107 |
+
else:
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class ConvAutoencoder(nn.Module):
|
112 |
+
def __init__(self):
|
113 |
+
super(ConvAutoencoder, self).__init__()
|
114 |
+
|
115 |
+
# Encoder
|
116 |
+
self.encoder = nn.Sequential(
|
117 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
|
118 |
+
nn.ReLU(),
|
119 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
|
120 |
+
nn.ReLU(),
|
121 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
|
122 |
+
nn.ReLU(),
|
123 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
|
124 |
+
nn.ReLU()
|
125 |
+
)
|
126 |
+
|
127 |
+
# Fully connected latent space
|
128 |
+
self.fc1 = nn.Linear(128 * 12 * 16, 8)
|
129 |
+
self.fc2 = nn.Linear(8, 128 * 12 * 16)
|
130 |
+
|
131 |
+
# Decoder
|
132 |
+
self.decoder = nn.Sequential(
|
133 |
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
|
134 |
+
nn.ReLU(),
|
135 |
+
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
|
136 |
+
nn.ReLU(),
|
137 |
+
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
|
138 |
+
nn.ReLU(),
|
139 |
+
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
|
140 |
+
nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
# Encode
|
145 |
+
x = self.encoder(x)
|
146 |
+
|
147 |
+
# Flatten the encoded output
|
148 |
+
x = x.view(x.size(0), -1)
|
149 |
+
|
150 |
+
# Fully connected latent space
|
151 |
+
x = self.fc1(x)
|
152 |
+
x = self.fc2(x)
|
153 |
+
|
154 |
+
# Reshape the output to the shape suitable for the decoder
|
155 |
+
x = x.view(x.size(0), 128, 12, 16)
|
156 |
+
|
157 |
+
# Decode
|
158 |
+
x = self.decoder(x)
|
159 |
+
|
160 |
+
return x
|
models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ResidualBlock(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
|
7 |
+
super(ResidualBlock, self).__init__()
|
8 |
+
self.conv1 = nn.Sequential(
|
9 |
+
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
|
10 |
+
nn.BatchNorm2d(out_channels),
|
11 |
+
nn.ReLU())
|
12 |
+
self.conv2 = nn.Sequential(
|
13 |
+
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
|
14 |
+
nn.BatchNorm2d(out_channels))
|
15 |
+
self.downsample = downsample
|
16 |
+
self.relu = nn.ReLU()
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.dropout_percentage = 0.5
|
19 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
20 |
+
self.batchnorm_mod = nn.BatchNorm2d(out_channels)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
residual = x
|
24 |
+
out = self.conv1(x)
|
25 |
+
out = self.dropout1(out)
|
26 |
+
# out = self.batchnorm_mod(out)
|
27 |
+
out = self.conv2(out)
|
28 |
+
out = self.dropout1(out)
|
29 |
+
# out = self.batchnorm_mod(out)
|
30 |
+
if self.downsample:
|
31 |
+
residual = self.downsample(x)
|
32 |
+
out += residual
|
33 |
+
out = self.relu(out)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
class ResNet(nn.Module):
|
38 |
+
def __init__(self, inchan, block, layers, num_classes = 10):
|
39 |
+
super(ResNet, self).__init__()
|
40 |
+
self.inplanes = 64
|
41 |
+
self.eps = 1e-5
|
42 |
+
self.relu = nn.ReLU()
|
43 |
+
self.conv1 = nn.Sequential(
|
44 |
+
nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
|
45 |
+
nn.BatchNorm2d(64),
|
46 |
+
nn.ReLU())
|
47 |
+
self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
|
48 |
+
self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
|
49 |
+
self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
|
50 |
+
self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
|
51 |
+
self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
|
52 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
53 |
+
self.fc = nn.Linear(39424, num_classes)
|
54 |
+
self.dropout_percentage = 0.3
|
55 |
+
self.dropout1 = nn.Dropout(p=self.dropout_percentage)
|
56 |
+
|
57 |
+
# Encoder
|
58 |
+
self.encoder = nn.Sequential(
|
59 |
+
nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
|
60 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
61 |
+
nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
|
62 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
63 |
+
nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
|
64 |
+
nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
|
65 |
+
nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
|
66 |
+
nn.Sigmoid()
|
67 |
+
)
|
68 |
+
params = sum(p.numel() for p in self.encoder.parameters())
|
69 |
+
print("num params encoder ",params)
|
70 |
+
|
71 |
+
def norm(self, x):
|
72 |
+
shifted = x-x.min()
|
73 |
+
maxes = torch.amax(abs(shifted), dim=(-2, -1))
|
74 |
+
repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
|
75 |
+
x = shifted/repeated_maxes
|
76 |
+
return x
|
77 |
+
|
78 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
79 |
+
downsample = None
|
80 |
+
if stride != 1 or self.inplanes != planes:
|
81 |
+
downsample = nn.Sequential(
|
82 |
+
nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
|
83 |
+
nn.BatchNorm2d(planes),
|
84 |
+
)
|
85 |
+
layers = []
|
86 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
87 |
+
self.inplanes = planes
|
88 |
+
for i in range(1, blocks):
|
89 |
+
layers.append(block(self.inplanes, planes))
|
90 |
+
return nn.Sequential(*layers)
|
91 |
+
|
92 |
+
def forward(self, x, return_mask=False):
|
93 |
+
# # m = self.encoder(x).unsqueeze(-1).repeat(1, 1, 1, x.shape[-1])
|
94 |
+
m = self.encoder(x)
|
95 |
+
self.mask = m
|
96 |
+
self.value = x
|
97 |
+
# # m = nn.Sigmoid()(self.encoder(x))
|
98 |
+
x = x * m
|
99 |
+
# x = self.norm(x)
|
100 |
+
x = self.conv1(x)
|
101 |
+
x = self.maxpool(x)
|
102 |
+
x = self.layer0(x)
|
103 |
+
x = self.layer1(x)
|
104 |
+
x = self.layer2(x)
|
105 |
+
x = self.layer3(x)
|
106 |
+
x = self.avgpool(x)
|
107 |
+
x = x.view(x.size(0), -1)
|
108 |
+
x = self.dropout1(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
# if return_mask:
|
112 |
+
# return x, self.mask, self.value
|
113 |
+
# else:
|
114 |
+
# return x
|
115 |
+
|
116 |
+
|
117 |
+
class ConvAutoencoder(nn.Module):
|
118 |
+
def __init__(self):
|
119 |
+
super(ConvAutoencoder, self).__init__()
|
120 |
+
|
121 |
+
# Encoder
|
122 |
+
self.encoder = nn.Sequential(
|
123 |
+
nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
|
124 |
+
nn.ReLU(),
|
125 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
|
126 |
+
nn.ReLU(),
|
127 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
|
128 |
+
nn.ReLU(),
|
129 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
|
130 |
+
nn.ReLU()
|
131 |
+
)
|
132 |
+
|
133 |
+
# Fully connected latent space
|
134 |
+
self.fc1 = nn.Linear(128 * 12 * 16, 8)
|
135 |
+
self.fc2 = nn.Linear(8, 128 * 12 * 16)
|
136 |
+
|
137 |
+
# Decoder
|
138 |
+
self.decoder = nn.Sequential(
|
139 |
+
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
|
140 |
+
nn.ReLU(),
|
141 |
+
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
|
142 |
+
nn.ReLU(),
|
143 |
+
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
|
146 |
+
nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
# Encode
|
151 |
+
x = self.encoder(x)
|
152 |
+
|
153 |
+
# Flatten the encoded output
|
154 |
+
x = x.view(x.size(0), -1)
|
155 |
+
|
156 |
+
# Fully connected latent space
|
157 |
+
x = self.fc1(x)
|
158 |
+
x = self.fc2(x)
|
159 |
+
|
160 |
+
# Reshape the output to the shape suitable for the decoder
|
161 |
+
x = x.view(x.size(0), 128, 12, 16)
|
162 |
+
|
163 |
+
# Decode
|
164 |
+
x = self.decoder(x)
|
165 |
+
|
166 |
+
return x
|
models/.ipynb_checkpoints/train-checkpoint.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from resnet_model import ResidualBlock, ResNet
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.optim as optim
|
9 |
+
import tqdm
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
12 |
+
import pickle
|
13 |
+
import sys
|
14 |
+
|
15 |
+
ind = int(sys.argv[1])
|
16 |
+
seeds = [1,42,7109,2002,32]
|
17 |
+
seed = seeds[ind]
|
18 |
+
print("using seed: ",seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
|
21 |
+
|
22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
+
num_gpus = torch.cuda.device_count()
|
24 |
+
print(num_gpus)
|
25 |
+
|
26 |
+
# Create custom dataset instance
|
27 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
28 |
+
dataset = CustomDataset(data_dir, transform=transform)
|
29 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
30 |
+
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
|
31 |
+
|
32 |
+
|
33 |
+
num_classes = 2
|
34 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
35 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
36 |
+
|
37 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
38 |
+
model = nn.DataParallel(model)
|
39 |
+
model = model.to(device)
|
40 |
+
params = sum(p.numel() for p in model.parameters())
|
41 |
+
print("num params ",params)
|
42 |
+
torch.save(model.state_dict(), 'models/test.pt')
|
43 |
+
model.load_state_dict(torch.load('models/test.pt'))
|
44 |
+
|
45 |
+
preproc_model = preproc()
|
46 |
+
Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
47 |
+
Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
48 |
+
|
49 |
+
# Define optimizer and loss function
|
50 |
+
|
51 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
52 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
53 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
54 |
+
|
55 |
+
|
56 |
+
from tqdm import tqdm
|
57 |
+
# Training loop
|
58 |
+
epochs = 10000
|
59 |
+
for epoch in range(epochs):
|
60 |
+
running_loss = 0.0
|
61 |
+
correct_train = 0
|
62 |
+
total_train = 0
|
63 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
64 |
+
model.train()
|
65 |
+
for i, (images, labels) in enumerate(tepoch):
|
66 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
67 |
+
optimizer.zero_grad()
|
68 |
+
outputs = model(inputs, return_mask=False).to(device)
|
69 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
70 |
+
loss = criterion(outputs, new_label)
|
71 |
+
loss.backward()
|
72 |
+
optimizer.step()
|
73 |
+
running_loss += loss.item()
|
74 |
+
# Calculate training accuracy
|
75 |
+
_, predicted = torch.max(outputs.data, 1)
|
76 |
+
total_train += labels.size(0)
|
77 |
+
correct_train += (predicted == labels).sum().item()
|
78 |
+
val_loss = 0.0
|
79 |
+
correct_valid = 0
|
80 |
+
total = 0
|
81 |
+
model.eval()
|
82 |
+
with torch.no_grad():
|
83 |
+
for images, labels in validloader:
|
84 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
85 |
+
optimizer.zero_grad()
|
86 |
+
outputs = model(inputs, return_mask=False)
|
87 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
88 |
+
loss = criterion(outputs, new_label)
|
89 |
+
val_loss += loss.item()
|
90 |
+
_, predicted = torch.max(outputs, 1)
|
91 |
+
total += labels.size(0)
|
92 |
+
correct_valid += (predicted == labels).sum().item()
|
93 |
+
scheduler.step(val_loss)
|
94 |
+
# Calculate training accuracy after each epoch
|
95 |
+
train_accuracy = 100 * correct_train / total_train
|
96 |
+
val_accuracy = correct_valid / total * 100.0
|
97 |
+
torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt')
|
98 |
+
Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs)
|
99 |
+
|
100 |
+
print("===========================")
|
101 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
102 |
+
print('learning rate: ', scheduler.get_last_lr())
|
103 |
+
print("===========================")
|
104 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
105 |
+
break
|
models/.ipynb_checkpoints/train-mask-8-checkpoint.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from resnet_model_mask import ResidualBlock, ResNet
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.optim as optim
|
9 |
+
import tqdm
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
12 |
+
import pickle
|
13 |
+
import sys
|
14 |
+
# [1,42,7109,2002,32]
|
15 |
+
ind = int(sys.argv[1])
|
16 |
+
seeds = [1,42,7109,2002,32]
|
17 |
+
seed = seeds[ind]
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
|
20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
+
num_gpus = torch.cuda.device_count()
|
22 |
+
print(num_gpus)
|
23 |
+
|
24 |
+
# Create custom dataset instance
|
25 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
26 |
+
dataset = CustomDataset(data_dir, bit8 = True, transform=transform)
|
27 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
28 |
+
valid_dataset = CustomDataset(valid_data_dir, bit8 = True, transform=transform)
|
29 |
+
|
30 |
+
|
31 |
+
num_classes = 2
|
32 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
33 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
34 |
+
|
35 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
36 |
+
model = nn.DataParallel(model)
|
37 |
+
model = model.to(device)
|
38 |
+
params = sum(p.numel() for p in model.parameters())
|
39 |
+
print("num params ",params)
|
40 |
+
torch.save(model.state_dict(), f'models_8/test_{seed}.pt')
|
41 |
+
model.load_state_dict(torch.load(f'models_8/test_{seed}.pt'))
|
42 |
+
|
43 |
+
preproc_model = preproc()
|
44 |
+
Convert_ONNX(model.module,f'models_8/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
45 |
+
Convert_ONNX(preproc_model,f'models_8/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
46 |
+
|
47 |
+
# Define optimizer and loss function
|
48 |
+
|
49 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
50 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
51 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
52 |
+
|
53 |
+
|
54 |
+
from tqdm import tqdm
|
55 |
+
# Training loop
|
56 |
+
epochs = 10000
|
57 |
+
for epoch in range(epochs):
|
58 |
+
running_loss = 0.0
|
59 |
+
correct_train = 0
|
60 |
+
total_train = 0
|
61 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
62 |
+
model.train()
|
63 |
+
for i, (images, labels) in enumerate(tepoch):
|
64 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
65 |
+
optimizer.zero_grad()
|
66 |
+
outputs = model(inputs, return_mask=False).to(device)
|
67 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
68 |
+
loss = criterion(outputs, new_label)
|
69 |
+
loss.backward()
|
70 |
+
optimizer.step()
|
71 |
+
running_loss += loss.item()
|
72 |
+
# Calculate training accuracy
|
73 |
+
_, predicted = torch.max(outputs.data, 1)
|
74 |
+
total_train += labels.size(0)
|
75 |
+
correct_train += (predicted == labels).sum().item()
|
76 |
+
val_loss = 0.0
|
77 |
+
correct_valid = 0
|
78 |
+
total = 0
|
79 |
+
model.eval()
|
80 |
+
with torch.no_grad():
|
81 |
+
for images, labels in validloader:
|
82 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
83 |
+
optimizer.zero_grad()
|
84 |
+
outputs = model(inputs, return_mask=False)
|
85 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
86 |
+
loss = criterion(outputs, new_label)
|
87 |
+
val_loss += loss.item()
|
88 |
+
_, predicted = torch.max(outputs, 1)
|
89 |
+
total += labels.size(0)
|
90 |
+
correct_valid += (predicted == labels).sum().item()
|
91 |
+
scheduler.step(val_loss)
|
92 |
+
# Calculate training accuracy after each epoch
|
93 |
+
train_accuracy = 100 * correct_train / total_train
|
94 |
+
val_accuracy = correct_valid / total * 100.0
|
95 |
+
torch.save(model.state_dict(), 'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
|
96 |
+
Convert_ONNX(model.module,'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
|
97 |
+
|
98 |
+
print("===========================")
|
99 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
100 |
+
print('learning rate: ', scheduler.get_last_lr())
|
101 |
+
print("===========================")
|
102 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
103 |
+
break
|
models/.ipynb_checkpoints/train-mask-checkpoint.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import CustomDataset, transform, preproc, Convert_ONNX
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from resnet_model_mask import ResidualBlock, ResNet
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.optim as optim
|
9 |
+
import tqdm
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
12 |
+
import pickle
|
13 |
+
import sys
|
14 |
+
|
15 |
+
ind = int(sys.argv[1])
|
16 |
+
seeds = [1,42,7109,2002,32]
|
17 |
+
seed = seeds[ind]
|
18 |
+
print("using seed: ",seed)
|
19 |
+
torch.manual_seed(seed)
|
20 |
+
|
21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
+
num_gpus = torch.cuda.device_count()
|
23 |
+
print(num_gpus)
|
24 |
+
|
25 |
+
# Create custom dataset instance
|
26 |
+
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
|
27 |
+
dataset = CustomDataset(data_dir, transform=transform)
|
28 |
+
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
|
29 |
+
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
|
30 |
+
|
31 |
+
|
32 |
+
num_classes = 2
|
33 |
+
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
|
34 |
+
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
|
35 |
+
|
36 |
+
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
|
37 |
+
model = nn.DataParallel(model)
|
38 |
+
model = model.to(device)
|
39 |
+
params = sum(p.numel() for p in model.parameters())
|
40 |
+
print("num params ",params)
|
41 |
+
torch.save(model.state_dict(), f'models_mask/test_{seed}.pt')
|
42 |
+
model.load_state_dict(torch.load(f'models_mask/test_{seed}.pt'))
|
43 |
+
|
44 |
+
preproc_model = preproc()
|
45 |
+
Convert_ONNX(model.module,f'models_mask/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
|
46 |
+
Convert_ONNX(preproc_model,f'models_mask/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
|
47 |
+
|
48 |
+
# Define optimizer and loss function
|
49 |
+
|
50 |
+
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
|
51 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
52 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
|
53 |
+
|
54 |
+
|
55 |
+
from tqdm import tqdm
|
56 |
+
# Training loop
|
57 |
+
epochs = 10000
|
58 |
+
for epoch in range(epochs):
|
59 |
+
running_loss = 0.0
|
60 |
+
correct_train = 0
|
61 |
+
total_train = 0
|
62 |
+
with tqdm(trainloader, unit="batch") as tepoch:
|
63 |
+
model.train()
|
64 |
+
for i, (images, labels) in enumerate(tepoch):
|
65 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
66 |
+
optimizer.zero_grad()
|
67 |
+
outputs = model(inputs, return_mask=False).to(device)
|
68 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
|
69 |
+
loss = criterion(outputs, new_label)
|
70 |
+
loss.backward()
|
71 |
+
optimizer.step()
|
72 |
+
running_loss += loss.item()
|
73 |
+
# Calculate training accuracy
|
74 |
+
_, predicted = torch.max(outputs.data, 1)
|
75 |
+
total_train += labels.size(0)
|
76 |
+
correct_train += (predicted == labels).sum().item()
|
77 |
+
val_loss = 0.0
|
78 |
+
correct_valid = 0
|
79 |
+
total = 0
|
80 |
+
model.eval()
|
81 |
+
with torch.no_grad():
|
82 |
+
for images, labels in validloader:
|
83 |
+
inputs, labels = images.to(device), labels.to(device).float()
|
84 |
+
optimizer.zero_grad()
|
85 |
+
outputs = model(inputs, return_mask=False)
|
86 |
+
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
|
87 |
+
loss = criterion(outputs, new_label)
|
88 |
+
val_loss += loss.item()
|
89 |
+
_, predicted = torch.max(outputs, 1)
|
90 |
+
total += labels.size(0)
|
91 |
+
correct_valid += (predicted == labels).sum().item()
|
92 |
+
scheduler.step(val_loss)
|
93 |
+
# Calculate training accuracy after each epoch
|
94 |
+
train_accuracy = 100 * correct_train / total_train
|
95 |
+
val_accuracy = correct_valid / total * 100.0
|
96 |
+
torch.save(model.state_dict(), 'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
|
97 |
+
Convert_ONNX(model.module,'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
|
98 |
+
|
99 |
+
print("===========================")
|
100 |
+
print('accuracy: ', epoch, train_accuracy, val_accuracy)
|
101 |
+
print('learning rate: ', scheduler.get_last_lr())
|
102 |
+
print("===========================")
|
103 |
+
if scheduler.get_last_lr()[0] <1e-6:
|
104 |
+
break
|
models/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from copy import deepcopy
|
7 |
+
from blimpy import Waterfall
|
8 |
+
from tqdm import tqdm
|
9 |
+
from copy import deepcopy
|
10 |
+
from sigpyproc.readers import FilReader
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def load_pickled_data(file_path):
|
15 |
+
with open(file_path, 'rb') as f:
|
16 |
+
data = pickle.load(f)
|
17 |
+
return data
|
18 |
+
|
19 |
+
# Custom dataset class
|
20 |
+
class CustomDataset(Dataset):
|
21 |
+
def __init__(self, data_dir, bit8=False, transform=None):
|
22 |
+
self.data_dir = data_dir
|
23 |
+
self.transform = transform
|
24 |
+
self.images = []
|
25 |
+
self.labels = []
|
26 |
+
self.classes = os.listdir(data_dir)
|
27 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
28 |
+
self.bit8 = bit8
|
29 |
+
# Load images and labels
|
30 |
+
for cls in self.classes:
|
31 |
+
class_dir = os.path.join(data_dir, cls)
|
32 |
+
for image_name in os.listdir(class_dir):
|
33 |
+
image_path = os.path.join(class_dir, image_name)
|
34 |
+
self.images.append(image_path)
|
35 |
+
self.labels.append(self.class_to_idx[cls])
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.images)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
image_path = self.images[idx]
|
42 |
+
label = self.labels[idx]
|
43 |
+
# Load image
|
44 |
+
image = load_pickled_data(image_path)
|
45 |
+
if self.transform is not None:
|
46 |
+
if self.bit8 == True:
|
47 |
+
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
|
48 |
+
else:
|
49 |
+
new_image = self.transform(torch.from_numpy(image['data']))
|
50 |
+
# new_image = self.transform(image['data'])
|
51 |
+
return new_image, label
|
52 |
+
|
53 |
+
# Custom dataset class
|
54 |
+
class CustomDataset_Masked(Dataset):
|
55 |
+
def __init__(self, data_dir, transform=None):
|
56 |
+
self.data_dir = data_dir
|
57 |
+
self.transform = transform
|
58 |
+
self.images = []
|
59 |
+
self.labels = []
|
60 |
+
self.classes = os.listdir(data_dir)
|
61 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
62 |
+
|
63 |
+
# Load images and labels
|
64 |
+
for cls in self.classes:
|
65 |
+
class_dir = os.path.join(data_dir, cls)
|
66 |
+
for image_name in os.listdir(class_dir):
|
67 |
+
image_path = os.path.join(class_dir, image_name)
|
68 |
+
self.images.append(image_path)
|
69 |
+
self.labels.append(self.class_to_idx[cls])
|
70 |
+
|
71 |
+
def __len__(self):
|
72 |
+
return len(self.images)
|
73 |
+
|
74 |
+
def __getitem__(self, idx):
|
75 |
+
image_path = self.images[idx]
|
76 |
+
|
77 |
+
label = self.labels[idx]
|
78 |
+
# Load image
|
79 |
+
image = load_pickled_data(image_path)
|
80 |
+
if self.transform is not None:
|
81 |
+
if image['burst'].max() ==0:
|
82 |
+
new_burst = torch.from_numpy(image['burst'])
|
83 |
+
else:
|
84 |
+
new_burst = torch.from_numpy(image['burst']/image['burst'].max())
|
85 |
+
ind = new_burst > 0.1
|
86 |
+
ind_not = new_burst <= 0.1
|
87 |
+
new_burst[ind] = 1
|
88 |
+
new_burst[ind_not] = 0
|
89 |
+
new_image = self.transform(torch.from_numpy(image['data'].data))
|
90 |
+
new_burst_arr = torch.zeros_like(new_image)
|
91 |
+
new_burst_arr[ 0, :,:] = new_burst
|
92 |
+
new_burst_arr[ 1, :,:] = new_burst
|
93 |
+
new_burst_arr[ 2, :,:] = new_burst
|
94 |
+
return new_image, label, new_burst_arr
|
95 |
+
|
96 |
+
# Custom dataset class
|
97 |
+
class TestingDataset(Dataset):
|
98 |
+
def __init__(self, data_dir, bit8=False, transform=None):
|
99 |
+
self.data_dir = data_dir
|
100 |
+
self.transform = transform
|
101 |
+
self.images = []
|
102 |
+
self.labels = []
|
103 |
+
self.classes = os.listdir(data_dir)
|
104 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
105 |
+
self.bit8 = bit8
|
106 |
+
# Load images and labels
|
107 |
+
for cls in self.classes:
|
108 |
+
class_dir = os.path.join(data_dir, cls)
|
109 |
+
for image_name in os.listdir(class_dir):
|
110 |
+
image_path = os.path.join(class_dir, image_name)
|
111 |
+
self.images.append(image_path)
|
112 |
+
self.labels.append(self.class_to_idx[cls])
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.images)
|
116 |
+
|
117 |
+
def __getitem__(self, idx):
|
118 |
+
image_path = self.images[idx]
|
119 |
+
label = self.labels[idx]
|
120 |
+
# Load image
|
121 |
+
image = load_pickled_data(image_path)
|
122 |
+
params = image['params']
|
123 |
+
if self.transform is not None:
|
124 |
+
params = image['params']
|
125 |
+
if self.bit8 == True:
|
126 |
+
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
|
127 |
+
else:
|
128 |
+
new_image = self.transform(torch.from_numpy(image['data']))
|
129 |
+
params['labels'] = label
|
130 |
+
return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard'])
|
131 |
+
|
132 |
+
# Custom dataset class
|
133 |
+
class SearchDataset(Dataset):
|
134 |
+
def __init__(self, data_dir, transform=None, pickle_data=False):
|
135 |
+
self.window_size = 2048
|
136 |
+
|
137 |
+
if pickle_data:
|
138 |
+
with open(data_dir, 'rb') as f:
|
139 |
+
self.d = pickle.load(f)
|
140 |
+
self.header = self.d['header']
|
141 |
+
self.images = self.crop(self.d['data'][:,0,:], self.window_size)
|
142 |
+
else:
|
143 |
+
self.obs = Waterfall(data_dir, max_load = 50)
|
144 |
+
self.header = self.obs.header
|
145 |
+
self.images = self.crop(self.obs.data[:,0,:], self.window_size)
|
146 |
+
self.transform = transform
|
147 |
+
self.SEC_PER_DAY = 86400
|
148 |
+
|
149 |
+
def crop(self, data, window_size = 2048):
|
150 |
+
n_samp = data.shape[0]//window_size
|
151 |
+
new_data = np.zeros((n_samp, window_size, 192 ))
|
152 |
+
for i in range(n_samp):
|
153 |
+
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
|
154 |
+
return new_data
|
155 |
+
|
156 |
+
def __len__(self):
|
157 |
+
return self.images.shape[0]
|
158 |
+
def __getitem__(self, idx):
|
159 |
+
data = self.images[idx, :, :].T
|
160 |
+
tindex = idx * self.window_size
|
161 |
+
time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart']
|
162 |
+
if self.transform is not None:
|
163 |
+
new_image = self.transform(data)
|
164 |
+
return new_image, idx
|
165 |
+
|
166 |
+
# Custom dataset class
|
167 |
+
class SearchDataset_Sigproc(Dataset):
|
168 |
+
def __init__(self, data_dir, transform=None):
|
169 |
+
self.window_size = 2048
|
170 |
+
fil = FilReader(data_dir)
|
171 |
+
self.header = fil.header
|
172 |
+
# print("check shape ",fil.read_block(0, fil.header.nsamples).shape)
|
173 |
+
read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024]
|
174 |
+
read_data = np.swapaxes(read_data, 0,-1)
|
175 |
+
self.images = self.crop(read_data, self.window_size)
|
176 |
+
self.transform = transform
|
177 |
+
self.SEC_PER_DAY = 86400
|
178 |
+
|
179 |
+
def crop(self, data, window_size = 2048):
|
180 |
+
n_samp = data.shape[0]//window_size
|
181 |
+
new_data = np.zeros((n_samp, window_size, 192 ))
|
182 |
+
for i in range(n_samp):
|
183 |
+
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
|
184 |
+
return new_data
|
185 |
+
|
186 |
+
def __len__(self):
|
187 |
+
return self.images.shape[0]
|
188 |
+
|
189 |
+
def __getitem__(self, idx):
|
190 |
+
data = self.images[idx, :, :].T
|
191 |
+
tindex = idx * self.window_size
|
192 |
+
time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart
|
193 |
+
if self.transform is not None:
|
194 |
+
new_image = self.transform(torch.from_numpy(data))
|
195 |
+
return new_image, idx
|
196 |
+
|
197 |
+
# def renorm(data):
|
198 |
+
# shifted = data - data.min()
|
199 |
+
# shifted = shifted/shifted.max()
|
200 |
+
# return shifted
|
201 |
+
|
202 |
+
def renorm(data):
|
203 |
+
mean = torch.mean(data)
|
204 |
+
std = torch.std(data)
|
205 |
+
# Standardize the data
|
206 |
+
standardized_data = (data - mean) / std
|
207 |
+
return standardized_data
|
208 |
+
|
209 |
+
def transform(data):
|
210 |
+
copy_data = data.detach().clone()
|
211 |
+
rms = torch.std(data)
|
212 |
+
mean = torch.mean(data)
|
213 |
+
masks_rms = [-1, 5]
|
214 |
+
new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
|
215 |
+
new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
|
216 |
+
for i in range(1, len(masks_rms)+1):
|
217 |
+
scale = masks_rms[i-1]
|
218 |
+
copy_data = data.detach().clone() #deepcopy(data)
|
219 |
+
if scale < 0:
|
220 |
+
ind = copy_data < abs(scale) * rms + mean
|
221 |
+
copy_data[ind] = 0
|
222 |
+
else:
|
223 |
+
ind = copy_data > (scale) * rms + mean
|
224 |
+
copy_data[ind] = 0
|
225 |
+
new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
|
226 |
+
new_data = new_data.type(torch.float32)
|
227 |
+
slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
|
228 |
+
new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
|
229 |
+
new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
|
230 |
+
return new_data
|
231 |
+
|
232 |
+
|
233 |
+
def renorm_batched(data):
|
234 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
235 |
+
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
236 |
+
standardized_data = (data - mean) / std
|
237 |
+
return standardized_data
|
238 |
+
|
239 |
+
def transform_batched(data):
|
240 |
+
copy_data = data.detach().clone()
|
241 |
+
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
|
242 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
|
243 |
+
masks_rms = [-1, 5]
|
244 |
+
|
245 |
+
# Prepare the new_data tensor
|
246 |
+
num_masks = len(masks_rms) + 1
|
247 |
+
new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
|
248 |
+
|
249 |
+
# First layer: Apply renorm(log10(copy_data + epsilon))
|
250 |
+
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
|
251 |
+
for i, scale in enumerate(masks_rms, start=1):
|
252 |
+
copy_data = data.detach().clone()
|
253 |
+
|
254 |
+
# Apply masking based on the scale
|
255 |
+
if scale < 0:
|
256 |
+
ind = copy_data < abs(scale) * rms + mean
|
257 |
+
else:
|
258 |
+
ind = copy_data > scale * rms + mean
|
259 |
+
copy_data[ind] = 0
|
260 |
+
|
261 |
+
# Renormalize and log10 transform
|
262 |
+
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
|
263 |
+
|
264 |
+
# Convert to float32
|
265 |
+
new_data = new_data.type(torch.float32)
|
266 |
+
|
267 |
+
# Chunk along the last dimension and stack
|
268 |
+
slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
|
269 |
+
new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
|
270 |
+
new_data = torch.swapaxes(new_data, 0,1)
|
271 |
+
# Reshape into final format
|
272 |
+
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
|
273 |
+
return new_data
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
class preproc(nn.Module):
|
278 |
+
def forward(self, x, flip=True):
|
279 |
+
if flip:
|
280 |
+
transform_batched(torch.flip(x, dims = (-2,)))
|
281 |
+
else:
|
282 |
+
transform_batched(x)
|
283 |
+
return template
|
284 |
+
|
285 |
+
# class preproc_debug(nn.Module):
|
286 |
+
# def forward(self, x):
|
287 |
+
# template = torch.zeros((32, 24, 192, 256))
|
288 |
+
# # for i in torch.arange(x.shape[0]): # Use a tensor-based range
|
289 |
+
# template[0,:,:,:] = transform_debug(torch.flip(x[0,:,:], dims = (0,)))
|
290 |
+
# template[1,:,:,:] = transform_debug(torch.flip(x[1,:,:], dims = (0,)))
|
291 |
+
# template[2,:,:,:] = transform_debug(torch.flip(x[2,:,:], dims = (0,)))
|
292 |
+
# template[3,:,:,:] = transform_debug(torch.flip(x[3,:,:], dims = (0,)))
|
293 |
+
# template[4,:,:,:] = transform_debug(torch.flip(x[4,:,:], dims = (0,)))
|
294 |
+
# template[5,:,:,:] = transform_debug(torch.flip(x[5,:,:], dims = (0,)))
|
295 |
+
# template[6,:,:,:] = transform_debug(torch.flip(x[6,:,:], dims = (0,)))
|
296 |
+
# template[7,:,:,:] = transform_debug(torch.flip(x[7,:,:], dims = (0,)))
|
297 |
+
# template[8,:,:,:] = transform_debug(torch.flip(x[8,:,:], dims = (0,)))
|
298 |
+
# template[9,:,:,:] = transform_debug(torch.flip(x[9,:,:], dims = (0,)))
|
299 |
+
# template[10,:,:,:] = transform_debug(torch.flip(x[10,:,:], dims = (0,)))
|
300 |
+
# template[11,:,:,:] = transform_debug(torch.flip(x[11,:,:], dims = (0,)))
|
301 |
+
# template[12,:,:,:] = transform_debug(torch.flip(x[12,:,:], dims = (0,)))
|
302 |
+
# template[13,:,:,:] = transform_debug(torch.flip(x[13,:,:], dims = (0,)))
|
303 |
+
# template[14,:,:,:] = transform_debug(torch.flip(x[14,:,:], dims = (0,)))
|
304 |
+
# template[15,:,:,:] = transform_debug(torch.flip(x[15,:,:], dims = (0,)))
|
305 |
+
# template[16,:,:,:] = transform_debug(torch.flip(x[16,:,:], dims = (0,)))
|
306 |
+
# template[17,:,:,:] = transform_debug(torch.flip(x[17,:,:], dims = (0,)))
|
307 |
+
# template[18,:,:,:] = transform_debug(torch.flip(x[18,:,:], dims = (0,)))
|
308 |
+
# template[19,:,:,:] = transform_debug(torch.flip(x[19,:,:], dims = (0,)))
|
309 |
+
# template[20,:,:,:] = transform_debug(torch.flip(x[20,:,:], dims = (0,)))
|
310 |
+
# template[21,:,:,:] = transform_debug(torch.flip(x[21,:,:], dims = (0,)))
|
311 |
+
# template[22,:,:,:] = transform_debug(torch.flip(x[22,:,:], dims = (0,)))
|
312 |
+
# template[23,:,:,:] = transform_debug(torch.flip(x[23,:,:], dims = (0,)))
|
313 |
+
# template[24,:,:,:] = transform_debug(torch.flip(x[24,:,:], dims = (0,)))
|
314 |
+
# template[25,:,:,:] = transform_debug(torch.flip(x[25,:,:], dims = (0,)))
|
315 |
+
# template[26,:,:,:] = transform_debug(torch.flip(x[26,:,:], dims = (0,)))
|
316 |
+
# template[27,:,:,:] = transform_debug(torch.flip(x[27,:,:], dims = (0,)))
|
317 |
+
# template[28,:,:,:] = transform_debug(torch.flip(x[28,:,:], dims = (0,)))
|
318 |
+
# template[29,:,:,:] = transform_debug(torch.flip(x[29,:,:], dims = (0,)))
|
319 |
+
# template[30,:,:,:] = transform_debug(torch.flip(x[30,:,:], dims = (0,)))
|
320 |
+
# template[31,:,:,:] = transform_debug(torch.flip(x[31,:,:], dims = (0,)))
|
321 |
+
# return template
|
322 |
+
|
323 |
+
# def transform_debug(data):
|
324 |
+
# copy_data = data.detach().clone()
|
325 |
+
# rms = torch.std(data)
|
326 |
+
# mean = torch.mean(data)
|
327 |
+
# masks_rms = [-1, 5]
|
328 |
+
# new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
|
329 |
+
# new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
|
330 |
+
# for i in range(1, len(masks_rms)+1):
|
331 |
+
# scale = masks_rms[i-1]
|
332 |
+
# copy_data = data.detach().clone()
|
333 |
+
# if scale < 0:
|
334 |
+
# ind = copy_data < abs(scale) * rms + mean
|
335 |
+
# copy_data[ind] = 0
|
336 |
+
# else:
|
337 |
+
# ind = copy_data > (scale) * rms + mean
|
338 |
+
# copy_data[ind] = 0
|
339 |
+
# new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
|
340 |
+
# new_data = new_data.type(torch.float32)
|
341 |
+
# slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
|
342 |
+
# new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
|
343 |
+
# new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
|
344 |
+
# return new_data
|
345 |
+
|
346 |
+
def renorm_batched(data):
|
347 |
+
mins = torch.amin(data, (-2, -1))
|
348 |
+
mins = mins.unsqueeze(1).unsqueeze(2)
|
349 |
+
mins = mins.expand(data.shape[0], 192, 2048)
|
350 |
+
shifted = data - mins
|
351 |
+
maxs = torch.amax(shifted, (-2, -1))
|
352 |
+
maxs = maxs.unsqueeze(1).unsqueeze(2)
|
353 |
+
maxs = maxs.expand(data.shape[0], 192, 2048)
|
354 |
+
shifted = shifted/maxs
|
355 |
+
return shifted
|
356 |
+
|
357 |
+
|
358 |
+
def transform_mask(data):
|
359 |
+
copy_data = deepcopy(data)
|
360 |
+
shift = copy_data - copy_data.min()
|
361 |
+
normalized_data = shift / shift.max()
|
362 |
+
new_data = np.zeros((3, data.shape[0], data.shape[1]))
|
363 |
+
for i in range(3):
|
364 |
+
new_data[i,:,:] = normalized_data
|
365 |
+
new_data = new_data.astype(np.float32)
|
366 |
+
return new_data
|
367 |
+
|
368 |
+
|
369 |
+
#Function to Convert to ONNX
|
370 |
+
def Convert_ONNX(model, saveloc, input_data_mock):
|
371 |
+
print("Saving to ONNX")
|
372 |
+
# set the model to inference mode
|
373 |
+
model.eval()
|
374 |
+
|
375 |
+
# Let's create a dummy input tensor
|
376 |
+
dummy_input = torch.autograd.Variable(input_data_mock)
|
377 |
+
|
378 |
+
# Export the model
|
379 |
+
torch.onnx.export(model, # model being run
|
380 |
+
dummy_input, # model input (or a tuple for multiple inputs)
|
381 |
+
saveloc, # where to save the model
|
382 |
+
input_names = ['modelInput'], # the model's input names
|
383 |
+
output_names = ['modelOutput'], # the model's output names
|
384 |
+
dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes
|
385 |
+
'modelOutput' : {0 : 'batch_size'}} )
|
386 |
+
print(" ")
|
387 |
+
print('Model has been converted to ONNX')
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
|
models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from copy import deepcopy
|
7 |
+
from blimpy import Waterfall
|
8 |
+
from tqdm import tqdm
|
9 |
+
from copy import deepcopy
|
10 |
+
from sigpyproc.readers import FilReader
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def renorm_batched(data):
|
15 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
16 |
+
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
|
17 |
+
standardized_data = (data - mean) / std
|
18 |
+
return standardized_data
|
19 |
+
|
20 |
+
def transform_batched(data):
|
21 |
+
copy_data = data.detach().clone()
|
22 |
+
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
|
23 |
+
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
|
24 |
+
masks_rms = [-1, 5]
|
25 |
+
|
26 |
+
# Prepare the new_data tensor
|
27 |
+
num_masks = len(masks_rms) + 1
|
28 |
+
new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
|
29 |
+
|
30 |
+
# First layer: Apply renorm(log10(copy_data + epsilon))
|
31 |
+
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
|
32 |
+
for i, scale in enumerate(masks_rms, start=1):
|
33 |
+
copy_data = data.detach().clone()
|
34 |
+
|
35 |
+
# Apply masking based on the scale
|
36 |
+
if scale < 0:
|
37 |
+
ind = copy_data < abs(scale) * rms + mean
|
38 |
+
else:
|
39 |
+
ind = copy_data > scale * rms + mean
|
40 |
+
copy_data[ind] = 0
|
41 |
+
|
42 |
+
# Renormalize and log10 transform
|
43 |
+
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
|
44 |
+
|
45 |
+
# Convert to float32
|
46 |
+
new_data = new_data.type(torch.float32)
|
47 |
+
|
48 |
+
# Chunk along the last dimension and stack
|
49 |
+
slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
|
50 |
+
new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
|
51 |
+
new_data = torch.swapaxes(new_data, 0,1)
|
52 |
+
# Reshape into final format
|
53 |
+
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
|
54 |
+
return new_data
|
55 |
+
|
56 |
+
class preproc_flip(nn.Module):
|
57 |
+
def forward(self, x, flip=True):
|
58 |
+
template = transform_batched(torch.flip(x, dims = (-2,)))
|
59 |
+
return template
|
60 |
+
|
61 |
+
class preproc(nn.Module):
|
62 |
+
def forward(self, x, flip=True):
|
63 |
+
template = transform_batched(x)
|
64 |
+
return template
|
65 |
+
|
models/HITS-FEB-10.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ce8b602cef03e22c666cdea792741411f623fb5eb0a254ef1ffd9a32864d754
|
3 |
+
size 270858960
|
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739230556_9.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7cc8774df726c0db3358653dc48894f70f2964cac7604dc5319b44f8f6340b71
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739230556_9.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739231399_1.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f66e781e7776a11067d0b103b36e9f42b7130d5297ab8936a814af659e125524
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739231399_1.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739231802_11.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9edde89b6a0526420dcff5e35516979a20f379fac7a2c98d59b9dae897bf4426
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739231802_11.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739234628_13.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:550f498a863202da1dad26237f7daf0a4f347e7fecde3650e9abee7611994078
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739234628_13.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739234628_14.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b49681cd9e55c1b3f81be0d2e9c04aacece284581b5b534cf8531ef5464b084
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739234628_14.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739235333_29.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04f284312e3c2d80e49d617bec875f1b6664bbe21e1418ba9fee8f0cafe70e24
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739235333_29.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_100000000_1739235841_12.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5b8008a6d13ef7b36fe03fd9dab1e58a6bbb770b433ffb77a3eb11a5033a09f
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_100000000_1739235841_12.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_50233055_1739232802_29.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2d0a49b0bf3825cfbbdbd91c465764f82b326e172c378685c9de87a44f1296e
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_50233055_1739232802_29.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_52111435_1739229641_28.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caa8876c4c3185ad998f1a61a2b4176dec62d5ad38e3ab585aa9ff730848f83f
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_52111435_1739229641_28.png
ADDED
![]() |
Git LFS Details
|
models/HITS-FEB-10/hit_52550001_1739233595_4.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a69ae7f971552525b8beac6adee0baa0eb066158fa73e2748265e7a803f8c9f
|
3 |
+
size 1572992
|
models/HITS-FEB-10/hit_52550001_1739233595_4.png
ADDED
![]() |
Git LFS Details
|