peterma02 commited on
Commit
f3972ea
·
verified ·
1 Parent(s): 25dedb0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +65 -0
  2. models/.ipynb_checkpoints/Untitled-checkpoint.ipynb +766 -0
  3. models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb +0 -0
  4. models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb +0 -0
  5. models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb +0 -0
  6. models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb +0 -0
  7. models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb +0 -0
  8. models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb +305 -0
  9. models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb +485 -0
  10. models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb +372 -0
  11. models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb +483 -0
  12. models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb +323 -0
  13. models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb +500 -0
  14. models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb +460 -0
  15. models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb +0 -0
  16. models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb +326 -0
  17. models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb +3 -0
  18. models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb +0 -0
  19. models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb +273 -0
  20. models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb +0 -0
  21. models/.ipynb_checkpoints/resnet_model-checkpoint.py +160 -0
  22. models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py +166 -0
  23. models/.ipynb_checkpoints/train-checkpoint.py +105 -0
  24. models/.ipynb_checkpoints/train-mask-8-checkpoint.py +103 -0
  25. models/.ipynb_checkpoints/train-mask-checkpoint.py +104 -0
  26. models/.ipynb_checkpoints/utils-checkpoint.py +393 -0
  27. models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py +65 -0
  28. models/HITS-FEB-10.zip +3 -0
  29. models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png +3 -0
  30. models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png +3 -0
  31. models/HITS-FEB-10/hit_100000000_1739230556_9.npy +3 -0
  32. models/HITS-FEB-10/hit_100000000_1739230556_9.png +3 -0
  33. models/HITS-FEB-10/hit_100000000_1739231399_1.npy +3 -0
  34. models/HITS-FEB-10/hit_100000000_1739231399_1.png +3 -0
  35. models/HITS-FEB-10/hit_100000000_1739231802_11.npy +3 -0
  36. models/HITS-FEB-10/hit_100000000_1739231802_11.png +3 -0
  37. models/HITS-FEB-10/hit_100000000_1739234628_13.npy +3 -0
  38. models/HITS-FEB-10/hit_100000000_1739234628_13.png +3 -0
  39. models/HITS-FEB-10/hit_100000000_1739234628_14.npy +3 -0
  40. models/HITS-FEB-10/hit_100000000_1739234628_14.png +3 -0
  41. models/HITS-FEB-10/hit_100000000_1739235333_29.npy +3 -0
  42. models/HITS-FEB-10/hit_100000000_1739235333_29.png +3 -0
  43. models/HITS-FEB-10/hit_100000000_1739235841_12.npy +3 -0
  44. models/HITS-FEB-10/hit_100000000_1739235841_12.png +3 -0
  45. models/HITS-FEB-10/hit_50233055_1739232802_29.npy +3 -0
  46. models/HITS-FEB-10/hit_50233055_1739232802_29.png +3 -0
  47. models/HITS-FEB-10/hit_52111435_1739229641_28.npy +3 -0
  48. models/HITS-FEB-10/hit_52111435_1739229641_28.png +3 -0
  49. models/HITS-FEB-10/hit_52550001_1739233595_4.npy +3 -0
  50. models/HITS-FEB-10/hit_52550001_1739233595_4.png +3 -0
.gitattributes CHANGED
@@ -36,3 +36,68 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
37
  accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
38
  accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
37
  accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
38
  accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
39
+ models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb filter=lfs diff=lfs merge=lfs -text
40
+ models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png filter=lfs diff=lfs merge=lfs -text
41
+ models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png filter=lfs diff=lfs merge=lfs -text
42
+ models/HITS-FEB-10/hit_100000000_1739230556_9.png filter=lfs diff=lfs merge=lfs -text
43
+ models/HITS-FEB-10/hit_100000000_1739231399_1.png filter=lfs diff=lfs merge=lfs -text
44
+ models/HITS-FEB-10/hit_100000000_1739231802_11.png filter=lfs diff=lfs merge=lfs -text
45
+ models/HITS-FEB-10/hit_100000000_1739234628_13.png filter=lfs diff=lfs merge=lfs -text
46
+ models/HITS-FEB-10/hit_100000000_1739234628_14.png filter=lfs diff=lfs merge=lfs -text
47
+ models/HITS-FEB-10/hit_100000000_1739235333_29.png filter=lfs diff=lfs merge=lfs -text
48
+ models/HITS-FEB-10/hit_100000000_1739235841_12.png filter=lfs diff=lfs merge=lfs -text
49
+ models/HITS-FEB-10/hit_50233055_1739232802_29.png filter=lfs diff=lfs merge=lfs -text
50
+ models/HITS-FEB-10/hit_52111435_1739229641_28.png filter=lfs diff=lfs merge=lfs -text
51
+ models/HITS-FEB-10/hit_52550001_1739233595_4.png filter=lfs diff=lfs merge=lfs -text
52
+ models/HITS-FEB-10/hit_57096732_1739234611_11.png filter=lfs diff=lfs merge=lfs -text
53
+ models/HITS-FEB-10/hit_57521253_1739232651_11.png filter=lfs diff=lfs merge=lfs -text
54
+ models/HITS-FEB-10/hit_58032264_1739232672_22.png filter=lfs diff=lfs merge=lfs -text
55
+ models/HITS-FEB-10/hit_58165746_1739230560_10.png filter=lfs diff=lfs merge=lfs -text
56
+ models/HITS-FEB-10/hit_62177701_1739230031_23.png filter=lfs diff=lfs merge=lfs -text
57
+ models/HITS-FEB-10/hit_64237575_1739233604_2.png filter=lfs diff=lfs merge=lfs -text
58
+ models/HITS-FEB-10/hit_67249737_1739231769_23.png filter=lfs diff=lfs merge=lfs -text
59
+ models/HITS-FEB-10/hit_71882680_1739230648_29.png filter=lfs diff=lfs merge=lfs -text
60
+ models/HITS-FEB-10/hit_72677566_1739232113_26.png filter=lfs diff=lfs merge=lfs -text
61
+ models/HITS-FEB-10/hit_74160848_1739234611_5.png filter=lfs diff=lfs merge=lfs -text
62
+ models/HITS-FEB-10/hit_75109552_1739231790_10.png filter=lfs diff=lfs merge=lfs -text
63
+ models/HITS-FEB-10/hit_79640130_1739231950_5.png filter=lfs diff=lfs merge=lfs -text
64
+ models/HITS-FEB-10/hit_81910572_1739231764_29.png filter=lfs diff=lfs merge=lfs -text
65
+ models/HITS-FEB-10/hit_83296520_1739233906_31.png filter=lfs diff=lfs merge=lfs -text
66
+ models/HITS-FEB-10/hit_84171229_1739231886_5.png filter=lfs diff=lfs merge=lfs -text
67
+ models/HITS-FEB-10/hit_84411238_1739233784_3.png filter=lfs diff=lfs merge=lfs -text
68
+ models/HITS-FEB-10/hit_87957837_1739232059_29.png filter=lfs diff=lfs merge=lfs -text
69
+ models/HITS-FEB-10/hit_88699241_1739235027_30.png filter=lfs diff=lfs merge=lfs -text
70
+ models/HITS-FEB-10/hit_90233018_1739235740_0.png filter=lfs diff=lfs merge=lfs -text
71
+ models/HITS-FEB-10/hit_93808281_1739233104_14.png filter=lfs diff=lfs merge=lfs -text
72
+ models/HITS-FEB-10/hit_93821122_1739232374_4.png filter=lfs diff=lfs merge=lfs -text
73
+ models/HITS-FEB-10/hit_94705507_1739231215_6.png filter=lfs diff=lfs merge=lfs -text
74
+ models/HITS-FEB-10/hit_95400645_1739230724_21.png filter=lfs diff=lfs merge=lfs -text
75
+ models/HITS-FEB-10/hit_95544329_1739233784_17.png filter=lfs diff=lfs merge=lfs -text
76
+ models/HITS-FEB-10/hit_96119644_1739232285_1.png filter=lfs diff=lfs merge=lfs -text
77
+ models/HITS-FEB-10/hit_96369222_1739233852_3.png filter=lfs diff=lfs merge=lfs -text
78
+ models/HITS-FEB-10/hit_96470689_1739233843_18.png filter=lfs diff=lfs merge=lfs -text
79
+ models/HITS-FEB-10/hit_96471079_1739232857_0.png filter=lfs diff=lfs merge=lfs -text
80
+ models/HITS-FEB-10/hit_96497133_1739233692_27.png filter=lfs diff=lfs merge=lfs -text
81
+ models/HITS-FEB-10/hit_98139322_1739231500_5.png filter=lfs diff=lfs merge=lfs -text
82
+ models/HITS-FEB-10/hit_98582385_1739232894_12.png filter=lfs diff=lfs merge=lfs -text
83
+ models/HITS-FEB-10/hit_98697207_1739229930_16.png filter=lfs diff=lfs merge=lfs -text
84
+ models/HITS-FEB-10/hit_99172221_1739232365_29.png filter=lfs diff=lfs merge=lfs -text
85
+ models/HITS-FEB-10/hit_99314646_1739233667_1.png filter=lfs diff=lfs merge=lfs -text
86
+ models/HITS-FEB-10/hit_99756914_1739230207_8.png filter=lfs diff=lfs merge=lfs -text
87
+ models/HITS-FEB-10/hit_99939211_1739233705_24.png filter=lfs diff=lfs merge=lfs -text
88
+ models/HITS-FEB-10/hit_99972041_1739234066_9.png filter=lfs diff=lfs merge=lfs -text
89
+ models/HITS-FEB-10/hit_99977277_1739231773_10.png filter=lfs diff=lfs merge=lfs -text
90
+ models/HITS-FEB-10/hit_99986237_1739234058_8.png filter=lfs diff=lfs merge=lfs -text
91
+ models/HITS-FEB-10/hit_99998032_1739232348_29.png filter=lfs diff=lfs merge=lfs -text
92
+ models/HITS-FEB-10/hit_99999287_1739233700_3.png filter=lfs diff=lfs merge=lfs -text
93
+ models/HITS-FEB-10/hit_99999351_1739235476_0.png filter=lfs diff=lfs merge=lfs -text
94
+ models/HITS-FEB-10/hit_99999979_1739232399_15.png filter=lfs diff=lfs merge=lfs -text
95
+ models/combined_frb_detections.pdf filter=lfs diff=lfs merge=lfs -text
96
+ models/combined_frb_detections.png filter=lfs diff=lfs merge=lfs -text
97
+ models/hits.png filter=lfs diff=lfs merge=lfs -text
98
+ models/hits_crab.pdf filter=lfs diff=lfs merge=lfs -text
99
+ models/models_mask/accuracy_vs_all_parameters.png filter=lfs diff=lfs merge=lfs -text
100
+ models/models_mask/accuracy_vs_dm.png filter=lfs diff=lfs merge=lfs -text
101
+ models/models_mask/accuracy_vs_snr.png filter=lfs diff=lfs merge=lfs -text
102
+ models/recover_crab.ipynb filter=lfs diff=lfs merge=lfs -text
103
+ models/recover_new_crab-debug.ipynb filter=lfs diff=lfs merge=lfs -text
models/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "id": "5577ffee-a5c9-4648-8849-95c2c7ebcebe",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
11
+ "from utils_batched_preproc import transform_batched, preproc_flip\n",
12
+ "from torch.utils.data import Dataset, DataLoader\n",
13
+ "import torch\n",
14
+ "import numpy as np\n",
15
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
16
+ "import torch\n",
17
+ "import torch.nn as nn\n",
18
+ "import torch.optim as optim\n",
19
+ "import tqdm \n",
20
+ "import torch.nn.functional as F\n",
21
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
22
+ "import pickle\n",
23
+ "import torch\n",
24
+ "from functorch import vmap"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 3,
30
+ "id": "f1180d60-83e7-47ca-aa09-58d26af3c706",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# def renorm_batched(data):\n",
35
+ "# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
36
+ "# std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n",
37
+ "# standardized_data = (data - mean) / std\n",
38
+ "# return standardized_data\n",
39
+ "\n",
40
+ "# def transform_batched(data):\n",
41
+ "# copy_data = data.detach().clone()\n",
42
+ "# rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std\n",
43
+ "# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean\n",
44
+ "# masks_rms = [-1, 5]\n",
45
+ " \n",
46
+ "# # Prepare the new_data tensor\n",
47
+ "# num_masks = len(masks_rms) + 1\n",
48
+ "# new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)\n",
49
+ "\n",
50
+ "# # First layer: Apply renorm(log10(copy_data + epsilon))\n",
51
+ "# new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
52
+ "# for i, scale in enumerate(masks_rms, start=1):\n",
53
+ "# copy_data = data.detach().clone()\n",
54
+ " \n",
55
+ "# # Apply masking based on the scale\n",
56
+ "# if scale < 0:\n",
57
+ "# ind = copy_data < abs(scale) * rms + mean\n",
58
+ "# else:\n",
59
+ "# ind = copy_data > scale * rms + mean\n",
60
+ "# copy_data[ind] = 0\n",
61
+ " \n",
62
+ "# # Renormalize and log10 transform\n",
63
+ "# new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))\n",
64
+ " \n",
65
+ "# # Convert to float32\n",
66
+ "# new_data = new_data.type(torch.float32)\n",
67
+ "\n",
68
+ "# # Chunk along the last dimension and stack\n",
69
+ "# slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing\n",
70
+ "# new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1\n",
71
+ "# new_data = torch.swapaxes(new_data, 0,1)\n",
72
+ "# # Reshape into final format\n",
73
+ "# new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions\n",
74
+ "# return new_data\n",
75
+ "\n",
76
+ "\n"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 11,
82
+ "id": "81cc81a8-ecef-43ef-a5cd-c35765384812",
83
+ "metadata": {
84
+ "scrolled": true
85
+ },
86
+ "outputs": [
87
+ {
88
+ "name": "stdout",
89
+ "output_type": "stream",
90
+ "text": [
91
+ "num params encoder 50840\n"
92
+ ]
93
+ },
94
+ {
95
+ "name": "stderr",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "/tmp/ipykernel_19147/1680389579.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
99
+ " model.load_state_dict(torch.load(model_path))\n"
100
+ ]
101
+ },
102
+ {
103
+ "data": {
104
+ "text/plain": [
105
+ "DataParallel(\n",
106
+ " (module): ResNet(\n",
107
+ " (relu): ReLU()\n",
108
+ " (conv1): Sequential(\n",
109
+ " (0): Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
110
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
111
+ " (2): ReLU()\n",
112
+ " )\n",
113
+ " (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=1, dilation=1, ceil_mode=False)\n",
114
+ " (layer0): Sequential(\n",
115
+ " (0): ResidualBlock(\n",
116
+ " (conv1): Sequential(\n",
117
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
118
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
119
+ " (2): ReLU()\n",
120
+ " )\n",
121
+ " (conv2): Sequential(\n",
122
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
123
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
124
+ " )\n",
125
+ " (relu): ReLU()\n",
126
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
127
+ " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
128
+ " )\n",
129
+ " (1): ResidualBlock(\n",
130
+ " (conv1): Sequential(\n",
131
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
132
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
133
+ " (2): ReLU()\n",
134
+ " )\n",
135
+ " (conv2): Sequential(\n",
136
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
137
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
138
+ " )\n",
139
+ " (relu): ReLU()\n",
140
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
141
+ " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
142
+ " )\n",
143
+ " (2): ResidualBlock(\n",
144
+ " (conv1): Sequential(\n",
145
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
146
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
147
+ " (2): ReLU()\n",
148
+ " )\n",
149
+ " (conv2): Sequential(\n",
150
+ " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
151
+ " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
152
+ " )\n",
153
+ " (relu): ReLU()\n",
154
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
155
+ " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
156
+ " )\n",
157
+ " )\n",
158
+ " (layer1): Sequential(\n",
159
+ " (0): ResidualBlock(\n",
160
+ " (conv1): Sequential(\n",
161
+ " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
162
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
163
+ " (2): ReLU()\n",
164
+ " )\n",
165
+ " (conv2): Sequential(\n",
166
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
167
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
168
+ " )\n",
169
+ " (downsample): Sequential(\n",
170
+ " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))\n",
171
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
172
+ " )\n",
173
+ " (relu): ReLU()\n",
174
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
175
+ " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
176
+ " )\n",
177
+ " (1): ResidualBlock(\n",
178
+ " (conv1): Sequential(\n",
179
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
180
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
181
+ " (2): ReLU()\n",
182
+ " )\n",
183
+ " (conv2): Sequential(\n",
184
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
185
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
186
+ " )\n",
187
+ " (relu): ReLU()\n",
188
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
189
+ " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
190
+ " )\n",
191
+ " (2): ResidualBlock(\n",
192
+ " (conv1): Sequential(\n",
193
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
194
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
195
+ " (2): ReLU()\n",
196
+ " )\n",
197
+ " (conv2): Sequential(\n",
198
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
199
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
200
+ " )\n",
201
+ " (relu): ReLU()\n",
202
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
203
+ " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
204
+ " )\n",
205
+ " (3): ResidualBlock(\n",
206
+ " (conv1): Sequential(\n",
207
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
208
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
209
+ " (2): ReLU()\n",
210
+ " )\n",
211
+ " (conv2): Sequential(\n",
212
+ " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
213
+ " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
214
+ " )\n",
215
+ " (relu): ReLU()\n",
216
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
217
+ " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
218
+ " )\n",
219
+ " )\n",
220
+ " (layer2): Sequential(\n",
221
+ " (0): ResidualBlock(\n",
222
+ " (conv1): Sequential(\n",
223
+ " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
224
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
225
+ " (2): ReLU()\n",
226
+ " )\n",
227
+ " (conv2): Sequential(\n",
228
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
229
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
230
+ " )\n",
231
+ " (downsample): Sequential(\n",
232
+ " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))\n",
233
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
234
+ " )\n",
235
+ " (relu): ReLU()\n",
236
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
237
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
238
+ " )\n",
239
+ " (1): ResidualBlock(\n",
240
+ " (conv1): Sequential(\n",
241
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
242
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
243
+ " (2): ReLU()\n",
244
+ " )\n",
245
+ " (conv2): Sequential(\n",
246
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
247
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
248
+ " )\n",
249
+ " (relu): ReLU()\n",
250
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
251
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
252
+ " )\n",
253
+ " (2): ResidualBlock(\n",
254
+ " (conv1): Sequential(\n",
255
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
256
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
257
+ " (2): ReLU()\n",
258
+ " )\n",
259
+ " (conv2): Sequential(\n",
260
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
261
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
262
+ " )\n",
263
+ " (relu): ReLU()\n",
264
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
265
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
266
+ " )\n",
267
+ " (3): ResidualBlock(\n",
268
+ " (conv1): Sequential(\n",
269
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
270
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
271
+ " (2): ReLU()\n",
272
+ " )\n",
273
+ " (conv2): Sequential(\n",
274
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
275
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
276
+ " )\n",
277
+ " (relu): ReLU()\n",
278
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
279
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " )\n",
281
+ " (4): ResidualBlock(\n",
282
+ " (conv1): Sequential(\n",
283
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
284
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
285
+ " (2): ReLU()\n",
286
+ " )\n",
287
+ " (conv2): Sequential(\n",
288
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
289
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
290
+ " )\n",
291
+ " (relu): ReLU()\n",
292
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
293
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
294
+ " )\n",
295
+ " (5): ResidualBlock(\n",
296
+ " (conv1): Sequential(\n",
297
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
298
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
299
+ " (2): ReLU()\n",
300
+ " )\n",
301
+ " (conv2): Sequential(\n",
302
+ " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
303
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
304
+ " )\n",
305
+ " (relu): ReLU()\n",
306
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
307
+ " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
308
+ " )\n",
309
+ " )\n",
310
+ " (layer3): Sequential(\n",
311
+ " (0): ResidualBlock(\n",
312
+ " (conv1): Sequential(\n",
313
+ " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
314
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
315
+ " (2): ReLU()\n",
316
+ " )\n",
317
+ " (conv2): Sequential(\n",
318
+ " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
319
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
320
+ " )\n",
321
+ " (downsample): Sequential(\n",
322
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n",
323
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
324
+ " )\n",
325
+ " (relu): ReLU()\n",
326
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
327
+ " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
328
+ " )\n",
329
+ " (1): ResidualBlock(\n",
330
+ " (conv1): Sequential(\n",
331
+ " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
332
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
333
+ " (2): ReLU()\n",
334
+ " )\n",
335
+ " (conv2): Sequential(\n",
336
+ " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
337
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
338
+ " )\n",
339
+ " (relu): ReLU()\n",
340
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
341
+ " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
342
+ " )\n",
343
+ " (2): ResidualBlock(\n",
344
+ " (conv1): Sequential(\n",
345
+ " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
346
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
347
+ " (2): ReLU()\n",
348
+ " )\n",
349
+ " (conv2): Sequential(\n",
350
+ " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
351
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
352
+ " )\n",
353
+ " (relu): ReLU()\n",
354
+ " (dropout1): Dropout(p=0.5, inplace=False)\n",
355
+ " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
356
+ " )\n",
357
+ " )\n",
358
+ " (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
359
+ " (fc): Linear(in_features=39424, out_features=2, bias=True)\n",
360
+ " (dropout1): Dropout(p=0.3, inplace=False)\n",
361
+ " (encoder): Sequential(\n",
362
+ " (0): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
363
+ " (1): ReLU(inplace=True)\n",
364
+ " (2): Dropout(p=0.3, inplace=False)\n",
365
+ " (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
366
+ " (4): ReLU(inplace=True)\n",
367
+ " (5): Dropout(p=0.3, inplace=False)\n",
368
+ " (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
369
+ " (7): ReLU(inplace=True)\n",
370
+ " (8): Dropout(p=0.3, inplace=False)\n",
371
+ " (9): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
372
+ " (10): Sigmoid()\n",
373
+ " )\n",
374
+ " )\n",
375
+ ")"
376
+ ]
377
+ },
378
+ "execution_count": 11,
379
+ "metadata": {},
380
+ "output_type": "execute_result"
381
+ }
382
+ ],
383
+ "source": [
384
+ "model_path = 'models/model-47-99.125.pt'\n",
385
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
386
+ "\n",
387
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=2).to(device)\n",
388
+ "model = nn.DataParallel(model)\n",
389
+ "model = model.to(device)\n",
390
+ "model.load_state_dict(torch.load(model_path))\n",
391
+ "model.eval()"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 12,
397
+ "id": "58b8c338-df2f-4ef0-92cf-409c9f034cab",
398
+ "metadata": {},
399
+ "outputs": [
400
+ {
401
+ "name": "stderr",
402
+ "output_type": "stream",
403
+ "text": [
404
+ "/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
405
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
406
+ ]
407
+ },
408
+ {
409
+ "name": "stdout",
410
+ "output_type": "stream",
411
+ "text": [
412
+ "tensor([[ 4.1780, -4.1750],\n",
413
+ " [ 4.6414, -4.6303],\n",
414
+ " [ 5.0103, -5.0162],\n",
415
+ " [ 4.8273, -4.8311],\n",
416
+ " [ 4.8523, -4.8661],\n",
417
+ " [ 4.8855, -4.9074],\n",
418
+ " [ 4.4973, -4.5213],\n",
419
+ " [ 5.5996, -5.6192],\n",
420
+ " [ 4.7929, -4.8116],\n",
421
+ " [ 5.5999, -5.5925],\n",
422
+ " [ 4.7918, -4.7998],\n",
423
+ " [ 4.0914, -4.0766],\n",
424
+ " [ 0.7072, -0.6955],\n",
425
+ " [ 4.7136, -4.7234],\n",
426
+ " [ 5.3918, -5.4307],\n",
427
+ " [ 4.5491, -4.5524],\n",
428
+ " [ 4.5412, -4.5391],\n",
429
+ " [ 4.6264, -4.6137],\n",
430
+ " [ 3.9378, -3.9300],\n",
431
+ " [ 5.0673, -5.0792],\n",
432
+ " [ 5.7389, -5.7330],\n",
433
+ " [ 5.2259, -5.2326],\n",
434
+ " [ 5.3856, -5.4036],\n",
435
+ " [ 5.0781, -5.1232],\n",
436
+ " [ 5.2432, -5.2584],\n",
437
+ " [ 5.8163, -5.8209],\n",
438
+ " [ 4.7730, -4.7823],\n",
439
+ " [ 5.1320, -5.1657],\n",
440
+ " [ 5.6486, -5.6485],\n",
441
+ " [ 3.7626, -3.7674],\n",
442
+ " [ 4.1834, -4.1797],\n",
443
+ " [ 4.4452, -4.4566]], device='cuda:0', grad_fn=<GatherBackward>)\n"
444
+ ]
445
+ }
446
+ ],
447
+ "source": [
448
+ "test_in = abs(torch.randn(32, 192, 2048).to(device))\n",
449
+ "results = []\n",
450
+ "for i in range(32):\n",
451
+ " results.append(transform(test_in[i,:,:]))\n",
452
+ "intermediate = torch.stack(results).cuda()\n",
453
+ "out = model(intermediate)\n",
454
+ "test_in.cpu().detach().numpy().tofile(\"input.bin\")\n",
455
+ "intermediate.cpu().detach().numpy().tofile(\"intermediate.bin\")\n",
456
+ "out.cpu().detach().numpy().tofile(\"output.bin\")\n",
457
+ "print(out)"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": 13,
463
+ "id": "ad56299a-44e4-4d6b-afcc-18a5f4cf0138",
464
+ "metadata": {},
465
+ "outputs": [
466
+ {
467
+ "ename": "NameError",
468
+ "evalue": "name 'preproc_flip' is not defined",
469
+ "output_type": "error",
470
+ "traceback": [
471
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
472
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
473
+ "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m preproc_model \u001b[38;5;241m=\u001b[39m preproc_flip()\n\u001b[1;32m 2\u001b[0m Convert_ONNX(preproc_model,\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodels_mask/preproc_flip.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m, input_data_mock\u001b[38;5;241m=\u001b[39mtest_in\u001b[38;5;241m.\u001b[39mto(device))\n",
474
+ "\u001b[0;31mNameError\u001b[0m: name 'preproc_flip' is not defined"
475
+ ]
476
+ }
477
+ ],
478
+ "source": [
479
+ "preproc_model = preproc_flip()\n",
480
+ "Convert_ONNX(preproc_model,f'models_mask/preproc_flip.onnx', input_data_mock=test_in.to(device))\n",
481
+ "# Convert_ONNX(model.module,f'models_mask/model_test.onnx', input_data_mock=intermediate.to(device))"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": 7,
487
+ "id": "30e84a9b-0d4f-4cb2-a92b-2e3f0b2ccb20",
488
+ "metadata": {},
489
+ "outputs": [
490
+ {
491
+ "data": {
492
+ "text/plain": [
493
+ "torch.Size([32, 192, 2048])"
494
+ ]
495
+ },
496
+ "execution_count": 7,
497
+ "metadata": {},
498
+ "output_type": "execute_result"
499
+ }
500
+ ],
501
+ "source": [
502
+ "test_in.shape"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": 13,
508
+ "id": "1bb26727-7914-470e-bb48-43d7ee81cb50",
509
+ "metadata": {},
510
+ "outputs": [
511
+ {
512
+ "data": {
513
+ "text/plain": [
514
+ "tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
515
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
516
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
517
+ " ...,\n",
518
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
519
+ " [0., 0., 0., ..., 0., 0., 0.],\n",
520
+ " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')"
521
+ ]
522
+ },
523
+ "execution_count": 13,
524
+ "metadata": {},
525
+ "output_type": "execute_result"
526
+ }
527
+ ],
528
+ "source": [
529
+ "import torch\n",
530
+ "torch.flip(test_in[0,:,:], dims = (0,)) - torch.flipud(test_in[0,:,:])"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 29,
536
+ "id": "aeaaab90-6a2a-4851-a1ca-28c54a446573",
537
+ "metadata": {},
538
+ "outputs": [
539
+ {
540
+ "name": "stdout",
541
+ "output_type": "stream",
542
+ "text": [
543
+ "tensor(float)\n",
544
+ "torch.float32\n",
545
+ "Input Name: modelInput\n",
546
+ "Output Name: modelOutput\n",
547
+ "[array([[ 4.3262615, -4.3409047],\n",
548
+ " [ 4.9648395, -4.968621 ],\n",
549
+ " [ 5.5126643, -5.522872 ],\n",
550
+ " [ 4.7735534, -4.8004475],\n",
551
+ " [ 4.0924144, -4.112945 ],\n",
552
+ " [ 4.588802 , -4.6043544],\n",
553
+ " [ 4.6231914, -4.617625 ],\n",
554
+ " [ 5.229881 , -5.2555394],\n",
555
+ " [ 4.877381 , -4.882144 ],\n",
556
+ " [ 5.2514744, -5.2786503],\n",
557
+ " [ 4.2948875, -4.3169603],\n",
558
+ " [ 4.5997186, -4.6177607],\n",
559
+ " [ 4.9509926, -4.9685597],\n",
560
+ " [ 4.933158 , -4.9568825],\n",
561
+ " [ 4.747336 , -4.7639017],\n",
562
+ " [ 5.020595 , -5.0202913],\n",
563
+ " [ 4.914437 , -4.9206715],\n",
564
+ " [ 5.193108 , -5.1925435],\n",
565
+ " [ 4.5233765, -4.512763 ],\n",
566
+ " [ 4.7573333, -4.762632 ],\n",
567
+ " [ 5.268702 , -5.2838397],\n",
568
+ " [ 4.857734 , -4.8605857],\n",
569
+ " [ 5.1886744, -5.2047734],\n",
570
+ " [ 5.512568 , -5.5503583],\n",
571
+ " [ 5.320961 , -5.344709 ],\n",
572
+ " [ 4.1023226, -4.1073256],\n",
573
+ " [ 5.17857 , -5.185736 ],\n",
574
+ " [ 4.997028 , -4.9933476],\n",
575
+ " [ 4.771303 , -4.767269 ],\n",
576
+ " [ 5.312805 , -5.3265243],\n",
577
+ " [ 5.0030336, -5.0492 ],\n",
578
+ " [ 5.429731 , -5.4249325]], dtype=float32)]\n"
579
+ ]
580
+ }
581
+ ],
582
+ "source": [
583
+ "import onnxruntime as ort\n",
584
+ "import onnx\n",
585
+ "\n",
586
+ "# Path to your ONNX model\n",
587
+ "model_path = \"models/model-47-99.125.onnx\"\n",
588
+ "\n",
589
+ "# Load the ONNX model\n",
590
+ "session = ort.InferenceSession(model_path)\n",
591
+ "\n",
592
+ "# Get input and output details\n",
593
+ "input_name = session.get_inputs()[0].name\n",
594
+ "output_name = session.get_outputs()[0].name\n",
595
+ "\n",
596
+ "print(session.get_inputs()[0].type)\n",
597
+ "print(test_in.dtype)\n",
598
+ "\n",
599
+ "print(f\"Input Name: {input_name}\")\n",
600
+ "print(f\"Output Name: {output_name}\")\n",
601
+ "\n",
602
+ "# Example Input Data (Replace with your actual input data)\n",
603
+ "import numpy as np\n",
604
+ "\n",
605
+ "# Perform inference\n",
606
+ "outputs = session.run([output_name], {input_name: intermediate.cpu().numpy()})\n",
607
+ "print(outputs)\n",
608
+ "\n",
609
+ "onnx_model = onnx.load(model_path)"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "execution_count": 30,
615
+ "id": "f250739d-4c8a-4752-964a-d0b929c396f4",
616
+ "metadata": {},
617
+ "outputs": [],
618
+ "source": [
619
+ "# import onnxruntime as ort\n",
620
+ "# import onnx\n",
621
+ "\n",
622
+ "# # Path to your ONNX model\n",
623
+ "# model_path = \"models_mask/preproc_test.onnx\"\n",
624
+ "\n",
625
+ "# # Load the ONNX model\n",
626
+ "# session = ort.InferenceSession(model_path)\n",
627
+ "\n",
628
+ "# # Get input and output details\n",
629
+ "# input_name = session.get_inputs()[0].name\n",
630
+ "# output_name = session.get_outputs()[0].name\n",
631
+ "\n",
632
+ "# print(session.get_inputs()[0].type)\n",
633
+ "# print(test_in.dtype)\n",
634
+ "\n",
635
+ "# print(f\"Input Name: {input_name}\")\n",
636
+ "# print(f\"Output Name: {output_name}\")\n",
637
+ "\n",
638
+ "# # Example Input Data (Replace with your actual input data)\n",
639
+ "# import numpy as np\n",
640
+ "\n",
641
+ "# # Perform inference\n",
642
+ "# outputs = session.run([output_name], {input_name: test_in.cpu().numpy()})\n",
643
+ "# print(\"Model Output:\", outputs)\n",
644
+ "\n",
645
+ "# onnx_model = onnx.load(model_path)"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "execution_count": 8,
651
+ "id": "24fed4e7-4838-44cc-9c3a-0862bdbe173a",
652
+ "metadata": {},
653
+ "outputs": [
654
+ {
655
+ "data": {
656
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAey0lEQVR4nO3df0xd9f3H8deFymXMcpURL8WCzG124o/L5Jd0dpblboQ6snZZxvaHItm6ZcFFc6NL+w+4rJMsMUiynIXtmyDZr8gaIy5zqanXH/gDQwvFVfEXjhkWvZc26r3luoBezvePxau0UHvhlvs5nOcjuX/ccw/nvDkhl2fuPedej23btgAAAAyRk+0BAAAAPok4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGCUTdkeIF2Li4t66623tHnzZnk8nmyPAwAAzoFt2zp16pRKS0uVk3P210YcFydvvfWWysrKsj0GAABYhZmZGW3duvWs6zgmTizLkmVZ+vDDDyX975crLCzM8lQAAOBcxONxlZWVafPmzZ+6rsdp360Tj8fl8/kUi8WIEwAAHCKd/9+cEAsAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMQpwAAACjECcAAMAoxAkAADBKVuJkenpajY2Nqqys1DXXXKNEIpGNMQAAgIE2ZWOnt956qw4cOKAdO3bonXfekdfrzcYYy7vbd9r9WHbmAADApdY9Tl566SVdcMEF2rFjhySpqKhovUcAAAAGS/ttneHhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo6mHnv99dd14YUXqqWlRdddd53uueeeNf0CAABgY0k7ThKJhAKBgCzLWvbxwcFBhUIhdXV1aXx8XIFAQE1NTZqdnZUkffjhh3r66af129/+ViMjIzp8+LAOHz68tt8CAABsGGnHSXNzsw4cOKA9e/Ys+3hPT4/27t2r9vZ2VVZWqq+vTwUFBerv75ckXXrppaqpqVFZWZm8Xq927dqliYmJFfc3Pz+veDy+5AYAADaujF6ts7CwoLGxMQWDwY93kJOjYDCokZERSVJtba1mZ2f17rvvanFxUcPDw7ryyitX3GZ3d7d8Pl/qVlZWlsmRAQCAYTIaJydPnlQymZTf71+y3O/3KxKJSJI2bdqke+65R1/72td07bXX6ktf+pK+9a1vrbjN/fv3KxaLpW4zMzOZHBkAABgmK5cSNzc3q7m5+ZzW9Xq9Zl1qDAAAzquMvnJSXFys3NxcRaPRJcuj0ahKSkrWtG3LslRZWana2to1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw2E1NDSsadsdHR2anJzUkSNH1jomAAAwWNpv68zNzWlqaip1f3p6WhMTEyoqKlJ5eblCoZDa2tpUU1Ojuro69fb2KpFIqL29PaODAwCAjSntODl69KgaGxtT90OhkCSpra1NAwMDam1t1YkTJ9TZ2alIJKKqqiodOnTojJNkAQAAluOxbdvO9hDnwrIsWZalZDKp1157TbFYTIWFhZnfEd+tAwBAxsXjcfl8vnP6/52VbyVeDc45AQDAHRwTJwAAwB0cEydcSgwAgDs4Jk54WwcAAHdwTJwAAAB3IE4AAIBRiBMAAGAUx8QJJ8QCAOAOjokTTogFAMAdHBMnAADAHYgTAABgFOIEAAAYxTFxwgmxAAC4g2PihBNiAQBwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAURwTJ1ytAwCAOzgmTrhaBwAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFMfECR/CBgCAOzgmTvgQNgAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwsfXAwDgDo6JEz6+HgAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABhlUzZ2WlFRocLCQuXk5Ojiiy/WE088kY0xAACAgbISJ5L03HPP6cILL8zW7gEAgKF4WwcAABgl7TgZHh5WS0uLSktL5fF4NDQ0dMY6lmWpoqJC+fn5qq+v1+jo6JLHPR6PbrzxRtXW1urPf/7zqocHAAAbT9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsbGqdZ555RmNjY/rb3/6me+65R//85z9X/xsAAIANJe04aW5u1oEDB7Rnz55lH+/p6dHevXvV3t6uyspK9fX1qaCgQP39/al1Lr30UknSli1btGvXLo2Pj6+4v/n5ecXj8SU3AACwcWX0nJOFhQWNjY0pGAx+vIOcHAWDQY2MjEj63ysvp06dkiTNzc3p8ccf11VXXbXiNru7u+Xz+VK3srKyTI4MAAAMk9E4OXnypJLJpPx+/5Llfr9fkUhEkhSNRnXDDTcoEAjo+uuv1y233KLa2toVt7l//37FYrHUbWZmJpMjAwAAw6z7pcSXX365XnjhhXNe3+v1yuv1nseJAACASTL6yklxcbFyc3MVjUaXLI9GoyopKVnTti3LUmVl5VlfZQEAAM6X0TjJy8tTdXW1wuFwatni4qLC4bAaGhrWtO2Ojg5NTk7qyJEjax0TAAAYLO23debm5jQ1NZW6Pz09rYmJCRUVFam8vFyhUEhtbW2qqalRXV2dent7lUgk1N7entHBAQDAxpR2nBw9elSNjY2p+6FQSJLU1tamgYEBtba26sSJE+rs7FQkElFVVZUOHTp0xkmy6bIsS5ZlKZlMrmk7AADAbB7btu1sD5GOeDwun8+nWCymwsLCzO/gbt9p92OZ3wcAAC6Tzv9vvlsHAAAYhTgBAABGcUyccCkxAADu4Jg44VJiAADcwTFxAgAA3IE4AQAARnFMnHDOCQAA7uCYOOGcEwAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTHxAlX6wAA4A6OiROu1gEAwB0cEycAAMAdiBMAAGAU4gQAABiFOAEAAEZxTJxwtQ4AAO7gmDjhah0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEZxTJzwOScAALiDY+KEzzkBAMAdHBMnAADAHYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTHxAnfrQMAgDs4Jk74bh0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABglKzFyfvvv6/LLrtMd955Z7ZGAAAABspanPzqV7/S9ddfn63dAwAAQ2UlTl5//XW98soram5uzsbuAQCAwdKOk+HhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo4uefzOO+9Ud3f3qocGAAAbV9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsrCTp4Ycf1hVXXKErrrhibZMDAIANaVO6P9Dc3HzWt2N6enq0d+9etbe3S5L6+vr0yCOPqL+/X/v27dPzzz+vBx54QAcPHtTc3Jw++OADFRYWqrOzc9ntzc/Pa35+PnU/Ho+nOzIAAHCQjJ5zsrCwoLGxMQWDwY93kJOjYDCokZERSVJ3d7dmZmb073//W/fee6/27t27Yph8tL7P50vdysrKMjkyAAAwTEbj5OTJk0omk/L7/UuW+/1+RSKRVW1z//79isViqdvMzEwmRgUAAIZK+22dTLr11ls/dR2v1yuv13v+hwEAAEbI6CsnxcXFys3NVTQaXbI8Go2qpKRkTdu2LEuVlZWqra1d03YAAIDZMhoneXl5qq6uVjgcTi1bXFxUOBxWQ0PDmrbd0dGhyclJHTlyZK1jAgAAg6X9ts7c3JympqZS96enpzUxMaGioiKVl5crFAqpra1NNTU1qqurU29vrxKJROrqHQAAgLNJO06OHj2qxsbG1P1QKCRJamtr08DAgFpbW3XixAl1dnYqEomoqqpKhw4dOuMk2XRZliXLspRMJte0HQAAYDaPbdt2todIRzwel8/nUywWU2FhYeZ3cLfvtPuxzO8DAACXSef/d9a++A8AAGA5xAkAADCKY+KES4kBAHAHx8QJlxIDAOAOjokTAADgDsQJAAAwimPihHNOAABwB8fECeecAADgDo6JEwAA4A7ECQAAMApxAgAAjOKYOOGEWAAA3MExccIJsQAAuINj4gQAALgDcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oRLiQEAcAfHxAmXEgMA4A6OiRMAAOAOxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oTPOQEAwB0cEyd8zgkAAO7gmDgBAADuQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjOKYOOHj6wEAcAfHxAkfXw8AgDs4Jk4AAIA7ECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMsu5x8t5776mmpkZVVVW6+uqr9X//93/rPQIAADDYpvXe4ebNmzU8PKyCggIlEgldffXV+s53vqPPfe5z6z0KAAAw0Lq/cpKbm6uCggJJ0vz8vGzblm3b6z0GAAAwVNpxMjw8rJaWFpWWlsrj8WhoaOiMdSzLUkVFhfLz81VfX6/R0dElj7/33nsKBALaunWr7rrrLhUXF6/6FwAAABtL2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Oxsap2LLrpIL7zwgqanp/WXv/xF0Wh09b8BAADYUNKOk+bmZh04cEB79uxZ9vGenh7t3btX7e3tqqysVF9fnwoKCtTf33/Gun6/X4FAQE8//fSK+5ufn1c8Hl9yAwAAG1dGzzlZWFjQ2NiYgsHgxzvIyVEwGNTIyIgkKRqN6tSpU5KkWCym4eFhbdu2bcVtdnd3y+fzpW5lZWWZHBkAABgmo3Fy8uRJJZNJ+f3+Jcv9fr8ikYgk6c0339SOHTsUCAS0Y8cO/exnP9M111yz4jb379+vWCyWus3MzGRyZAAAYJh1v5S4rq5OExMT57y+1+uV1+s9fwMBAACjZPSVk+LiYuXm5p5xgms0GlVJScmatm1ZliorK1VbW7um7QAAALNlNE7y8vJUXV2tcDicWra4uKhwOKyGhoY1bbujo0OTk5M6cuTIWscEAAAGS/ttnbm5OU1NTaXuT09Pa2JiQkVFRSovL1coFFJbW5tqampUV1en3t5eJRIJtbe3Z3RwAACwMaUdJ0ePHlVjY2PqfigUkiS1tbVpYGBAra2tOnHihDo7OxWJRFRVVaVDhw6dcZIsAADAcjy2Qz473rIsWZalZDKp1157TbFYTIWFhZnf0d2+0+7HMr8PAABcJh6Py+fzndP/73X/bp3V4pwTAADcwTFxAgAA3MExccKlxAAAuINj4oS3dQAAcAfHxAkAAHAH4gQAABiFOAEAAEZxTJxwQiwAAO7gmDjhhFgAANzBMXECAADcgTgBAABGIU4AAIBRHBMnnBALAIA7OCZOOCEWAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYxTFxwtU6AAC4g2PihKt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwuecAADgDo6JEz7nBAAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfLcOAADu4Jg44bt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFHWPU5mZma0c+dOVVZW6tprr9XBgwfXewQAAGCwTeu+w02b1Nvbq6qqKkUiEVVXV2vXrl367Gc/u96jAAAAA617nGzZskVbtmyRJJWUlKi4uFjvvPMOcQIAACSt4m2d4eFhtbS0qLS0VB6PR0NDQ2esY1mWKioqlJ+fr/r6eo2Oji67rbGxMSWTSZWVlaU9OAAA2JjSjpNEIqFAICDLspZ9fHBwUKFQSF1dXRofH1cgEFBTU5NmZ2eXrPfOO+/olltu0e9///vVTQ4AADaktN/WaW5uVnNz84qP9/T0aO/evWpvb5ck9fX16ZFHHlF/f7/27dsnSZqfn9fu3bu1b98+bd++/az7m5+f1/z8fOp+PB5Pd2QAAOAgGb1aZ2FhQWNjYwoGgx/vICdHwWBQIyMjkiTbtnXrrbfq61//um6++eZP3WZ3d7d8Pl/qxltAAABsbBmNk5MnTyqZTMrv9y9Z7vf7FYlEJEnPPvusBgcHNTQ0pKqqKlVVVen48eMrbnP//v2KxWKp28zMTCZHBgAAhln3q3VuuOEGLS4unvP6Xq9XXq/3PE4EAABMktFXToqLi5Wbm6toNLpkeTQaVUlJyZq2bVmWKisrVVtbu6btAAAAs2U0TvLy8lRdXa1wOJxatri4qHA4rIaGhjVtu6OjQ5OTkzpy5MhaxwQAAAZL+22dubk5TU1Npe5PT09rYmJCRUVFKi8vVygUUltbm2pqalRXV6fe3l4lEonU1TsAAABnk3acHD16VI2Njan7oVBIktTW1qaBgQG1trbqxIkT6uzsVCQSUVVVlQ4dOnTGSbLpsixLlmUpmUyuaTsAAMBsHtu27WwPkY54PC6fz6dYLKbCwsLM7+Bu32n3Y5nfBwAALpPO/+91/1ZiAACAsyFOAACAURwTJ1xKDACAOzgmTriUGAAAd3BMnAAAAHcgTgAAgFEcEyeccwIAgDs4Jk445wQAAHdwTJwAAAB3IE4AAIBRiBMAAGAU4gQAABjFMXHC1ToAALiDY+KEq3UAAHAHx8QJAABwB+IEAAAYhTgBAABGIU4AAIBRHBMnXK0DAIA7OCZOuFoHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfM4JAADu4Jg44XNOAABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABjFMXHCd+sAAOAOjokTvlsHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGCUrcbJnzx5dfPHF+u53v5uN3QMAAINlJU5uv/12/eEPf8jGrgEAgOGyEic7d+7U5s2bs7FrAABguLTjZHh4WC0tLSotLZXH49HQ0NAZ61iWpYqKCuXn56u+vl6jo6OZmBUAALhA2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Ozsqgacn59XPB5fcgMAABtX2nHS3NysAwcOaM+ePcs+3tPTo71796q9vV2VlZXq6+tTQUGB+vv7VzVgd3e3fD5f6lZWVraq7QAAAGfI6DknCwsLGhsbUzAY/HgHOTkKBoMaGRlZ1Tb379+vWCyWus3MzGRqXAAAYKBNmdzYyZMnlUwm5ff7lyz3+/165ZVXUveDwaBeeOEFJRIJbd26VQcPHlRDQ8Oy2/R6vfJ6vZkcEwAAGCyjcXKuHnvssbR/xrIsWZalZDJ5HiYCAMCl7vYtsyy2/nN8Qkbf1ikuLlZubq6i0eiS5dFoVCUlJWvadkdHhyYnJ3XkyJE1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw+EV37YBAAD4pLTf1pmbm9PU1FTq/vT0tCYmJlRUVKTy8nKFQiG1tbWppqZGdXV16u3tVSKRUHt7e0YHBwAAG1PacXL06FE1Njam7odCIUlSW1ubBgYG1NraqhMnTqizs1ORSERVVVU6dOjQGSfJpotzTgAAcAePbdt2todIRzwel8/nUywWU2FhYeZ3cPqJQVk+KQgAgPNqnU6ITef/d1a+WwcAAGAlxAkAADCKY+LEsixVVlaqtrY226MAAIDzyDFxwuecAADgDo6JEwAA4A7ECQAAMIpj4oRzTgAAcAfHxAnnnAAA4A6OiRMAAOAOxAkAADBK2t+tk20ffdp+PB4/PzuYP+3T/M/XfgAAMMHp//ek8/K/76P/2+fyrTmO+W6dj774b2FhQW+88Ua2xwEAAKswMzOjrVu3nnUdx8TJRxYXF/XWW29p8+bN8ng8Gd12PB5XWVmZZmZmzs+XCroIxzJzOJaZw7HMHI5l5rjlWNq2rVOnTqm0tFQ5OWc/q8Rxb+vk5OR8anGtVWFh4Yb+A1lPHMvM4VhmDscycziWmeOGY+nzLfMNyMvghFgAAGAU4gQAABiFOPkEr9errq4ueb3ebI/ieBzLzOFYZg7HMnM4lpnDsTyT406IBQAAGxuvnAAAAKMQJwAAwCjECQAAMApxAgAAjOK6OLEsSxUVFcrPz1d9fb1GR0fPuv7Bgwf15S9/Wfn5+brmmmv0j3/8Y50mNV86x3JgYEAej2fJLT8/fx2nNdPw8LBaWlpUWloqj8ejoaGhT/2ZJ598Utddd528Xq+++MUvamBg4LzP6QTpHssnn3zyjL9Jj8ejSCSyPgMbrLu7W7W1tdq8ebMuueQS7d69W6+++uqn/hzPl2dazbHk+dJlcTI4OKhQKKSuri6Nj48rEAioqalJs7Ozy67/3HPP6Qc/+IF++MMf6tixY9q9e7d2796tF198cZ0nN0+6x1L636cfvv3226nbm2++uY4TmymRSCgQCMiyrHNaf3p6WjfddJMaGxs1MTGhO+64Qz/60Y/06KOPnudJzZfusfzIq6++uuTv8pJLLjlPEzrHU089pY6ODj3//PM6fPiwPvjgA33zm99UIpFY8Wd4vlzeao6lxPOlbBepq6uzOzo6UveTyaRdWlpqd3d3L7v+9773Pfumm25asqy+vt7+yU9+cl7ndIJ0j+X9999v+3y+dZrOmSTZDz300FnX+fnPf25fddVVS5a1trbaTU1N53Ey5zmXY/nEE0/Ykux33313XWZystnZWVuS/dRTT624Ds+X5+ZcjiXPl7btmldOFhYWNDY2pmAwmFqWk5OjYDCokZGRZX9mZGRkyfqS1NTUtOL6brGaYylJc3Nzuuyyy1RWVqZvf/vbeumll9Zj3A2Fv8nMq6qq0pYtW/SNb3xDzz77bLbHMVIsFpMkFRUVrbgOf5vn5lyOpcTzpWvi5OTJk0omk/L7/UuW+/3+Fd9jjkQiaa3vFqs5ltu2bVN/f78efvhh/elPf9Li4qK2b9+u//znP+sx8oax0t9kPB7Xf//73yxN5UxbtmxRX1+fHnzwQT344IMqKyvTzp07NT4+nu3RjLK4uKg77rhDX/3qV3X11VevuB7Pl5/uXI8lz5cO/FZiOFNDQ4MaGhpS97dv364rr7xSv/vd7/TLX/4yi5PBrbZt26Zt27al7m/fvl1vvPGG7rvvPv3xj3/M4mRm6ejo0Isvvqhnnnkm26M43rkeS54vXfTKSXFxsXJzcxWNRpcsj0ajKikpWfZnSkpK0lrfLVZzLE93wQUX6Ctf+YqmpqbOx4gb1kp/k4WFhfrMZz6Tpak2jrq6Ov4mP+G2227T3//+dz3xxBPaunXrWdfl+fLs0jmWp3Pj86Vr4iQvL0/V1dUKh8OpZYuLiwqHw0sK9ZMaGhqWrC9Jhw8fXnF9t1jNsTxdMpnU8ePHtWXLlvM15obE3+T5NTExwd+kJNu2ddttt+mhhx7S448/rs9//vOf+jP8bS5vNcfydK58vsz2Gbnr6YEHHrC9Xq89MDBgT05O2j/+8Y/tiy66yI5EIrZt2/bNN99s79u3L7X+s88+a2/atMm+99577Zdfftnu6uqyL7jgAvv48ePZ+hWMke6x/MUvfmE/+uij9htvvGGPjY3Z3//+9+38/Hz7pZdeytavYIRTp07Zx44ds48dO2ZLsnt6euxjx47Zb775pm3btr1v3z775ptvTq3/r3/9yy4oKLDvuusu++WXX7Yty7Jzc3PtQ4cOZetXMEa6x/K+++6zh4aG7Ndff90+fvy4ffvtt9s5OTn2Y489lq1fwRg//elPbZ/PZz/55JP222+/nbq9//77qXV4vjw3qzmWPF/atqvixLZt+ze/+Y1dXl5u5+Xl2XV1dfbzzz+feuzGG2+029ralqz/17/+1b7iiivsvLw8+6qrrrIfeeSRdZ7YXOkcyzvuuCO1rt/vt3ft2mWPj49nYWqzfHQ56+m3j45dW1ubfeONN57xM1VVVXZeXp59+eWX2/fff/+6z22idI/lr3/9a/sLX/iCnZ+fbxcVFdk7d+60H3/88ewMb5jljqOkJX9rPF+em9UcS54vbdtj27a9fq/TAAAAnJ1rzjkBAADOQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwyv8D6KAeY7AISbEAAAAASUVORK5CYII=",
657
+ "text/plain": [
658
+ "<Figure size 640x480 with 1 Axes>"
659
+ ]
660
+ },
661
+ "metadata": {},
662
+ "output_type": "display_data"
663
+ }
664
+ ],
665
+ "source": [
666
+ "import matplotlib.pyplot as plt\n",
667
+ "%matplotlib inline\n",
668
+ "plt.hist(abs(intermediate-outputs[0]).ravel(), bins = 100)\n",
669
+ "plt.yscale('log')\n",
670
+ "plt.show()"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "code",
675
+ "execution_count": 15,
676
+ "id": "71cb219e-b91a-4629-99f6-00db786903c7",
677
+ "metadata": {},
678
+ "outputs": [
679
+ {
680
+ "data": {
681
+ "text/plain": [
682
+ "tensor([1.1902e-03, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00,\n",
683
+ " 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00])"
684
+ ]
685
+ },
686
+ "execution_count": 15,
687
+ "metadata": {},
688
+ "output_type": "execute_result"
689
+ }
690
+ ],
691
+ "source": [
692
+ "torch.sort(abs(intermediate-outputs[0]).ravel())[0][-10:]"
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "execution_count": null,
698
+ "id": "92ba5920-5451-4bb1-af0e-5ea987841ab1",
699
+ "metadata": {},
700
+ "outputs": [],
701
+ "source": [
702
+ "import onnxruntime as ort\n",
703
+ "\n",
704
+ "session_options = ort.SessionOptions()\n",
705
+ "session_options.log_severity_level = 0 # Verbose logging\n",
706
+ "session = ort.InferenceSession(\"models_mask/preproc_test.onnx\", sess_options=session_options)"
707
+ ]
708
+ },
709
+ {
710
+ "cell_type": "code",
711
+ "execution_count": null,
712
+ "id": "3277b343-245d-4ac8-a91c-373061dcbf53",
713
+ "metadata": {},
714
+ "outputs": [],
715
+ "source": [
716
+ "import matplotlib.pyplot as plt\n",
717
+ "%matplotlib inline\n",
718
+ "plt.imshow(outputs[0][0,8,:,:])\n",
719
+ "plt.show()"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "execution_count": null,
725
+ "id": "67e99ef8-e49a-4037-a818-244555b0bdc5",
726
+ "metadata": {},
727
+ "outputs": [],
728
+ "source": [
729
+ "import onnx\n",
730
+ "\n",
731
+ "# Path to your ONNX model\n",
732
+ "model_path = \"models/model-47-99.125.onnx\"\n",
733
+ "\n",
734
+ "# Load the ONNX model\n",
735
+ "onnx_model = onnx.load(model_path)\n",
736
+ "\n",
737
+ "# Check the model for validity\n",
738
+ "onnx.checker.check_model(onnx_model)\n",
739
+ "\n",
740
+ "# Print model graph structure (optional)\n",
741
+ "print(onnx.helper.printable_graph(onnx_model.graph))\n"
742
+ ]
743
+ }
744
+ ],
745
+ "metadata": {
746
+ "kernelspec": {
747
+ "display_name": "Python 3 (ipykernel)",
748
+ "language": "python",
749
+ "name": "python3"
750
+ },
751
+ "language_info": {
752
+ "codemirror_mode": {
753
+ "name": "ipython",
754
+ "version": 3
755
+ },
756
+ "file_extension": ".py",
757
+ "mimetype": "text/x-python",
758
+ "name": "python",
759
+ "nbconvert_exporter": "python",
760
+ "pygments_lexer": "ipython3",
761
+ "version": "3.11.9"
762
+ }
763
+ },
764
+ "nbformat": 4,
765
+ "nbformat_minor": 5
766
+ }
models/.ipynb_checkpoints/benchmark_model-8bit-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/benchmark_model-Copy1-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/benchmark_model-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/benchmark_model_treshold-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/benchmark_model_vanilla-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/eval_basic-checkpoint.ipynb ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "6\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ " 0%| | 0/48 [00:22<?, ?it/s]\n"
25
+ ]
26
+ },
27
+ {
28
+ "ename": "AttributeError",
29
+ "evalue": "Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n",
30
+ "output_type": "error",
31
+ "traceback": [
32
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
33
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
34
+ "Cell \u001b[0;32mIn[1], line 50\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[0;32m---> 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 51\u001b[0m _, predicted \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmax(outputs, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 52\u001b[0m results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mextend(outputs\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mtolist())\n",
35
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
36
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
37
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:186\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodule_kwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 185\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 186\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparallel_apply(replicas, inputs, module_kwargs)\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
38
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:201\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Any]:\n\u001b[0;32m--> 201\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parallel_apply(replicas, inputs, kwargs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(replicas)])\n",
39
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:108\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 106\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m--> 108\u001b[0m output\u001b[38;5;241m.\u001b[39mreraise()\n\u001b[1;32m 109\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
40
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/_utils.py:706\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 703\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 704\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 706\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
41
+ "\u001b[0;31mAttributeError\u001b[0m: Caught AttributeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py\", line 83, in _worker\n output = module(*input, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1553, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1562, in _call_impl\n return forward_call(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/pma/projects/frbnn_narrow/CNN/resnet_model.py\", line 106, in forward\n return x, self.mask, self.value\n ^^^^^^^^^\n File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/modules/module.py\", line 1729, in __getattr__\n raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{name}'\")\nAttributeError: 'ResNet' object has no attribute 'mask'\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
47
+ "from torch.utils.data import Dataset, DataLoader\n",
48
+ "from utils import CustomDataset, TestingDataset, transform\n",
49
+ "from tqdm import tqdm\n",
50
+ "import torch\n",
51
+ "import numpy as np\n",
52
+ "from resnet_model import ResidualBlock, ResNet\n",
53
+ "import torch\n",
54
+ "import torch.nn as nn\n",
55
+ "import torch.optim as optim\n",
56
+ "from tqdm import tqdm \n",
57
+ "import torch.nn.functional as F\n",
58
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
59
+ "import pickle\n",
60
+ "\n",
61
+ "torch.manual_seed(1)\n",
62
+ "# torch.manual_seed(42)\n",
63
+ "\n",
64
+ "\n",
65
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
66
+ "num_gpus = torch.cuda.device_count()\n",
67
+ "print(num_gpus)\n",
68
+ "\n",
69
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
70
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
71
+ "\n",
72
+ "num_classes = 2\n",
73
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
74
+ "\n",
75
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
76
+ "model = nn.DataParallel(model)\n",
77
+ "model = model.to(device)\n",
78
+ "params = sum(p.numel() for p in model.parameters())\n",
79
+ "print(\"num params \",params)\n",
80
+ "\n",
81
+ "model_1 = 'models/model-23-99.045.pt'\n",
82
+ "# model_1 ='models/model-47-99.125.pt'\n",
83
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
84
+ "model = model.eval()\n",
85
+ "\n",
86
+ "# eval\n",
87
+ "val_loss = 0.0\n",
88
+ "correct_valid = 0\n",
89
+ "total = 0\n",
90
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
91
+ "model.eval()\n",
92
+ "with torch.no_grad():\n",
93
+ " for images, labels in tqdm(testloader):\n",
94
+ " inputs, labels = images.to(device), labels\n",
95
+ " outputs = model(inputs)\n",
96
+ " _, predicted = torch.max(outputs, 1)\n",
97
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
98
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
99
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
100
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
101
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
102
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
103
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
104
+ " total += labels[0].size(0)\n",
105
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
106
+ "# Calculate training accuracy after each epoch\n",
107
+ "val_accuracy = correct_valid / total * 100.0\n",
108
+ "print(\"===========================\")\n",
109
+ "print('accuracy: ', val_accuracy)\n",
110
+ "print(\"===========================\")\n",
111
+ "\n",
112
+ "import pickle\n",
113
+ "\n",
114
+ "# Pickle the dictionary to a file\n",
115
+ "with open('models/test_42.pkl', 'wb') as f:\n",
116
+ " pickle.dump(results, f)"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
127
+ "from torch.utils.data import Dataset, DataLoader\n",
128
+ "from utils import CustomDataset, TestingDataset, transform\n",
129
+ "from tqdm import tqdm\n",
130
+ "import torch\n",
131
+ "import numpy as np\n",
132
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
133
+ "import torch\n",
134
+ "import torch.nn as nn\n",
135
+ "import torch.optim as optim\n",
136
+ "from tqdm import tqdm \n",
137
+ "import torch.nn.functional as F\n",
138
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
139
+ "import pickle\n",
140
+ "\n",
141
+ "torch.manual_seed(1)\n",
142
+ "# torch.manual_seed(42)\n",
143
+ "\n",
144
+ "\n",
145
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
146
+ "num_gpus = torch.cuda.device_count()\n",
147
+ "print(num_gpus)\n",
148
+ "\n",
149
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
150
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
151
+ "\n",
152
+ "num_classes = 2\n",
153
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
154
+ "\n",
155
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
156
+ "model = nn.DataParallel(model)\n",
157
+ "model = model.to(device)\n",
158
+ "params = sum(p.numel() for p in model.parameters())\n",
159
+ "print(\"num params \",params)\n",
160
+ "\n",
161
+ "\n",
162
+ "model_1 = 'models/model-14-98.005.pt'\n",
163
+ "# model_1 ='models/model-47-99.125.pt'\n",
164
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
165
+ "model = model.eval()\n",
166
+ "\n",
167
+ "# eval\n",
168
+ "val_loss = 0.0\n",
169
+ "correct_valid = 0\n",
170
+ "total = 0\n",
171
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
172
+ "model.eval()\n",
173
+ "with torch.no_grad():\n",
174
+ " for images, labels in tqdm(testloader):\n",
175
+ " inputs, labels = images.to(device), labels\n",
176
+ " outputs = model(inputs)\n",
177
+ " _, predicted = torch.max(outputs, 1)\n",
178
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
179
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
180
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
181
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
182
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
183
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
184
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
185
+ " total += labels[0].size(0)\n",
186
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
187
+ " \n",
188
+ "# Calculate training accuracy after each epoch\n",
189
+ "val_accuracy = correct_valid / total * 100.0\n",
190
+ "print(\"===========================\")\n",
191
+ "print('accuracy: ', val_accuracy)\n",
192
+ "print(\"===========================\")\n",
193
+ "\n",
194
+ "import pickle\n",
195
+ "\n",
196
+ "# Pickle the dictionary to a file\n",
197
+ "with open('models/test_1.pkl', 'wb') as f:\n",
198
+ " pickle.dump(results, f)"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
209
+ "from torch.utils.data import Dataset, DataLoader\n",
210
+ "from utils import CustomDataset, TestingDataset, transform\n",
211
+ "from tqdm import tqdm\n",
212
+ "import torch\n",
213
+ "import numpy as np\n",
214
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
215
+ "import torch\n",
216
+ "import torch.nn as nn\n",
217
+ "import torch.optim as optim\n",
218
+ "from tqdm import tqdm \n",
219
+ "import torch.nn.functional as F\n",
220
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
221
+ "import pickle\n",
222
+ "\n",
223
+ "torch.manual_seed(1)\n",
224
+ "# torch.manual_seed(42)\n",
225
+ "\n",
226
+ "\n",
227
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
228
+ "num_gpus = torch.cuda.device_count()\n",
229
+ "print(num_gpus)\n",
230
+ "\n",
231
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
232
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
233
+ "\n",
234
+ "num_classes = 2\n",
235
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
236
+ "\n",
237
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
238
+ "model = nn.DataParallel(model)\n",
239
+ "model = model.to(device)\n",
240
+ "params = sum(p.numel() for p in model.parameters())\n",
241
+ "print(\"num params \",params)\n",
242
+ "\n",
243
+ "\n",
244
+ "model_1 = 'models/model-28-98.955.pt'\n",
245
+ "# model_1 ='models/model-47-99.125.pt'\n",
246
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
247
+ "model = model.eval()\n",
248
+ "\n",
249
+ "# eval\n",
250
+ "val_loss = 0.0\n",
251
+ "correct_valid = 0\n",
252
+ "total = 0\n",
253
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
254
+ "model.eval()\n",
255
+ "with torch.no_grad():\n",
256
+ " for images, labels in tqdm(testloader):\n",
257
+ " inputs, labels = images.to(device), labels\n",
258
+ " outputs = model(inputs)\n",
259
+ " _, predicted = torch.max(outputs, 1)\n",
260
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
261
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
262
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
263
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
264
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
265
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
266
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
267
+ " total += labels[0].size(0)\n",
268
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
269
+ " \n",
270
+ "# Calculate training accuracy after each epoch\n",
271
+ "val_accuracy = correct_valid / total * 100.0\n",
272
+ "print(\"===========================\")\n",
273
+ "print('accuracy: ', val_accuracy)\n",
274
+ "print(\"===========================\")\n",
275
+ "\n",
276
+ "import pickle\n",
277
+ "\n",
278
+ "# Pickle the dictionary to a file\n",
279
+ "with open('models/test_7109.pkl', 'wb') as f:\n",
280
+ " pickle.dump(results, f)"
281
+ ]
282
+ }
283
+ ],
284
+ "metadata": {
285
+ "kernelspec": {
286
+ "display_name": "Python 3 (ipykernel)",
287
+ "language": "python",
288
+ "name": "python3"
289
+ },
290
+ "language_info": {
291
+ "codemirror_mode": {
292
+ "name": "ipython",
293
+ "version": 3
294
+ },
295
+ "file_extension": ".py",
296
+ "mimetype": "text/x-python",
297
+ "name": "python",
298
+ "nbconvert_exporter": "python",
299
+ "pygments_lexer": "ipython3",
300
+ "version": "3.11.9"
301
+ }
302
+ },
303
+ "nbformat": 4,
304
+ "nbformat_minor": 5
305
+ }
models/.ipynb_checkpoints/eval_basic-extend-checkpoint.ipynb ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ "100%|███████████████████████████████████████████| 48/48 [00:39<00:00, 1.21it/s]"
25
+ ]
26
+ },
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "===========================\n",
32
+ "accuracy: 98.82\n",
33
+ "===========================\n",
34
+ "False Positive Rate: 0.010\n",
35
+ "Precision: 0.990\n",
36
+ "Recall: 0.986\n",
37
+ "F1 Score: 0.988\n"
38
+ ]
39
+ },
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "\n"
45
+ ]
46
+ }
47
+ ],
48
+ "source": [
49
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
50
+ "from torch.utils.data import Dataset, DataLoader\n",
51
+ "from utils import CustomDataset, TestingDataset, transform\n",
52
+ "from tqdm import tqdm\n",
53
+ "import torch\n",
54
+ "import numpy as np\n",
55
+ "from resnet_model import ResidualBlock, ResNet\n",
56
+ "import torch\n",
57
+ "import torch.nn as nn\n",
58
+ "import torch.optim as optim\n",
59
+ "from tqdm import tqdm \n",
60
+ "import torch.nn.functional as F\n",
61
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
62
+ "import pickle\n",
63
+ "\n",
64
+ "torch.manual_seed(1)\n",
65
+ "# torch.manual_seed(42)\n",
66
+ "\n",
67
+ "\n",
68
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
69
+ "num_gpus = torch.cuda.device_count()\n",
70
+ "print(num_gpus)\n",
71
+ "\n",
72
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
73
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
74
+ "\n",
75
+ "num_classes = 2\n",
76
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
77
+ "\n",
78
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
79
+ "model = nn.DataParallel(model)\n",
80
+ "model = model.to(device)\n",
81
+ "params = sum(p.numel() for p in model.parameters())\n",
82
+ "print(\"num params \",params)\n",
83
+ "\n",
84
+ "model_1 = 'models/model-23-99.045.pt'\n",
85
+ "# model_1 ='models/model-47-99.125.pt'\n",
86
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
87
+ "model = model.eval()\n",
88
+ "\n",
89
+ "# eval\n",
90
+ "val_loss = 0.0\n",
91
+ "correct_valid = 0\n",
92
+ "total = 0\n",
93
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
94
+ "model.eval()\n",
95
+ "with torch.no_grad():\n",
96
+ " for images, labels in tqdm(testloader):\n",
97
+ " inputs, labels = images.to(device), labels\n",
98
+ " outputs = model(inputs)\n",
99
+ " _, predicted = torch.max(outputs, 1)\n",
100
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
101
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
102
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
103
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
104
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
105
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
106
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
107
+ " total += labels[0].size(0)\n",
108
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
109
+ "# Calculate training accuracy after each epoch\n",
110
+ "val_accuracy = correct_valid / total * 100.0\n",
111
+ "print(\"===========================\")\n",
112
+ "print('accuracy: ', val_accuracy)\n",
113
+ "print(\"===========================\")\n",
114
+ "\n",
115
+ "import pickle\n",
116
+ "\n",
117
+ "# Pickle the dictionary to a file\n",
118
+ "with open('models/test_42.pkl', 'wb') as f:\n",
119
+ " pickle.dump(results, f)\n",
120
+ "\n",
121
+ "\n",
122
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
123
+ "from sklearn.metrics import confusion_matrix\n",
124
+ "\n",
125
+ "# Example binary labels\n",
126
+ "true = results['true'] # ground truth\n",
127
+ "pred = results['pred'] # predicted\n",
128
+ "\n",
129
+ "# Compute metrics\n",
130
+ "precision = precision_score(true, pred)\n",
131
+ "recall = recall_score(true, pred)\n",
132
+ "f1 = f1_score(true, pred)\n",
133
+ "# Get confusion matrix: TN, FP, FN, TP\n",
134
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
135
+ "\n",
136
+ "# Compute FPR\n",
137
+ "fpr = fp / (fp + tn)\n",
138
+ "\n",
139
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
140
+ "\n",
141
+ "print(f\"Precision: {precision:.3f}\")\n",
142
+ "print(f\"Recall: {recall:.3f}\")\n",
143
+ "print(f\"F1 Score: {f1:.3f}\")"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 2,
149
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stdout",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "2\n",
157
+ "num params encoder 50840\n",
158
+ "num params 21496282\n"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
166
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
167
+ "100%|███████████████████████████████████████████| 48/48 [00:51<00:00, 1.07s/it]"
168
+ ]
169
+ },
170
+ {
171
+ "name": "stdout",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "===========================\n",
175
+ "accuracy: 97.185\n",
176
+ "===========================\n",
177
+ "False Positive Rate: 0.038\n",
178
+ "Precision: 0.963\n",
179
+ "Recall: 0.981\n",
180
+ "F1 Score: 0.972\n"
181
+ ]
182
+ },
183
+ {
184
+ "name": "stderr",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "\n"
188
+ ]
189
+ }
190
+ ],
191
+ "source": [
192
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
193
+ "from torch.utils.data import Dataset, DataLoader\n",
194
+ "from utils import CustomDataset, TestingDataset, transform\n",
195
+ "from tqdm import tqdm\n",
196
+ "import torch\n",
197
+ "import numpy as np\n",
198
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
199
+ "import torch\n",
200
+ "import torch.nn as nn\n",
201
+ "import torch.optim as optim\n",
202
+ "from tqdm import tqdm \n",
203
+ "import torch.nn.functional as F\n",
204
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
205
+ "import pickle\n",
206
+ "\n",
207
+ "torch.manual_seed(1)\n",
208
+ "# torch.manual_seed(42)\n",
209
+ "\n",
210
+ "\n",
211
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
212
+ "num_gpus = torch.cuda.device_count()\n",
213
+ "print(num_gpus)\n",
214
+ "\n",
215
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
216
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
217
+ "\n",
218
+ "num_classes = 2\n",
219
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
220
+ "\n",
221
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
222
+ "model = nn.DataParallel(model)\n",
223
+ "model = model.to(device)\n",
224
+ "params = sum(p.numel() for p in model.parameters())\n",
225
+ "print(\"num params \",params)\n",
226
+ "\n",
227
+ "\n",
228
+ "model_1 = 'models/model-14-98.005.pt'\n",
229
+ "# model_1 ='models/model-47-99.125.pt'\n",
230
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
231
+ "model = model.eval()\n",
232
+ "\n",
233
+ "# eval\n",
234
+ "val_loss = 0.0\n",
235
+ "correct_valid = 0\n",
236
+ "total = 0\n",
237
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
238
+ "model.eval()\n",
239
+ "with torch.no_grad():\n",
240
+ " for images, labels in tqdm(testloader):\n",
241
+ " inputs, labels = images.to(device), labels\n",
242
+ " outputs = model(inputs)\n",
243
+ " _, predicted = torch.max(outputs, 1)\n",
244
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
245
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
246
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
247
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
248
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
249
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
250
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
251
+ " total += labels[0].size(0)\n",
252
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
253
+ " \n",
254
+ "# Calculate training accuracy after each epoch\n",
255
+ "val_accuracy = correct_valid / total * 100.0\n",
256
+ "print(\"===========================\")\n",
257
+ "print('accuracy: ', val_accuracy)\n",
258
+ "print(\"===========================\")\n",
259
+ "\n",
260
+ "import pickle\n",
261
+ "\n",
262
+ "# Pickle the dictionary to a file\n",
263
+ "with open('models/test_1.pkl', 'wb') as f:\n",
264
+ " pickle.dump(results, f)\n",
265
+ "\n",
266
+ "\n",
267
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
268
+ "from sklearn.metrics import confusion_matrix\n",
269
+ "\n",
270
+ "# Example binary labels\n",
271
+ "true = results['true'] # ground truth\n",
272
+ "pred = results['pred'] # predicted\n",
273
+ "\n",
274
+ "# Compute metrics\n",
275
+ "precision = precision_score(true, pred)\n",
276
+ "recall = recall_score(true, pred)\n",
277
+ "f1 = f1_score(true, pred)\n",
278
+ "# Get confusion matrix: TN, FP, FN, TP\n",
279
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
280
+ "\n",
281
+ "# Compute FPR\n",
282
+ "fpr = fp / (fp + tn)\n",
283
+ "\n",
284
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
285
+ "\n",
286
+ "print(f\"Precision: {precision:.3f}\")\n",
287
+ "print(f\"Recall: {recall:.3f}\")\n",
288
+ "print(f\"F1 Score: {f1:.3f}\")"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 3,
294
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
295
+ "metadata": {},
296
+ "outputs": [
297
+ {
298
+ "name": "stdout",
299
+ "output_type": "stream",
300
+ "text": [
301
+ "2\n",
302
+ "num params encoder 50840\n",
303
+ "num params 21496282\n"
304
+ ]
305
+ },
306
+ {
307
+ "name": "stderr",
308
+ "output_type": "stream",
309
+ "text": [
310
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
311
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
312
+ "100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.11s/it]"
313
+ ]
314
+ },
315
+ {
316
+ "name": "stdout",
317
+ "output_type": "stream",
318
+ "text": [
319
+ "===========================\n",
320
+ "accuracy: 98.455\n",
321
+ "===========================\n",
322
+ "False Positive Rate: 0.010\n",
323
+ "Precision: 0.990\n",
324
+ "Recall: 0.979\n",
325
+ "F1 Score: 0.984\n"
326
+ ]
327
+ },
328
+ {
329
+ "name": "stderr",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "\n"
333
+ ]
334
+ }
335
+ ],
336
+ "source": [
337
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
338
+ "from torch.utils.data import Dataset, DataLoader\n",
339
+ "from utils import CustomDataset, TestingDataset, transform\n",
340
+ "from tqdm import tqdm\n",
341
+ "import torch\n",
342
+ "import numpy as np\n",
343
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
344
+ "import torch\n",
345
+ "import torch.nn as nn\n",
346
+ "import torch.optim as optim\n",
347
+ "from tqdm import tqdm \n",
348
+ "import torch.nn.functional as F\n",
349
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
350
+ "import pickle\n",
351
+ "\n",
352
+ "torch.manual_seed(1)\n",
353
+ "# torch.manual_seed(42)\n",
354
+ "\n",
355
+ "\n",
356
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
357
+ "num_gpus = torch.cuda.device_count()\n",
358
+ "print(num_gpus)\n",
359
+ "\n",
360
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
361
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
362
+ "\n",
363
+ "num_classes = 2\n",
364
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
365
+ "\n",
366
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
367
+ "model = nn.DataParallel(model)\n",
368
+ "model = model.to(device)\n",
369
+ "params = sum(p.numel() for p in model.parameters())\n",
370
+ "print(\"num params \",params)\n",
371
+ "\n",
372
+ "\n",
373
+ "model_1 = 'models/model-28-98.955.pt'\n",
374
+ "# model_1 ='models/model-47-99.125.pt'\n",
375
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
376
+ "model = model.eval()\n",
377
+ "\n",
378
+ "# eval\n",
379
+ "val_loss = 0.0\n",
380
+ "correct_valid = 0\n",
381
+ "total = 0\n",
382
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
383
+ "model.eval()\n",
384
+ "with torch.no_grad():\n",
385
+ " for images, labels in tqdm(testloader):\n",
386
+ " inputs, labels = images.to(device), labels\n",
387
+ " outputs = model(inputs)\n",
388
+ " _, predicted = torch.max(outputs, 1)\n",
389
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
390
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
391
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
392
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
393
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
394
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
395
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
396
+ " total += labels[0].size(0)\n",
397
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
398
+ " \n",
399
+ "# Calculate training accuracy after each epoch\n",
400
+ "val_accuracy = correct_valid / total * 100.0\n",
401
+ "print(\"===========================\")\n",
402
+ "print('accuracy: ', val_accuracy)\n",
403
+ "print(\"===========================\")\n",
404
+ "\n",
405
+ "import pickle\n",
406
+ "\n",
407
+ "# Pickle the dictionary to a file\n",
408
+ "with open('models/test_7109.pkl', 'wb') as f:\n",
409
+ " pickle.dump(results, f)\n",
410
+ "\n",
411
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
412
+ "from sklearn.metrics import confusion_matrix\n",
413
+ "\n",
414
+ "# Example binary labels\n",
415
+ "true = results['true'] # ground truth\n",
416
+ "pred = results['pred'] # predicted\n",
417
+ "\n",
418
+ "# Compute metrics\n",
419
+ "precision = precision_score(true, pred)\n",
420
+ "recall = recall_score(true, pred)\n",
421
+ "f1 = f1_score(true, pred)\n",
422
+ "# Get confusion matrix: TN, FP, FN, TP\n",
423
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
424
+ "\n",
425
+ "# Compute FPR\n",
426
+ "fpr = fp / (fp + tn)\n",
427
+ "\n",
428
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
429
+ "\n",
430
+ "print(f\"Precision: {precision:.3f}\")\n",
431
+ "print(f\"Recall: {recall:.3f}\")\n",
432
+ "print(f\"F1 Score: {f1:.3f}\")"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": 4,
438
+ "id": "ad4ef08f-6a9b-495f-bb93-0a13251f0adb",
439
+ "metadata": {},
440
+ "outputs": [
441
+ {
442
+ "name": "stdout",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "98.15333333333332 0.7007416705811665\n",
446
+ "0.9803333333333333 0.0009428090415820641\n",
447
+ "0.9813333333333333 0.006798692684790386\n",
448
+ "0.019333333333333334 0.013199326582148887\n"
449
+ ]
450
+ }
451
+ ],
452
+ "source": [
453
+ "# acc\n",
454
+ "print(np.mean([98.82,97.185,98.455]), np.std([98.82,97.185,98.455]))\n",
455
+ "# recall\n",
456
+ "print(np.mean([0.981,0.981, 0.979]), np.std([0.981,0.981, 0.979]))\n",
457
+ "# f1\n",
458
+ "print(np.mean([0.988,0.972,0.984]),np.std([0.988,0.972,0.984]))\n",
459
+ "# fp\n",
460
+ "print(np.mean([0.010,0.038,0.010]),np.std([0.010,0.038,0.010]))"
461
+ ]
462
+ }
463
+ ],
464
+ "metadata": {
465
+ "kernelspec": {
466
+ "display_name": "Python 3 (ipykernel)",
467
+ "language": "python",
468
+ "name": "python3"
469
+ },
470
+ "language_info": {
471
+ "codemirror_mode": {
472
+ "name": "ipython",
473
+ "version": 3
474
+ },
475
+ "file_extension": ".py",
476
+ "mimetype": "text/x-python",
477
+ "name": "python",
478
+ "nbconvert_exporter": "python",
479
+ "pygments_lexer": "ipython3",
480
+ "version": "3.11.9"
481
+ }
482
+ },
483
+ "nbformat": 4,
484
+ "nbformat_minor": 5
485
+ }
models/.ipynb_checkpoints/eval_mask-8-checkpoint.ipynb ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "6\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ "100%|███████████████████████████████████████████| 48/48 [01:43<00:00, 2.16s/it]"
25
+ ]
26
+ },
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "===========================\n",
32
+ "accuracy: 98.94\n",
33
+ "===========================\n"
34
+ ]
35
+ },
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
46
+ "from torch.utils.data import Dataset, DataLoader\n",
47
+ "from utils import CustomDataset, TestingDataset, transform\n",
48
+ "from tqdm import tqdm\n",
49
+ "import torch\n",
50
+ "import numpy as np\n",
51
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
52
+ "import torch\n",
53
+ "import torch.nn as nn\n",
54
+ "import torch.optim as optim\n",
55
+ "from tqdm import tqdm \n",
56
+ "import torch.nn.functional as F\n",
57
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
58
+ "import pickle\n",
59
+ "\n",
60
+ "torch.manual_seed(1)\n",
61
+ "# torch.manual_seed(42)\n",
62
+ "\n",
63
+ "\n",
64
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
65
+ "num_gpus = torch.cuda.device_count()\n",
66
+ "print(num_gpus)\n",
67
+ "\n",
68
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
69
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
70
+ "\n",
71
+ "num_classes = 2\n",
72
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
73
+ "\n",
74
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
75
+ "model = nn.DataParallel(model)\n",
76
+ "model = model.to(device)\n",
77
+ "params = sum(p.numel() for p in model.parameters())\n",
78
+ "print(\"num params \",params)\n",
79
+ "\n",
80
+ "model_1 = 'models_8/model-25-99.31_7109.pt'\n",
81
+ "# model_1 ='models/model-47-99.125.pt'\n",
82
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
83
+ "model = model.eval()\n",
84
+ "\n",
85
+ "# eval\n",
86
+ "val_loss = 0.0\n",
87
+ "correct_valid = 0\n",
88
+ "total = 0\n",
89
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
90
+ "model.eval()\n",
91
+ "with torch.no_grad():\n",
92
+ " for images, labels in tqdm(testloader):\n",
93
+ " inputs, labels = images.to(device), labels\n",
94
+ " outputs = model(inputs, return_mask = True)\n",
95
+ " _, predicted = torch.max(outputs, 1)\n",
96
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
97
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
98
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
99
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
100
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
101
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
102
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
103
+ " total += labels[0].size(0)\n",
104
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
105
+ " \n",
106
+ " \n",
107
+ "# Calculate training accuracy after each epoch\n",
108
+ "val_accuracy = correct_valid / total * 100.0\n",
109
+ "print(\"===========================\")\n",
110
+ "print('accuracy: ', val_accuracy)\n",
111
+ "print(\"===========================\")\n",
112
+ "\n",
113
+ "import pickle\n",
114
+ "\n",
115
+ "# Pickle the dictionary to a file\n",
116
+ "with open('models_8/test_7109.pkl', 'wb') as f:\n",
117
+ " pickle.dump(results, f)"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 2,
123
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "6\n",
131
+ "num params encoder 50840\n",
132
+ "num params 21496282\n"
133
+ ]
134
+ },
135
+ {
136
+ "name": "stderr",
137
+ "output_type": "stream",
138
+ "text": [
139
+ "100%|██████████████████��████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
140
+ ]
141
+ },
142
+ {
143
+ "name": "stdout",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "===========================\n",
147
+ "accuracy: 99.17\n",
148
+ "===========================\n"
149
+ ]
150
+ },
151
+ {
152
+ "name": "stderr",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "\n"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
161
+ "from torch.utils.data import Dataset, DataLoader\n",
162
+ "from utils import CustomDataset, TestingDataset, transform\n",
163
+ "from tqdm import tqdm\n",
164
+ "import torch\n",
165
+ "import numpy as np\n",
166
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
167
+ "import torch\n",
168
+ "import torch.nn as nn\n",
169
+ "import torch.optim as optim\n",
170
+ "from tqdm import tqdm \n",
171
+ "import torch.nn.functional as F\n",
172
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
173
+ "import pickle\n",
174
+ "\n",
175
+ "torch.manual_seed(1)\n",
176
+ "# torch.manual_seed(42)\n",
177
+ "\n",
178
+ "\n",
179
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
180
+ "num_gpus = torch.cuda.device_count()\n",
181
+ "print(num_gpus)\n",
182
+ "\n",
183
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
184
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
185
+ "\n",
186
+ "num_classes = 2\n",
187
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
188
+ "\n",
189
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
190
+ "model = nn.DataParallel(model)\n",
191
+ "model = model.to(device)\n",
192
+ "params = sum(p.numel() for p in model.parameters())\n",
193
+ "print(\"num params \",params)\n",
194
+ "\n",
195
+ "\n",
196
+ "model_1 = 'models_8/model-44-99.445_42.pt'\n",
197
+ "# model_1 ='models/model-47-99.125.pt'\n",
198
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
199
+ "model = model.eval()\n",
200
+ "\n",
201
+ "# eval\n",
202
+ "val_loss = 0.0\n",
203
+ "correct_valid = 0\n",
204
+ "total = 0\n",
205
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
206
+ "model.eval()\n",
207
+ "with torch.no_grad():\n",
208
+ " for images, labels in tqdm(testloader):\n",
209
+ " inputs, labels = images.to(device), labels\n",
210
+ " outputs = model(inputs, return_mask = True)\n",
211
+ " _, predicted = torch.max(outputs, 1)\n",
212
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
213
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
214
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
215
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
216
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
217
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
218
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
219
+ " total += labels[0].size(0)\n",
220
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
221
+ " \n",
222
+ "# Calculate training accuracy after each epoch\n",
223
+ "val_accuracy = correct_valid / total * 100.0\n",
224
+ "print(\"===========================\")\n",
225
+ "print('accuracy: ', val_accuracy)\n",
226
+ "print(\"===========================\")\n",
227
+ "\n",
228
+ "import pickle\n",
229
+ "\n",
230
+ "# Pickle the dictionary to a file\n",
231
+ "with open('models_8/test_42.pkl', 'wb') as f:\n",
232
+ " pickle.dump(results, f)"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 3,
238
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
239
+ "metadata": {},
240
+ "outputs": [
241
+ {
242
+ "name": "stdout",
243
+ "output_type": "stream",
244
+ "text": [
245
+ "6\n",
246
+ "num params encoder 50840\n",
247
+ "num params 21496282\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stderr",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "100%|███████████████████████████████████████████| 48/48 [00:54<00:00, 1.14s/it]"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "===========================\n",
262
+ "accuracy: 99.035\n",
263
+ "===========================\n"
264
+ ]
265
+ },
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
276
+ "from torch.utils.data import Dataset, DataLoader\n",
277
+ "from utils import CustomDataset, TestingDataset, transform\n",
278
+ "from tqdm import tqdm\n",
279
+ "import torch\n",
280
+ "import numpy as np\n",
281
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
282
+ "import torch\n",
283
+ "import torch.nn as nn\n",
284
+ "import torch.optim as optim\n",
285
+ "from tqdm import tqdm \n",
286
+ "import torch.nn.functional as F\n",
287
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
288
+ "import pickle\n",
289
+ "\n",
290
+ "torch.manual_seed(1)\n",
291
+ "# torch.manual_seed(42)\n",
292
+ "\n",
293
+ "\n",
294
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
295
+ "num_gpus = torch.cuda.device_count()\n",
296
+ "print(num_gpus)\n",
297
+ "\n",
298
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
299
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
300
+ "\n",
301
+ "num_classes = 2\n",
302
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
303
+ "\n",
304
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
305
+ "model = nn.DataParallel(model)\n",
306
+ "model = model.to(device)\n",
307
+ "params = sum(p.numel() for p in model.parameters())\n",
308
+ "print(\"num params \",params)\n",
309
+ "\n",
310
+ "\n",
311
+ "model_1 = 'models_8/model-43-99.355_1.pt'\n",
312
+ "# model_1 ='models/model-47-99.125.pt'\n",
313
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
314
+ "model = model.eval()\n",
315
+ "\n",
316
+ "# eval\n",
317
+ "val_loss = 0.0\n",
318
+ "correct_valid = 0\n",
319
+ "total = 0\n",
320
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
321
+ "model.eval()\n",
322
+ "with torch.no_grad():\n",
323
+ " for images, labels in tqdm(testloader):\n",
324
+ " inputs, labels = images.to(device), labels\n",
325
+ " outputs = model(inputs, return_mask = True)\n",
326
+ " _, predicted = torch.max(outputs, 1)\n",
327
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
328
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
329
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
330
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
331
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
332
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
333
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
334
+ " total += labels[0].size(0)\n",
335
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
336
+ " \n",
337
+ "# Calculate training accuracy after each epoch\n",
338
+ "val_accuracy = correct_valid / total * 100.0\n",
339
+ "print(\"===========================\")\n",
340
+ "print('accuracy: ', val_accuracy)\n",
341
+ "print(\"===========================\")\n",
342
+ "\n",
343
+ "import pickle\n",
344
+ "\n",
345
+ "# Pickle the dictionary to a file\n",
346
+ "with open('models_8/test_1.pkl', 'wb') as f:\n",
347
+ " pickle.dump(results, f)"
348
+ ]
349
+ }
350
+ ],
351
+ "metadata": {
352
+ "kernelspec": {
353
+ "display_name": "Python 3 (ipykernel)",
354
+ "language": "python",
355
+ "name": "python3"
356
+ },
357
+ "language_info": {
358
+ "codemirror_mode": {
359
+ "name": "ipython",
360
+ "version": 3
361
+ },
362
+ "file_extension": ".py",
363
+ "mimetype": "text/x-python",
364
+ "name": "python",
365
+ "nbconvert_exporter": "python",
366
+ "pygments_lexer": "ipython3",
367
+ "version": "3.11.9"
368
+ }
369
+ },
370
+ "nbformat": 4,
371
+ "nbformat_minor": 5
372
+ }
models/.ipynb_checkpoints/eval_mask-8-extend-checkpoint.ipynb ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ "100%|███████████████████████████████████████████| 48/48 [00:44<00:00, 1.09it/s]"
25
+ ]
26
+ },
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "===========================\n",
32
+ "accuracy: 98.94\n",
33
+ "===========================\n",
34
+ "False Positive Rate: 0.004\n",
35
+ "Precision: 0.996\n",
36
+ "Recall: 0.983\n",
37
+ "F1 Score: 0.989\n"
38
+ ]
39
+ },
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "\n"
45
+ ]
46
+ }
47
+ ],
48
+ "source": [
49
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
50
+ "from torch.utils.data import Dataset, DataLoader\n",
51
+ "from utils import CustomDataset, TestingDataset, transform\n",
52
+ "from tqdm import tqdm\n",
53
+ "import torch\n",
54
+ "import numpy as np\n",
55
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
56
+ "import torch\n",
57
+ "import torch.nn as nn\n",
58
+ "import torch.optim as optim\n",
59
+ "from tqdm import tqdm \n",
60
+ "import torch.nn.functional as F\n",
61
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
62
+ "import pickle\n",
63
+ "\n",
64
+ "torch.manual_seed(1)\n",
65
+ "# torch.manual_seed(42)\n",
66
+ "\n",
67
+ "\n",
68
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
69
+ "num_gpus = torch.cuda.device_count()\n",
70
+ "print(num_gpus)\n",
71
+ "\n",
72
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
73
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
74
+ "\n",
75
+ "num_classes = 2\n",
76
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
77
+ "\n",
78
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
79
+ "model = nn.DataParallel(model)\n",
80
+ "model = model.to(device)\n",
81
+ "params = sum(p.numel() for p in model.parameters())\n",
82
+ "print(\"num params \",params)\n",
83
+ "\n",
84
+ "model_1 = 'models_8/model-25-99.31_7109.pt'\n",
85
+ "# model_1 ='models/model-47-99.125.pt'\n",
86
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
87
+ "model = model.eval()\n",
88
+ "\n",
89
+ "# eval\n",
90
+ "val_loss = 0.0\n",
91
+ "correct_valid = 0\n",
92
+ "total = 0\n",
93
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
94
+ "model.eval()\n",
95
+ "with torch.no_grad():\n",
96
+ " for images, labels in tqdm(testloader):\n",
97
+ " inputs, labels = images.to(device), labels\n",
98
+ " outputs = model(inputs, return_mask = True)\n",
99
+ " _, predicted = torch.max(outputs, 1)\n",
100
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
101
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
102
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
103
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
104
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
105
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
106
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
107
+ " total += labels[0].size(0)\n",
108
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
109
+ " \n",
110
+ " \n",
111
+ "# Calculate training accuracy after each epoch\n",
112
+ "val_accuracy = correct_valid / total * 100.0\n",
113
+ "print(\"===========================\")\n",
114
+ "print('accuracy: ', val_accuracy)\n",
115
+ "print(\"===========================\")\n",
116
+ "\n",
117
+ "import pickle\n",
118
+ "\n",
119
+ "# Pickle the dictionary to a file\n",
120
+ "with open('models_8/test_7109.pkl', 'wb') as f:\n",
121
+ " pickle.dump(results, f)\n",
122
+ "\n",
123
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
124
+ "from sklearn.metrics import confusion_matrix\n",
125
+ "\n",
126
+ "# Example binary labels\n",
127
+ "true = results['true'] # ground truth\n",
128
+ "pred = results['pred'] # predicted\n",
129
+ "\n",
130
+ "# Compute metrics\n",
131
+ "precision = precision_score(true, pred)\n",
132
+ "recall = recall_score(true, pred)\n",
133
+ "f1 = f1_score(true, pred)\n",
134
+ "# Get confusion matrix: TN, FP, FN, TP\n",
135
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
136
+ "\n",
137
+ "# Compute FPR\n",
138
+ "fpr = fp / (fp + tn)\n",
139
+ "\n",
140
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
141
+ "\n",
142
+ "print(f\"Precision: {precision:.3f}\")\n",
143
+ "print(f\"Recall: {recall:.3f}\")\n",
144
+ "print(f\"F1 Score: {f1:.3f}\")"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 3,
150
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "name": "stdout",
155
+ "output_type": "stream",
156
+ "text": [
157
+ "2\n",
158
+ "num params encoder 50840\n",
159
+ "num params 21496282\n"
160
+ ]
161
+ },
162
+ {
163
+ "name": "stderr",
164
+ "output_type": "stream",
165
+ "text": [
166
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
167
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
168
+ "100%|███████████████████████████████████████████| 48/48 [00:41<00:00, 1.15it/s]"
169
+ ]
170
+ },
171
+ {
172
+ "name": "stdout",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "===========================\n",
176
+ "accuracy: 99.17\n",
177
+ "===========================\n",
178
+ "False Positive Rate: 0.004\n",
179
+ "Precision: 0.995\n",
180
+ "Recall: 0.988\n",
181
+ "F1 Score: 0.992\n"
182
+ ]
183
+ },
184
+ {
185
+ "name": "stderr",
186
+ "output_type": "stream",
187
+ "text": [
188
+ "\n"
189
+ ]
190
+ }
191
+ ],
192
+ "source": [
193
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
194
+ "from torch.utils.data import Dataset, DataLoader\n",
195
+ "from utils import CustomDataset, TestingDataset, transform\n",
196
+ "from tqdm import tqdm\n",
197
+ "import torch\n",
198
+ "import numpy as np\n",
199
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
200
+ "import torch\n",
201
+ "import torch.nn as nn\n",
202
+ "import torch.optim as optim\n",
203
+ "from tqdm import tqdm \n",
204
+ "import torch.nn.functional as F\n",
205
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
206
+ "import pickle\n",
207
+ "\n",
208
+ "torch.manual_seed(1)\n",
209
+ "# torch.manual_seed(42)\n",
210
+ "\n",
211
+ "\n",
212
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
213
+ "num_gpus = torch.cuda.device_count()\n",
214
+ "print(num_gpus)\n",
215
+ "\n",
216
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
217
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
218
+ "\n",
219
+ "num_classes = 2\n",
220
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
221
+ "\n",
222
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
223
+ "model = nn.DataParallel(model)\n",
224
+ "model = model.to(device)\n",
225
+ "params = sum(p.numel() for p in model.parameters())\n",
226
+ "print(\"num params \",params)\n",
227
+ "\n",
228
+ "\n",
229
+ "model_1 = 'models_8/model-44-99.445_42.pt'\n",
230
+ "# model_1 ='models/model-47-99.125.pt'\n",
231
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
232
+ "model = model.eval()\n",
233
+ "\n",
234
+ "# eval\n",
235
+ "val_loss = 0.0\n",
236
+ "correct_valid = 0\n",
237
+ "total = 0\n",
238
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
239
+ "model.eval()\n",
240
+ "with torch.no_grad():\n",
241
+ " for images, labels in tqdm(testloader):\n",
242
+ " inputs, labels = images.to(device), labels\n",
243
+ " outputs = model(inputs, return_mask = True)\n",
244
+ " _, predicted = torch.max(outputs, 1)\n",
245
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
246
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
247
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
248
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
249
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
250
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
251
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
252
+ " total += labels[0].size(0)\n",
253
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
254
+ " \n",
255
+ "# Calculate training accuracy after each epoch\n",
256
+ "val_accuracy = correct_valid / total * 100.0\n",
257
+ "print(\"===========================\")\n",
258
+ "print('accuracy: ', val_accuracy)\n",
259
+ "print(\"===========================\")\n",
260
+ "\n",
261
+ "import pickle\n",
262
+ "\n",
263
+ "# Pickle the dictionary to a file\n",
264
+ "with open('models_8/test_42.pkl', 'wb') as f:\n",
265
+ " pickle.dump(results, f)\n",
266
+ "\n",
267
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
268
+ "\n",
269
+ "# Example binary labels\n",
270
+ "true = results['true'] # ground truth\n",
271
+ "pred = results['pred'] # predicted\n",
272
+ "\n",
273
+ "# Compute metrics\n",
274
+ "precision = precision_score(true, pred)\n",
275
+ "recall = recall_score(true, pred)\n",
276
+ "f1 = f1_score(true, pred)\n",
277
+ "# Get confusion matrix: TN, FP, FN, TP\n",
278
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
279
+ "\n",
280
+ "# Compute FPR\n",
281
+ "fpr = fp / (fp + tn)\n",
282
+ "\n",
283
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
284
+ "\n",
285
+ "print(f\"Precision: {precision:.3f}\")\n",
286
+ "print(f\"Recall: {recall:.3f}\")\n",
287
+ "print(f\"F1 Score: {f1:.3f}\")"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 4,
293
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "2\n",
301
+ "num params encoder 50840\n",
302
+ "num params 21496282\n"
303
+ ]
304
+ },
305
+ {
306
+ "name": "stderr",
307
+ "output_type": "stream",
308
+ "text": [
309
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
310
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
311
+ "100%|███████████████████████████████████████████| 48/48 [00:53<00:00, 1.12s/it]"
312
+ ]
313
+ },
314
+ {
315
+ "name": "stdout",
316
+ "output_type": "stream",
317
+ "text": [
318
+ "===========================\n",
319
+ "accuracy: 99.035\n",
320
+ "===========================\n",
321
+ "False Positive Rate: 0.010\n",
322
+ "Precision: 0.990\n",
323
+ "Recall: 0.990\n",
324
+ "F1 Score: 0.990\n"
325
+ ]
326
+ },
327
+ {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "\n"
332
+ ]
333
+ }
334
+ ],
335
+ "source": [
336
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
337
+ "from torch.utils.data import Dataset, DataLoader\n",
338
+ "from utils import CustomDataset, TestingDataset, transform\n",
339
+ "from tqdm import tqdm\n",
340
+ "import torch\n",
341
+ "import numpy as np\n",
342
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
343
+ "import torch\n",
344
+ "import torch.nn as nn\n",
345
+ "import torch.optim as optim\n",
346
+ "from tqdm import tqdm \n",
347
+ "import torch.nn.functional as F\n",
348
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
349
+ "import pickle\n",
350
+ "\n",
351
+ "torch.manual_seed(1)\n",
352
+ "# torch.manual_seed(42)\n",
353
+ "\n",
354
+ "\n",
355
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
356
+ "num_gpus = torch.cuda.device_count()\n",
357
+ "print(num_gpus)\n",
358
+ "\n",
359
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
360
+ "test_dataset = TestingDataset(test_data_dir, bit8 =True, transform=transform)\n",
361
+ "\n",
362
+ "num_classes = 2\n",
363
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
364
+ "\n",
365
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
366
+ "model = nn.DataParallel(model)\n",
367
+ "model = model.to(device)\n",
368
+ "params = sum(p.numel() for p in model.parameters())\n",
369
+ "print(\"num params \",params)\n",
370
+ "\n",
371
+ "\n",
372
+ "model_1 = 'models_8/model-43-99.355_1.pt'\n",
373
+ "# model_1 ='models/model-47-99.125.pt'\n",
374
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
375
+ "model = model.eval()\n",
376
+ "\n",
377
+ "# eval\n",
378
+ "val_loss = 0.0\n",
379
+ "correct_valid = 0\n",
380
+ "total = 0\n",
381
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
382
+ "model.eval()\n",
383
+ "with torch.no_grad():\n",
384
+ " for images, labels in tqdm(testloader):\n",
385
+ " inputs, labels = images.to(device), labels\n",
386
+ " outputs = model(inputs, return_mask = True)\n",
387
+ " _, predicted = torch.max(outputs, 1)\n",
388
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
389
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
390
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
391
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
392
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
393
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
394
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
395
+ " total += labels[0].size(0)\n",
396
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
397
+ " \n",
398
+ "# Calculate training accuracy after each epoch\n",
399
+ "val_accuracy = correct_valid / total * 100.0\n",
400
+ "print(\"===========================\")\n",
401
+ "print('accuracy: ', val_accuracy)\n",
402
+ "print(\"===========================\")\n",
403
+ "\n",
404
+ "import pickle\n",
405
+ "\n",
406
+ "# Pickle the dictionary to a file\n",
407
+ "with open('models_8/test_1.pkl', 'wb') as f:\n",
408
+ " pickle.dump(results, f)\n",
409
+ "\n",
410
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
411
+ "\n",
412
+ "# Example binary labels\n",
413
+ "true = results['true'] # ground truth\n",
414
+ "pred = results['pred'] # predicted\n",
415
+ "\n",
416
+ "# Compute metrics\n",
417
+ "precision = precision_score(true, pred)\n",
418
+ "recall = recall_score(true, pred)\n",
419
+ "f1 = f1_score(true, pred)\n",
420
+ "# Get confusion matrix: TN, FP, FN, TP\n",
421
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
422
+ "\n",
423
+ "# Compute FPR\n",
424
+ "fpr = fp / (fp + tn)\n",
425
+ "\n",
426
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
427
+ "\n",
428
+ "print(f\"Precision: {precision:.3f}\")\n",
429
+ "print(f\"Recall: {recall:.3f}\")\n",
430
+ "print(f\"F1 Score: {f1:.3f}\")"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": 7,
436
+ "id": "8444ced5-686c-4c6e-bb02-8c4a45ff8af9",
437
+ "metadata": {},
438
+ "outputs": [
439
+ {
440
+ "name": "stdout",
441
+ "output_type": "stream",
442
+ "text": [
443
+ "99.04833333333333 0.09436925111261553\n",
444
+ "0.9936666666666666 0.002624669291337273\n",
445
+ "0.9903333333333334 0.0012472191289246482\n",
446
+ "0.006000000000000001 0.00282842712474619\n"
447
+ ]
448
+ }
449
+ ],
450
+ "source": [
451
+ "# acc\n",
452
+ "print(np.mean([98.94,99.17, 99.035 ]), np.std([98.94,99.17, 99.035 ]))\n",
453
+ "# recall\n",
454
+ "print(np.mean([0.996,0.988, 0.990]), np.std([0.996,0.995, 0.990]))\n",
455
+ "# f1\n",
456
+ "print(np.mean([0.989,0.992,0.990 ]),np.std([0.989,0.992,0.990 ]))\n",
457
+ "# fp\n",
458
+ "print(np.mean([0.004,0.004,0.010]),np.std([0.004,0.004,0.010]))\n"
459
+ ]
460
+ }
461
+ ],
462
+ "metadata": {
463
+ "kernelspec": {
464
+ "display_name": "Python 3 (ipykernel)",
465
+ "language": "python",
466
+ "name": "python3"
467
+ },
468
+ "language_info": {
469
+ "codemirror_mode": {
470
+ "name": "ipython",
471
+ "version": 3
472
+ },
473
+ "file_extension": ".py",
474
+ "mimetype": "text/x-python",
475
+ "name": "python",
476
+ "nbconvert_exporter": "python",
477
+ "pygments_lexer": "ipython3",
478
+ "version": "3.11.9"
479
+ }
480
+ },
481
+ "nbformat": 4,
482
+ "nbformat_minor": 5
483
+ }
models/.ipynb_checkpoints/eval_mask-checkpoint.ipynb ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "6\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ " 8%|███▋ | 4/48 [02:22<25:11, 34.35s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7efcbc3f67d0>>\n",
25
+ "Traceback (most recent call last):\n",
26
+ " File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n",
27
+ " def _clean_thread_parent_frames(\n",
28
+ "\n",
29
+ "KeyboardInterrupt: \n",
30
+ " 10%|████▌ | 5/48 [02:53<24:56, 34.79s/it]\n"
31
+ ]
32
+ },
33
+ {
34
+ "ename": "RuntimeError",
35
+ "evalue": "DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly",
36
+ "output_type": "error",
37
+ "traceback": [
38
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
39
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
40
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n",
41
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _ForkingPickler\u001b[38;5;241m.\u001b[39mloads(res)\n",
42
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/multiprocessing/reductions.py:496\u001b[0m, in \u001b[0;36mrebuild_storage_fd\u001b[0;34m(cls, df, size)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrebuild_storage_fd\u001b[39m(\u001b[38;5;28mcls\u001b[39m, df, size):\n\u001b[0;32m--> 496\u001b[0m fd \u001b[38;5;241m=\u001b[39m df\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
43
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:57\u001b[0m, in \u001b[0;36mDupFd.detach\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m'''Get the fd. This should only be called once.'''\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _resource_sharer\u001b[38;5;241m.\u001b[39mget_connection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_id) \u001b[38;5;28;01mas\u001b[39;00m conn:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m reduction\u001b[38;5;241m.\u001b[39mrecv_handle(conn)\n",
44
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:86\u001b[0m, in \u001b[0;36m_ResourceSharer.get_connection\u001b[0;34m(ident)\u001b[0m\n\u001b[1;32m 85\u001b[0m address, key \u001b[38;5;241m=\u001b[39m ident\n\u001b[0;32m---> 86\u001b[0m c \u001b[38;5;241m=\u001b[39m Client(address, authkey\u001b[38;5;241m=\u001b[39mprocess\u001b[38;5;241m.\u001b[39mcurrent_process()\u001b[38;5;241m.\u001b[39mauthkey)\n\u001b[1;32m 87\u001b[0m c\u001b[38;5;241m.\u001b[39msend((key, os\u001b[38;5;241m.\u001b[39mgetpid()))\n",
45
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:519\u001b[0m, in \u001b[0;36mClient\u001b[0;34m(address, family, authkey)\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 519\u001b[0m c \u001b[38;5;241m=\u001b[39m SocketClient(address)\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m authkey \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(authkey, \u001b[38;5;28mbytes\u001b[39m):\n",
46
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:647\u001b[0m, in \u001b[0;36mSocketClient\u001b[0;34m(address)\u001b[0m\n\u001b[1;32m 646\u001b[0m s\u001b[38;5;241m.\u001b[39msetblocking(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 647\u001b[0m s\u001b[38;5;241m.\u001b[39mconnect(address)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Connection(s\u001b[38;5;241m.\u001b[39mdetach())\n",
47
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory",
48
+ "\nThe above exception was the direct cause of the following exception:\n",
49
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
50
+ "Cell \u001b[0;32mIn[1], line 48\u001b[0m\n\u001b[1;32m 46\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[1;32m 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n",
51
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
52
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_data()\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
53
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_data()\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
54
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_try_get_data()\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
55
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1144\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1143\u001b[0m pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[0;32m-> 1144\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1145\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n",
56
+ "\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
62
+ "from torch.utils.data import Dataset, DataLoader\n",
63
+ "from utils import CustomDataset, TestingDataset, transform\n",
64
+ "from tqdm import tqdm\n",
65
+ "import torch\n",
66
+ "import numpy as np\n",
67
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
68
+ "import torch\n",
69
+ "import torch.nn as nn\n",
70
+ "import torch.optim as optim\n",
71
+ "from tqdm import tqdm \n",
72
+ "import torch.nn.functional as F\n",
73
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
74
+ "import pickle\n",
75
+ "\n",
76
+ "torch.manual_seed(1)\n",
77
+ "# torch.manual_seed(42)\n",
78
+ "\n",
79
+ "\n",
80
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
81
+ "num_gpus = torch.cuda.device_count()\n",
82
+ "print(num_gpus)\n",
83
+ "\n",
84
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
85
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
86
+ "\n",
87
+ "num_classes = 2\n",
88
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
89
+ "\n",
90
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
91
+ "model = nn.DataParallel(model)\n",
92
+ "model = model.to(device)\n",
93
+ "params = sum(p.numel() for p in model.parameters())\n",
94
+ "print(\"num params \",params)\n",
95
+ "\n",
96
+ "model_1 = 'models_mask/model-43-99.235_42.pt'\n",
97
+ "# model_1 ='models/model-47-99.125.pt'\n",
98
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
99
+ "model = model.eval()\n",
100
+ "\n",
101
+ "# eval\n",
102
+ "val_loss = 0.0\n",
103
+ "correct_valid = 0\n",
104
+ "total = 0\n",
105
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
106
+ "model.eval()\n",
107
+ "with torch.no_grad():\n",
108
+ " for images, labels in tqdm(testloader):\n",
109
+ " inputs, labels = images.to(device), labels\n",
110
+ " outputs = model(inputs, return_mask = True)\n",
111
+ " _, predicted = torch.max(outputs, 1)\n",
112
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
113
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
114
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
115
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
116
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
117
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
118
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
119
+ " total += labels[0].size(0)\n",
120
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
121
+ " \n",
122
+ "# Calculate training accuracy after each epoch\n",
123
+ "val_accuracy = correct_valid / total * 100.0\n",
124
+ "print(\"===========================\")\n",
125
+ "print('accuracy: ', val_accuracy)\n",
126
+ "print(\"===========================\")\n",
127
+ "\n",
128
+ "import pickle\n",
129
+ "\n",
130
+ "# Pickle the dictionary to a file\n",
131
+ "with open('models_mask/test_42.pkl', 'wb') as f:\n",
132
+ " pickle.dump(results, f)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
143
+ "from torch.utils.data import Dataset, DataLoader\n",
144
+ "from utils import CustomDataset, TestingDataset, transform\n",
145
+ "from tqdm import tqdm\n",
146
+ "import torch\n",
147
+ "import numpy as np\n",
148
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
149
+ "import torch\n",
150
+ "import torch.nn as nn\n",
151
+ "import torch.optim as optim\n",
152
+ "from tqdm import tqdm \n",
153
+ "import torch.nn.functional as F\n",
154
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
155
+ "import pickle\n",
156
+ "\n",
157
+ "torch.manual_seed(1)\n",
158
+ "# torch.manual_seed(42)\n",
159
+ "\n",
160
+ "\n",
161
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
162
+ "num_gpus = torch.cuda.device_count()\n",
163
+ "print(num_gpus)\n",
164
+ "\n",
165
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
166
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
167
+ "\n",
168
+ "num_classes = 2\n",
169
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
170
+ "\n",
171
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
172
+ "model = nn.DataParallel(model)\n",
173
+ "model = model.to(device)\n",
174
+ "params = sum(p.numel() for p in model.parameters())\n",
175
+ "print(\"num params \",params)\n",
176
+ "\n",
177
+ "\n",
178
+ "model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
179
+ "# model_1 ='models/model-47-99.125.pt'\n",
180
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
181
+ "model = model.eval()\n",
182
+ "\n",
183
+ "# eval\n",
184
+ "val_loss = 0.0\n",
185
+ "correct_valid = 0\n",
186
+ "total = 0\n",
187
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
188
+ "model.eval()\n",
189
+ "with torch.no_grad():\n",
190
+ " for images, labels in tqdm(testloader):\n",
191
+ " inputs, labels = images.to(device), labels\n",
192
+ " outputs = model(inputs, return_mask = True)\n",
193
+ " _, predicted = torch.max(outputs, 1)\n",
194
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
195
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
196
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
197
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
198
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
199
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
200
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
201
+ " total += labels[0].size(0)\n",
202
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
203
+ " \n",
204
+ " \n",
205
+ "# Calculate training accuracy after each epoch\n",
206
+ "val_accuracy = correct_valid / total * 100.0\n",
207
+ "print(\"===========================\")\n",
208
+ "print('accuracy: ', val_accuracy)\n",
209
+ "print(\"===========================\")\n",
210
+ "\n",
211
+ "import pickle\n",
212
+ "\n",
213
+ "# Pickle the dictionary to a file\n",
214
+ "with open('models_mask/test_1.pkl', 'wb') as f:\n",
215
+ " pickle.dump(results, f)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
226
+ "from torch.utils.data import Dataset, DataLoader\n",
227
+ "from utils import CustomDataset, TestingDataset, transform\n",
228
+ "from tqdm import tqdm\n",
229
+ "import torch\n",
230
+ "import numpy as np\n",
231
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
232
+ "import torch\n",
233
+ "import torch.nn as nn\n",
234
+ "import torch.optim as optim\n",
235
+ "from tqdm import tqdm \n",
236
+ "import torch.nn.functional as F\n",
237
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
238
+ "import pickle\n",
239
+ "\n",
240
+ "torch.manual_seed(1)\n",
241
+ "# torch.manual_seed(42)\n",
242
+ "\n",
243
+ "\n",
244
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
245
+ "num_gpus = torch.cuda.device_count()\n",
246
+ "print(num_gpus)\n",
247
+ "\n",
248
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
249
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
250
+ "\n",
251
+ "num_classes = 2\n",
252
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
253
+ "\n",
254
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
255
+ "model = nn.DataParallel(model)\n",
256
+ "model = model.to(device)\n",
257
+ "params = sum(p.numel() for p in model.parameters())\n",
258
+ "print(\"num params \",params)\n",
259
+ "\n",
260
+ "\n",
261
+ "model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
262
+ "# model_1 ='models/model-47-99.125.pt'\n",
263
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
264
+ "model = model.eval()\n",
265
+ "\n",
266
+ "# eval\n",
267
+ "val_loss = 0.0\n",
268
+ "correct_valid = 0\n",
269
+ "total = 0\n",
270
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
271
+ "model.eval()\n",
272
+ "with torch.no_grad():\n",
273
+ " for images, labels in tqdm(testloader):\n",
274
+ " inputs, labels = images.to(device), labels\n",
275
+ " outputs = model(inputs, return_mask = True)\n",
276
+ " _, predicted = torch.max(outputs, 1)\n",
277
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
278
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
279
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
280
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
281
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
282
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
283
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
284
+ " total += labels[0].size(0)\n",
285
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
286
+ " \n",
287
+ " \n",
288
+ "# Calculate training accuracy after each epoch\n",
289
+ "val_accuracy = correct_valid / total * 100.0\n",
290
+ "print(\"===========================\")\n",
291
+ "print('accuracy: ', val_accuracy)\n",
292
+ "print(\"===========================\")\n",
293
+ "\n",
294
+ "import pickle\n",
295
+ "\n",
296
+ "# Pickle the dictionary to a file\n",
297
+ "with open('models_mask/test_7109.pkl', 'wb') as f:\n",
298
+ " pickle.dump(results, f)"
299
+ ]
300
+ }
301
+ ],
302
+ "metadata": {
303
+ "kernelspec": {
304
+ "display_name": "Python 3 (ipykernel)",
305
+ "language": "python",
306
+ "name": "python3"
307
+ },
308
+ "language_info": {
309
+ "codemirror_mode": {
310
+ "name": "ipython",
311
+ "version": 3
312
+ },
313
+ "file_extension": ".py",
314
+ "mimetype": "text/x-python",
315
+ "name": "python",
316
+ "nbconvert_exporter": "python",
317
+ "pygments_lexer": "ipython3",
318
+ "version": "3.11.9"
319
+ }
320
+ },
321
+ "nbformat": 4,
322
+ "nbformat_minor": 5
323
+ }
models/.ipynb_checkpoints/eval_mask-extend-checkpoint.ipynb ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
24
+ "100%|███████████████████████████████████████████| 48/48 [00:45<00:00, 1.07it/s]"
25
+ ]
26
+ },
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "===========================\n",
32
+ "accuracy: 99.125\n",
33
+ "===========================\n",
34
+ "Precision: 0.992\n",
35
+ "Recall: 0.991\n",
36
+ "F1 Score: 0.991\n"
37
+ ]
38
+ },
39
+ {
40
+ "name": "stderr",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
49
+ "from torch.utils.data import Dataset, DataLoader\n",
50
+ "from utils import CustomDataset, TestingDataset, transform\n",
51
+ "from tqdm import tqdm\n",
52
+ "import torch\n",
53
+ "import numpy as np\n",
54
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
55
+ "import torch\n",
56
+ "import torch.nn as nn\n",
57
+ "import torch.optim as optim\n",
58
+ "from tqdm import tqdm \n",
59
+ "import torch.nn.functional as F\n",
60
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
61
+ "import pickle\n",
62
+ "\n",
63
+ "torch.manual_seed(1)\n",
64
+ "# torch.manual_seed(42)\n",
65
+ "\n",
66
+ "\n",
67
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
68
+ "num_gpus = torch.cuda.device_count()\n",
69
+ "print(num_gpus)\n",
70
+ "\n",
71
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
72
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
73
+ "\n",
74
+ "num_classes = 2\n",
75
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
76
+ "\n",
77
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
78
+ "model = nn.DataParallel(model)\n",
79
+ "model = model.to(device)\n",
80
+ "params = sum(p.numel() for p in model.parameters())\n",
81
+ "print(\"num params \",params)\n",
82
+ "\n",
83
+ "model_1 = 'models_mask/model-43-99.235_42.pt'\n",
84
+ "# model_1 ='models/model-47-99.125.pt'\n",
85
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
86
+ "model = model.eval()\n",
87
+ "\n",
88
+ "# eval\n",
89
+ "val_loss = 0.0\n",
90
+ "correct_valid = 0\n",
91
+ "total = 0\n",
92
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
93
+ "model.eval()\n",
94
+ "with torch.no_grad():\n",
95
+ " for images, labels in tqdm(testloader):\n",
96
+ " inputs, labels = images.to(device), labels\n",
97
+ " outputs = model(inputs, return_mask = True)\n",
98
+ " _, predicted = torch.max(outputs, 1)\n",
99
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
100
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
101
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
102
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
103
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
104
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
105
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
106
+ " total += labels[0].size(0)\n",
107
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
108
+ " \n",
109
+ "# Calculate training accuracy after each epoch\n",
110
+ "val_accuracy = correct_valid / total * 100.0\n",
111
+ "print(\"===========================\")\n",
112
+ "print('accuracy: ', val_accuracy)\n",
113
+ "print(\"===========================\")\n",
114
+ "\n",
115
+ "import pickle\n",
116
+ "\n",
117
+ "# Pickle the dictionary to a file\n",
118
+ "with open('models_mask/test_42.pkl', 'wb') as f:\n",
119
+ " pickle.dump(results, f)\n",
120
+ "\n",
121
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
122
+ "\n",
123
+ "# Example binary labels\n",
124
+ "true = results['true'] # ground truth\n",
125
+ "pred = results['pred'] # predicted\n",
126
+ "\n",
127
+ "# Compute metrics\n",
128
+ "precision = precision_score(true, pred)\n",
129
+ "recall = recall_score(true, pred)\n",
130
+ "f1 = f1_score(true, pred)\n",
131
+ "# Get confusion matrix: TN, FP, FN, TP\n",
132
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
133
+ "\n",
134
+ "# Compute FPR\n",
135
+ "fpr = fp / (fp + tn)\n",
136
+ "\n",
137
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
138
+ "\n",
139
+ "print(f\"Precision: {precision:.3f}\")\n",
140
+ "print(f\"Recall: {recall:.3f}\")\n",
141
+ "print(f\"F1 Score: {f1:.3f}\")\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 10,
147
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
148
+ "metadata": {},
149
+ "outputs": [
150
+ {
151
+ "name": "stdout",
152
+ "output_type": "stream",
153
+ "text": [
154
+ "2\n",
155
+ "num params encoder 50840\n",
156
+ "num params 21496282\n"
157
+ ]
158
+ },
159
+ {
160
+ "name": "stderr",
161
+ "output_type": "stream",
162
+ "text": [
163
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
164
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
165
+ "100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
166
+ ]
167
+ },
168
+ {
169
+ "name": "stdout",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "===========================\n",
173
+ "accuracy: 98.77\n",
174
+ "===========================\n",
175
+ "Precision: 0.982\n",
176
+ "Recall: 0.993\n",
177
+ "F1 Score: 0.988\n"
178
+ ]
179
+ },
180
+ {
181
+ "name": "stderr",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "\n"
185
+ ]
186
+ }
187
+ ],
188
+ "source": [
189
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
190
+ "from torch.utils.data import Dataset, DataLoader\n",
191
+ "from utils import CustomDataset, TestingDataset, transform\n",
192
+ "from tqdm import tqdm\n",
193
+ "import torch\n",
194
+ "import numpy as np\n",
195
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
196
+ "import torch\n",
197
+ "import torch.nn as nn\n",
198
+ "import torch.optim as optim\n",
199
+ "from tqdm import tqdm \n",
200
+ "import torch.nn.functional as F\n",
201
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
202
+ "import pickle\n",
203
+ "\n",
204
+ "torch.manual_seed(1)\n",
205
+ "\n",
206
+ "\n",
207
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
208
+ "num_gpus = torch.cuda.device_count()\n",
209
+ "print(num_gpus)\n",
210
+ "\n",
211
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
212
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
213
+ "\n",
214
+ "num_classes = 2\n",
215
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
216
+ "\n",
217
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
218
+ "model = nn.DataParallel(model)\n",
219
+ "model = model.to(device)\n",
220
+ "params = sum(p.numel() for p in model.parameters())\n",
221
+ "print(\"num params \",params)\n",
222
+ "\n",
223
+ "\n",
224
+ "model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
225
+ "# model_1 ='models/model-47-99.125.pt'\n",
226
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
227
+ "model = model.eval()\n",
228
+ "\n",
229
+ "# eval\n",
230
+ "val_loss = 0.0\n",
231
+ "correct_valid = 0\n",
232
+ "total = 0\n",
233
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
234
+ "model.eval()\n",
235
+ "with torch.no_grad():\n",
236
+ " for images, labels in tqdm(testloader):\n",
237
+ " inputs, labels = images.to(device), labels\n",
238
+ " outputs = model(inputs, return_mask = True)\n",
239
+ " _, predicted = torch.max(outputs, 1)\n",
240
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
241
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
242
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
243
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
244
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
245
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
246
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
247
+ " total += labels[0].size(0)\n",
248
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
249
+ " \n",
250
+ " \n",
251
+ "# Calculate training accuracy after each epoch\n",
252
+ "val_accuracy = correct_valid / total * 100.0\n",
253
+ "print(\"===========================\")\n",
254
+ "print('accuracy: ', val_accuracy)\n",
255
+ "print(\"===========================\")\n",
256
+ "\n",
257
+ "import pickle\n",
258
+ "\n",
259
+ "# Pickle the dictionary to a file\n",
260
+ "with open('models_mask/test_1.pkl', 'wb') as f:\n",
261
+ " pickle.dump(results, f)\n",
262
+ "\n",
263
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
264
+ "\n",
265
+ "# Example binary labels\n",
266
+ "true = results['true'] # ground truth\n",
267
+ "pred = results['pred'] # predicted\n",
268
+ "\n",
269
+ "# Compute metrics\n",
270
+ "precision = precision_score(true, pred)\n",
271
+ "recall = recall_score(true, pred)\n",
272
+ "f1 = f1_score(true, pred)\n",
273
+ "# Get confusion matrix: TN, FP, FN, TP\n",
274
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
275
+ "\n",
276
+ "# Compute FPR\n",
277
+ "fpr = fp / (fp + tn)\n",
278
+ "\n",
279
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
280
+ "\n",
281
+ "print(f\"Precision: {precision:.3f}\")\n",
282
+ "print(f\"Recall: {recall:.3f}\")\n",
283
+ "print(f\"F1 Score: {f1:.3f}\")"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 11,
289
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
290
+ "metadata": {},
291
+ "outputs": [
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "2\n",
297
+ "num params encoder 50840\n",
298
+ "num params 21496282\n"
299
+ ]
300
+ },
301
+ {
302
+ "name": "stderr",
303
+ "output_type": "stream",
304
+ "text": [
305
+ " 0%| | 0/48 [00:00<?, ?it/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
306
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
307
+ "100%|███████████████████████████████████████████| 48/48 [00:43<00:00, 1.11it/s]"
308
+ ]
309
+ },
310
+ {
311
+ "name": "stdout",
312
+ "output_type": "stream",
313
+ "text": [
314
+ "===========================\n",
315
+ "accuracy: 99.03\n",
316
+ "===========================\n",
317
+ "Precision: 0.990\n",
318
+ "Recall: 0.990\n",
319
+ "F1 Score: 0.990\n"
320
+ ]
321
+ },
322
+ {
323
+ "name": "stderr",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "\n"
327
+ ]
328
+ }
329
+ ],
330
+ "source": [
331
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
332
+ "from torch.utils.data import Dataset, DataLoader\n",
333
+ "from utils import CustomDataset, TestingDataset, transform\n",
334
+ "from tqdm import tqdm\n",
335
+ "import torch\n",
336
+ "import numpy as np\n",
337
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
338
+ "import torch\n",
339
+ "import torch.nn as nn\n",
340
+ "import torch.optim as optim\n",
341
+ "from tqdm import tqdm \n",
342
+ "import torch.nn.functional as F\n",
343
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
344
+ "import pickle\n",
345
+ "\n",
346
+ "torch.manual_seed(1)\n",
347
+ "# torch.manual_seed(42)\n",
348
+ "\n",
349
+ "\n",
350
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
351
+ "num_gpus = torch.cuda.device_count()\n",
352
+ "print(num_gpus)\n",
353
+ "\n",
354
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
355
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
356
+ "\n",
357
+ "num_classes = 2\n",
358
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
359
+ "\n",
360
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
361
+ "model = nn.DataParallel(model)\n",
362
+ "model = model.to(device)\n",
363
+ "params = sum(p.numel() for p in model.parameters())\n",
364
+ "print(\"num params \",params)\n",
365
+ "\n",
366
+ "\n",
367
+ "model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
368
+ "# model_1 ='models/model-47-99.125.pt'\n",
369
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
370
+ "model = model.eval()\n",
371
+ "\n",
372
+ "# eval\n",
373
+ "val_loss = 0.0\n",
374
+ "correct_valid = 0\n",
375
+ "total = 0\n",
376
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
377
+ "model.eval()\n",
378
+ "with torch.no_grad():\n",
379
+ " for images, labels in tqdm(testloader):\n",
380
+ " inputs, labels = images.to(device), labels\n",
381
+ " outputs = model(inputs, return_mask = True)\n",
382
+ " _, predicted = torch.max(outputs, 1)\n",
383
+ " results['output'].extend(outputs.cpu().numpy().tolist())\n",
384
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
385
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
386
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
387
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
388
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
389
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
390
+ " total += labels[0].size(0)\n",
391
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
392
+ " \n",
393
+ " \n",
394
+ "# Calculate training accuracy after each epoch\n",
395
+ "val_accuracy = correct_valid / total * 100.0\n",
396
+ "print(\"===========================\")\n",
397
+ "print('accuracy: ', val_accuracy)\n",
398
+ "print(\"===========================\")\n",
399
+ "\n",
400
+ "import pickle\n",
401
+ "\n",
402
+ "# Pickle the dictionary to a file\n",
403
+ "with open('models_mask/test_7109.pkl', 'wb') as f:\n",
404
+ " pickle.dump(results, f)\n",
405
+ "\n",
406
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
407
+ "\n",
408
+ "# Example binary labels\n",
409
+ "true = results['true'] # ground truth\n",
410
+ "pred = results['pred'] # predicted\n",
411
+ "\n",
412
+ "# Compute metrics\n",
413
+ "precision = precision_score(true, pred)\n",
414
+ "recall = recall_score(true, pred)\n",
415
+ "f1 = f1_score(true, pred)\n",
416
+ "# Get confusion matrix: TN, FP, FN, TP\n",
417
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
418
+ "\n",
419
+ "# Compute FPR\n",
420
+ "fpr = fp / (fp + tn)\n",
421
+ "\n",
422
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
423
+ "\n",
424
+ "print(f\"Precision: {precision:.3f}\")\n",
425
+ "print(f\"Recall: {recall:.3f}\")\n",
426
+ "print(f\"F1 Score: {f1:.3f}\")"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": 17,
432
+ "id": "974e62d6-5088-4cd8-9721-6702717eadee",
433
+ "metadata": {},
434
+ "outputs": [
435
+ {
436
+ "name": "stdout",
437
+ "output_type": "stream",
438
+ "text": [
439
+ "98.97499999999998 0.1500555452713003\n",
440
+ "0.9913333333333333 0.0012472191289246482\n",
441
+ "0.9896666666666666 0.0012472191289246482\n"
442
+ ]
443
+ }
444
+ ],
445
+ "source": [
446
+ "# acc\n",
447
+ "print(np.mean([99.125,98.77, 99.03 ]), np.std([99.125,98.77, 99.03 ]))\n",
448
+ "# precision\n",
449
+ "print(np.mean([0.991,0.990, 0.993]), np.std([0.991,0.990, 0.993]))\n",
450
+ "# f1\n",
451
+ "print(np.mean([0.990,0.988,0.991 ]),np.std([0.990,0.988,0.991 ]))\n",
452
+ "# recall"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": 18,
458
+ "id": "3eee97ad-114f-4090-b54a-6ec0cc7150f5",
459
+ "metadata": {},
460
+ "outputs": [
461
+ {
462
+ "name": "stdout",
463
+ "output_type": "stream",
464
+ "text": [
465
+ "False Positive Rate: 0.200\n"
466
+ ]
467
+ }
468
+ ],
469
+ "source": [
470
+ "from sklearn.metrics import confusion_matrix\n",
471
+ "\n",
472
+ "# Ground truth and predictions\n",
473
+ "true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 0]\n",
474
+ "pred = [1, 0, 1, 0, 0, 1, 1, 0, 1, 0]\n",
475
+ "\n"
476
+ ]
477
+ }
478
+ ],
479
+ "metadata": {
480
+ "kernelspec": {
481
+ "display_name": "Python 3 (ipykernel)",
482
+ "language": "python",
483
+ "name": "python3"
484
+ },
485
+ "language_info": {
486
+ "codemirror_mode": {
487
+ "name": "ipython",
488
+ "version": 3
489
+ },
490
+ "file_extension": ".py",
491
+ "mimetype": "text/x-python",
492
+ "name": "python",
493
+ "nbconvert_exporter": "python",
494
+ "pygments_lexer": "ipython3",
495
+ "version": "3.11.9"
496
+ }
497
+ },
498
+ "nbformat": 4,
499
+ "nbformat_minor": 5
500
+ }
models/.ipynb_checkpoints/eval_mask_threshold-extend-checkpoint.ipynb ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ },
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
23
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "===========================\n",
31
+ "accuracy: 99.195\n",
32
+ "===========================\n",
33
+ "False Positive Rate: 0.005\n",
34
+ "Precision: 0.995\n",
35
+ "Recall: 0.989\n",
36
+ "F1 Score: 0.992\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
42
+ "from torch.utils.data import Dataset, DataLoader\n",
43
+ "from utils import CustomDataset, TestingDataset, transform\n",
44
+ "from tqdm import tqdm\n",
45
+ "import torch\n",
46
+ "import numpy as np\n",
47
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
48
+ "import torch\n",
49
+ "import torch.nn as nn\n",
50
+ "import torch.optim as optim\n",
51
+ "from tqdm import tqdm \n",
52
+ "import torch.nn.functional as F\n",
53
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
54
+ "import pickle\n",
55
+ "\n",
56
+ "torch.manual_seed(1)\n",
57
+ "# torch.manual_seed(42)\n",
58
+ "\n",
59
+ "\n",
60
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
61
+ "num_gpus = torch.cuda.device_count()\n",
62
+ "print(num_gpus)\n",
63
+ "threshold = 0.992\n",
64
+ "\n",
65
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
66
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
67
+ "\n",
68
+ "num_classes = 2\n",
69
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
70
+ "\n",
71
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
72
+ "model = nn.DataParallel(model)\n",
73
+ "model = model.to(device)\n",
74
+ "params = sum(p.numel() for p in model.parameters())\n",
75
+ "print(\"num params \",params)\n",
76
+ "\n",
77
+ "model_1 = 'models_mask/model-43-99.235_42.pt'\n",
78
+ "# model_1 ='models/model-47-99.125.pt'\n",
79
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
80
+ "model = model.eval()\n",
81
+ "\n",
82
+ "# eval\n",
83
+ "val_loss = 0.0\n",
84
+ "correct_valid = 0\n",
85
+ "total = 0\n",
86
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
87
+ "model.eval()\n",
88
+ "with torch.no_grad():\n",
89
+ " for images, labels in testloader:\n",
90
+ " inputs, labels = images.to(device), labels\n",
91
+ " outputs = nn.Softmax(dim = 1)(model(inputs))\n",
92
+ " selection = outputs[:, 1] > threshold\n",
93
+ " predicted = selection.int()\n",
94
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
95
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
96
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
97
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
98
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
99
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
100
+ " total += labels[0].size(0)\n",
101
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
102
+ " \n",
103
+ "# Calculate training accuracy after each epoch\n",
104
+ "val_accuracy = correct_valid / total * 100.0\n",
105
+ "print(\"===========================\")\n",
106
+ "print('accuracy: ', val_accuracy)\n",
107
+ "print(\"===========================\")\n",
108
+ "\n",
109
+ "import pickle\n",
110
+ "\n",
111
+ "# Pickle the dictionary to a file\n",
112
+ "with open('models_mask/test_42.pkl', 'wb') as f:\n",
113
+ " pickle.dump(results, f)\n",
114
+ "\n",
115
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
116
+ "from sklearn.metrics import confusion_matrix\n",
117
+ "\n",
118
+ "# Example binary labels\n",
119
+ "true = results['true'] # ground truth\n",
120
+ "pred = results['pred'] # predicted\n",
121
+ "\n",
122
+ "# Compute metrics\n",
123
+ "precision = precision_score(true, pred)\n",
124
+ "recall = recall_score(true, pred)\n",
125
+ "f1 = f1_score(true, pred)\n",
126
+ "# Get confusion matrix: TN, FP, FN, TP\n",
127
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
128
+ "\n",
129
+ "# Compute FPR\n",
130
+ "fpr = fp / (fp + tn)\n",
131
+ "\n",
132
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
133
+ "\n",
134
+ "print(f\"Precision: {precision:.3f}\")\n",
135
+ "print(f\"Recall: {recall:.3f}\")\n",
136
+ "print(f\"F1 Score: {f1:.3f}\")\n"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 2,
142
+ "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3",
143
+ "metadata": {},
144
+ "outputs": [
145
+ {
146
+ "name": "stdout",
147
+ "output_type": "stream",
148
+ "text": [
149
+ "2\n",
150
+ "num params encoder 50840\n",
151
+ "num params 21496282\n"
152
+ ]
153
+ },
154
+ {
155
+ "name": "stderr",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
159
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
160
+ ]
161
+ },
162
+ {
163
+ "name": "stdout",
164
+ "output_type": "stream",
165
+ "text": [
166
+ "===========================\n",
167
+ "accuracy: 99.195\n",
168
+ "===========================\n",
169
+ "False Positive Rate: 0.007\n",
170
+ "Precision: 0.993\n",
171
+ "Recall: 0.991\n",
172
+ "F1 Score: 0.992\n"
173
+ ]
174
+ }
175
+ ],
176
+ "source": [
177
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
178
+ "from torch.utils.data import Dataset, DataLoader\n",
179
+ "from utils import CustomDataset, TestingDataset, transform\n",
180
+ "from tqdm import tqdm\n",
181
+ "import torch\n",
182
+ "import numpy as np\n",
183
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
184
+ "import torch\n",
185
+ "import torch.nn as nn\n",
186
+ "import torch.optim as optim\n",
187
+ "from tqdm import tqdm \n",
188
+ "import torch.nn.functional as F\n",
189
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
190
+ "import pickle\n",
191
+ "\n",
192
+ "torch.manual_seed(1)\n",
193
+ "\n",
194
+ "\n",
195
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
196
+ "num_gpus = torch.cuda.device_count()\n",
197
+ "print(num_gpus)\n",
198
+ "\n",
199
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
200
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
201
+ "\n",
202
+ "num_classes = 2\n",
203
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
204
+ "\n",
205
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
206
+ "model = nn.DataParallel(model)\n",
207
+ "model = model.to(device)\n",
208
+ "params = sum(p.numel() for p in model.parameters())\n",
209
+ "print(\"num params \",params)\n",
210
+ "\n",
211
+ "\n",
212
+ "model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n",
213
+ "# model_1 ='models/model-47-99.125.pt'\n",
214
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
215
+ "model = model.eval()\n",
216
+ "\n",
217
+ "# eval\n",
218
+ "val_loss = 0.0\n",
219
+ "correct_valid = 0\n",
220
+ "total = 0\n",
221
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
222
+ "model.eval()\n",
223
+ "with torch.no_grad():\n",
224
+ " for images, labels in testloader:\n",
225
+ " inputs, labels = images.to(device), labels\n",
226
+ " outputs = nn.Softmax(dim = 1)(model(inputs))\n",
227
+ " selection = outputs[:, 1] > threshold\n",
228
+ " predicted = selection.int()\n",
229
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
230
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
231
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
232
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
233
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
234
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
235
+ " total += labels[0].size(0)\n",
236
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
237
+ " \n",
238
+ " \n",
239
+ "# Calculate training accuracy after each epoch\n",
240
+ "val_accuracy = correct_valid / total * 100.0\n",
241
+ "print(\"===========================\")\n",
242
+ "print('accuracy: ', val_accuracy)\n",
243
+ "print(\"===========================\")\n",
244
+ "\n",
245
+ "import pickle\n",
246
+ "\n",
247
+ "# Pickle the dictionary to a file\n",
248
+ "with open('models_mask/test_1.pkl', 'wb') as f:\n",
249
+ " pickle.dump(results, f)\n",
250
+ "\n",
251
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
252
+ "\n",
253
+ "# Example binary labels\n",
254
+ "true = results['true'] # ground truth\n",
255
+ "pred = results['pred'] # predicted\n",
256
+ "\n",
257
+ "# Compute metrics\n",
258
+ "precision = precision_score(true, pred)\n",
259
+ "recall = recall_score(true, pred)\n",
260
+ "f1 = f1_score(true, pred)\n",
261
+ "# Get confusion matrix: TN, FP, FN, TP\n",
262
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
263
+ "\n",
264
+ "# Compute FPR\n",
265
+ "fpr = fp / (fp + tn)\n",
266
+ "\n",
267
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
268
+ "\n",
269
+ "print(f\"Precision: {precision:.3f}\")\n",
270
+ "print(f\"Recall: {recall:.3f}\")\n",
271
+ "print(f\"F1 Score: {f1:.3f}\")"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 3,
277
+ "id": "fe74ada8-43e4-4c73-b772-0ef18983345d",
278
+ "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "name": "stdout",
282
+ "output_type": "stream",
283
+ "text": [
284
+ "2\n",
285
+ "num params encoder 50840\n",
286
+ "num params 21496282\n"
287
+ ]
288
+ },
289
+ {
290
+ "name": "stderr",
291
+ "output_type": "stream",
292
+ "text": [
293
+ "/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
294
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n"
295
+ ]
296
+ },
297
+ {
298
+ "name": "stdout",
299
+ "output_type": "stream",
300
+ "text": [
301
+ "===========================\n",
302
+ "accuracy: 99.035\n",
303
+ "===========================\n",
304
+ "False Positive Rate: 0.007\n",
305
+ "Precision: 0.993\n",
306
+ "Recall: 0.987\n",
307
+ "F1 Score: 0.990\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
313
+ "from torch.utils.data import Dataset, DataLoader\n",
314
+ "from utils import CustomDataset, TestingDataset, transform\n",
315
+ "from tqdm import tqdm\n",
316
+ "import torch\n",
317
+ "import numpy as np\n",
318
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
319
+ "import torch\n",
320
+ "import torch.nn as nn\n",
321
+ "import torch.optim as optim\n",
322
+ "from tqdm import tqdm \n",
323
+ "import torch.nn.functional as F\n",
324
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
325
+ "import pickle\n",
326
+ "\n",
327
+ "torch.manual_seed(1)\n",
328
+ "# torch.manual_seed(42)\n",
329
+ "\n",
330
+ "\n",
331
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
332
+ "num_gpus = torch.cuda.device_count()\n",
333
+ "print(num_gpus)\n",
334
+ "\n",
335
+ "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n",
336
+ "test_dataset = TestingDataset(test_data_dir, transform=transform)\n",
337
+ "\n",
338
+ "num_classes = 2\n",
339
+ "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n",
340
+ "\n",
341
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
342
+ "model = nn.DataParallel(model)\n",
343
+ "model = model.to(device)\n",
344
+ "params = sum(p.numel() for p in model.parameters())\n",
345
+ "print(\"num params \",params)\n",
346
+ "\n",
347
+ "\n",
348
+ "model_1 = 'models_mask/model-26-99.13_7109.pt'\n",
349
+ "# model_1 ='models/model-47-99.125.pt'\n",
350
+ "model.load_state_dict(torch.load(model_1, weights_only=True))\n",
351
+ "model = model.eval()\n",
352
+ "\n",
353
+ "# eval\n",
354
+ "val_loss = 0.0\n",
355
+ "correct_valid = 0\n",
356
+ "total = 0\n",
357
+ "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n",
358
+ "model.eval()\n",
359
+ "with torch.no_grad():\n",
360
+ " for images, labels in testloader:\n",
361
+ " inputs, labels = images.to(device), labels\n",
362
+ " outputs = nn.Softmax(dim = 1)(model(inputs))\n",
363
+ " selection = outputs[:, 1] > threshold\n",
364
+ " predicted = selection.int()\n",
365
+ " results['pred'].extend(predicted.cpu().numpy().tolist())\n",
366
+ " results['true'].extend(labels[0].cpu().numpy().tolist())\n",
367
+ " results['freq'].extend(labels[2].cpu().numpy().tolist())\n",
368
+ " results['dm'].extend(labels[1].cpu().numpy().tolist())\n",
369
+ " results['snr'].extend(labels[3].cpu().numpy().tolist())\n",
370
+ " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n",
371
+ " total += labels[0].size(0)\n",
372
+ " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n",
373
+ " \n",
374
+ " \n",
375
+ "# Calculate training accuracy after each epoch\n",
376
+ "val_accuracy = correct_valid / total * 100.0\n",
377
+ "print(\"===========================\")\n",
378
+ "print('accuracy: ', val_accuracy)\n",
379
+ "print(\"===========================\")\n",
380
+ "\n",
381
+ "import pickle\n",
382
+ "\n",
383
+ "# Pickle the dictionary to a file\n",
384
+ "with open('models_mask/test_7109.pkl', 'wb') as f:\n",
385
+ " pickle.dump(results, f)\n",
386
+ "\n",
387
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
388
+ "\n",
389
+ "# Example binary labels\n",
390
+ "true = results['true'] # ground truth\n",
391
+ "pred = results['pred'] # predicted\n",
392
+ "\n",
393
+ "# Compute metrics\n",
394
+ "precision = precision_score(true, pred)\n",
395
+ "recall = recall_score(true, pred)\n",
396
+ "f1 = f1_score(true, pred)\n",
397
+ "# Get confusion matrix: TN, FP, FN, TP\n",
398
+ "tn, fp, fn, tp = confusion_matrix(true, pred).ravel()\n",
399
+ "\n",
400
+ "# Compute FPR\n",
401
+ "fpr = fp / (fp + tn)\n",
402
+ "\n",
403
+ "print(f\"False Positive Rate: {fpr:.3f}\")\n",
404
+ "\n",
405
+ "print(f\"Precision: {precision:.3f}\")\n",
406
+ "print(f\"Recall: {recall:.3f}\")\n",
407
+ "print(f\"F1 Score: {f1:.3f}\")"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 6,
413
+ "id": "974e62d6-5088-4cd8-9721-6702717eadee",
414
+ "metadata": {},
415
+ "outputs": [
416
+ {
417
+ "name": "stdout",
418
+ "output_type": "stream",
419
+ "text": [
420
+ "99.14166666666665 0.07542472332656346\n",
421
+ "0.9913333333333333 0.0012472191289246482\n",
422
+ "0.9913333333333334 0.0012472191289246482\n",
423
+ "0.006333333333333333 0.0009428090415820634\n"
424
+ ]
425
+ }
426
+ ],
427
+ "source": [
428
+ "# acc\n",
429
+ "print(np.mean([99.195,99.195, 99.035 ]), np.std([99.195,99.195, 99.035]))\n",
430
+ "# recall\n",
431
+ "print(np.mean([0.991,0.991, 0.987]), np.std([0.991,0.990, 0.993]))\n",
432
+ "# f1\n",
433
+ "print(np.mean([0.992,0.992,0.990 ]),np.std([0.990,0.988,0.991 ]))\n",
434
+ "# fp\n",
435
+ "print(np.mean([0.005,0.007,0.007 ]),np.std([0.005,0.007,0.007]))\n"
436
+ ]
437
+ }
438
+ ],
439
+ "metadata": {
440
+ "kernelspec": {
441
+ "display_name": "Python 3 (ipykernel)",
442
+ "language": "python",
443
+ "name": "python3"
444
+ },
445
+ "language_info": {
446
+ "codemirror_mode": {
447
+ "name": "ipython",
448
+ "version": 3
449
+ },
450
+ "file_extension": ".py",
451
+ "mimetype": "text/x-python",
452
+ "name": "python",
453
+ "nbconvert_exporter": "python",
454
+ "pygments_lexer": "ipython3",
455
+ "version": "3.11.9"
456
+ }
457
+ },
458
+ "nbformat": 4,
459
+ "nbformat_minor": 5
460
+ }
models/.ipynb_checkpoints/plot_reatime_hits-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/practice_cnn_train-checkpoint.ipynb ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "6\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n",
21
+ "from torch.utils.data import Dataset, DataLoader\n",
22
+ "import torch\n",
23
+ "import numpy as np\n",
24
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
25
+ "import torch\n",
26
+ "import torch.nn as nn\n",
27
+ "import torch.optim as optim\n",
28
+ "from tqdm import tqdm \n",
29
+ "import torch.nn.functional as F\n",
30
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
31
+ "import pickle\n",
32
+ "\n",
33
+ "torch.manual_seed(1)\n",
34
+ "# torch.manual_seed(42)\n",
35
+ "\n",
36
+ "\n",
37
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
38
+ "num_gpus = torch.cuda.device_count()\n",
39
+ "print(num_gpus)\n",
40
+ "\n",
41
+ "# Create custom dataset instance\n",
42
+ "# Create custom dataset instance\n",
43
+ "data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
44
+ "dataset = CustomDataset(data_dir, transform=transform)\n",
45
+ "valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
46
+ "valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
47
+ "\n",
48
+ "\n",
49
+ "num_classes = 2\n",
50
+ "trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
51
+ "\n",
52
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
53
+ "model = nn.DataParallel(model)\n",
54
+ "model = model.to(device)\n",
55
+ "params = sum(p.numel() for p in model.parameters())\n",
56
+ "print(\"num params \",params)\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
63
+ "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "name": "stderr",
67
+ "output_type": "stream",
68
+ "text": [
69
+ " 0%| | 0/477 [00:00<?, ?batch/s]/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
70
+ " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n",
71
+ "100%|██████████████████████████████████████| 477/477 [08:57<00:00, 1.13s/batch]\n"
72
+ ]
73
+ },
74
+ {
75
+ "ename": "NameError",
76
+ "evalue": "name 'validloader' is not defined",
77
+ "output_type": "error",
78
+ "traceback": [
79
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
80
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
81
+ "Cell \u001b[0;32mIn[2], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m validloader:\n\u001b[1;32m 30\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\u001b[38;5;241m.\u001b[39mto(device)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 31\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
82
+ "\u001b[0;31mNameError\u001b[0m: name 'validloader' is not defined"
83
+ ]
84
+ }
85
+ ],
86
+ "source": [
87
+ "criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))\n",
88
+ "optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
89
+ "scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)\n",
90
+ "\n",
91
+ "for epoch in range(5):\n",
92
+ " running_loss = 0.0\n",
93
+ " correct_train = 0\n",
94
+ " total_train = 0\n",
95
+ " with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
96
+ " model.train()\n",
97
+ " for i, (images, labels) in enumerate(tepoch):\n",
98
+ " inputs, labels = images.to(device), labels.to(device).float()\n",
99
+ " optimizer.zero_grad()\n",
100
+ " outputs = model(inputs, return_mask=False).to(device)\n",
101
+ " new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)\n",
102
+ " loss = criterion(outputs, new_label)\n",
103
+ " loss.backward()\n",
104
+ " optimizer.step()\n",
105
+ " running_loss += loss.item()\n",
106
+ " # Calculate training accuracy\n",
107
+ " _, predicted = torch.max(outputs.data, 1)\n",
108
+ " total_train += labels.size(0)\n",
109
+ " correct_train += (predicted == labels).sum().item() \n",
110
+ " val_loss = 0.0\n",
111
+ " correct_valid = 0\n",
112
+ " total = 0\n",
113
+ " model.eval()\n",
114
+ " with torch.no_grad():\n",
115
+ " for images, labels in validloader:\n",
116
+ " inputs, labels = images.to(device), labels.to(device).float()\n",
117
+ " optimizer.zero_grad()\n",
118
+ " outputs = model(inputs, return_mask=False)\n",
119
+ " new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)\n",
120
+ " loss = criterion(outputs, new_label)\n",
121
+ " val_loss += loss.item()\n",
122
+ " _, predicted = torch.max(outputs, 1)\n",
123
+ " total += labels.size(0)\n",
124
+ " correct_valid += (predicted == labels).sum().item()\n",
125
+ " scheduler.step(val_loss)\n",
126
+ " # Calculate training accuracy after each epoch\n",
127
+ " train_accuracy = 100 * correct_train / total_train\n",
128
+ " val_accuracy = correct_valid / total * 100.0\n",
129
+ "\n",
130
+ "\n",
131
+ " print(\"===========================\")\n",
132
+ " print('accuracy: ', epoch, train_accuracy, val_accuracy)\n",
133
+ " print('learning rate: ', scheduler.get_last_lr())\n",
134
+ " print(\"===========================\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "3faa4a11-89fb-4556-ae87-3645a47fa00d",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "train_accuracy = 100 * correct_train / total_train\n",
145
+ "print('accuracy: ', epoch, train_accuracy)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "e586c4d2-a7f4-4f14-81fc-4f84ffac52b3",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "import sigpyproc.readers as r\n",
156
+ "import cv2\n",
157
+ "import numpy as np\n",
158
+ "import matplotlib.pyplot as plt\n",
159
+ "\n",
160
+ "from scipy.special import softmax\n",
161
+ "%matplotlib inline\n",
162
+ "path = '/mnt/primary/ata/projects/p051/fil_60565_59210_9756774_J0534+2200_0001/LoB.C0928/fil_60565_59210_9756774_J0534+2200_0001-beam0000.fil'\n",
163
+ "# path = '/mnt/primary/ata/projects/p051/fil_60564_62428_4679748_J0332+5434_0001/LoB.C0928/fil_60564_62428_4679748_J0332+5434_0001-beam0000.fil'\n",
164
+ "\n",
165
+ "# Get some metadata\n",
166
+ "\n",
167
+ "# Open the filterbank file\n",
168
+ "fil = r.FilReader(path)\n",
169
+ "header = fil.header\n",
170
+ "print(\"Header:\", header)\n",
171
+ "n=100\n",
172
+ "li = [ 7257608, 7324207, 10393163, 10641071, 11130537, 11085081,\n",
173
+ " 11419145, 11964112, 12329364, 13047181]\n",
174
+ "for el in li:\n",
175
+ " data = torch.tensor(fil.read_block(el-1024, 2048)).cuda()\n",
176
+ " print(data.shape)\n",
177
+ " out = model(transform(torch.tensor(data).cuda())[None])\n",
178
+ " print(softmax(out.detach().cpu().numpy(), axis=1))\n",
179
+ " plt.figure(figsize=(10,10))\n",
180
+ " plt.imshow(data.cpu().numpy(), aspect = 10)\n",
181
+ " plt.show()"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "609e5564-f14f-4bd1-b604-68e7e7d42834",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "triggers = []\n",
192
+ "counter = 0\n",
193
+ "with torch.no_grad():\n",
194
+ " for i in range(2048,10201921, 2048 ):\n",
195
+ " data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
196
+ " # Shuffle the tensor using the random indices\n",
197
+ " out = model(transform(torch.tensor(data).cuda())[None])\n",
198
+ " triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
199
+ " counter += 1\n",
200
+ " if counter > 1000:\n",
201
+ " break"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "id": "08ee6dcf-cb30-4490-8624-4e52552fdf39",
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "print(triggers[0])"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "id": "8c56c6f5-5a0b-4854-8a94-066a9baf4cfc",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "stack = np.stack(triggers)\n",
222
+ "positives = stack[:,0,1]\n",
223
+ "num_pos = np.where(positives > 0.5)[0].shape[0]\n",
224
+ "print(num_pos)"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "eb1d1591-8855-4989-bf12-c8a9cdbf2a4d",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "import pickle\n",
235
+ "\n",
236
+ "# Path to your pickle file\n",
237
+ "file_path = \"../dataset_generator/dir.pkl\"\n",
238
+ "\n",
239
+ "# Open and load the pickle file\n",
240
+ "with open(file_path, \"rb\") as file: # Use \"rb\" mode for reading binary files\n",
241
+ " data = pickle.load(file)\n",
242
+ "\n",
243
+ "# Print the contents of the file\n"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "46f61d7e-55fa-44fe-be94-d4ddb3c576f9",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "import sigpyproc.readers as r\n",
254
+ "import cv2\n",
255
+ "import numpy as np\n",
256
+ "import matplotlib.pyplot as plt\n",
257
+ "\n",
258
+ "from scipy.special import softmax\n",
259
+ "%matplotlib inline\n",
260
+ "path = data[0]\n",
261
+ "model.eval()\n",
262
+ "\n",
263
+ "fil = r.FilReader(path)\n",
264
+ "header = fil.header\n",
265
+ "print(\"Header:\", header)\n",
266
+ "n=100\n",
267
+ "\n",
268
+ "\n",
269
+ "triggers = []\n",
270
+ "counter = 0\n",
271
+ "for i in range(2048,10201921, 2048):\n",
272
+ " data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
273
+ " # Shuffle the tensor using the random indices\n",
274
+ " out = model(transform(torch.tensor(data).cuda())[None])\n",
275
+ " triggers.append(softmax(out.detach().cpu().numpy(), axis=1))\n",
276
+ " counter += 1\n",
277
+ " if counter > 1000:\n",
278
+ " break"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "id": "413d402e-2ce3-49fc-bbd4-a3cf1cc92388",
285
+ "metadata": {},
286
+ "outputs": [],
287
+ "source": [
288
+ "print(triggers[0])"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "id": "5c039dee-1b9b-4664-b42a-a79d780f37f1",
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "stack = np.stack(triggers)\n",
299
+ "positives = stack[:,0,1]\n",
300
+ "num_pos = np.where(positives > 0.5)[0].shape[0]\n",
301
+ "print(num_pos)"
302
+ ]
303
+ }
304
+ ],
305
+ "metadata": {
306
+ "kernelspec": {
307
+ "display_name": "Python 3 (ipykernel)",
308
+ "language": "python",
309
+ "name": "python3"
310
+ },
311
+ "language_info": {
312
+ "codemirror_mode": {
313
+ "name": "ipython",
314
+ "version": 3
315
+ },
316
+ "file_extension": ".py",
317
+ "mimetype": "text/x-python",
318
+ "name": "python",
319
+ "nbconvert_exporter": "python",
320
+ "pygments_lexer": "ipython3",
321
+ "version": "3.11.9"
322
+ }
323
+ },
324
+ "nbformat": 4,
325
+ "nbformat_minor": 5
326
+ }
models/.ipynb_checkpoints/recover_crab-checkpoint.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b063cb5ad71d38a551a5a4beb9ae21399e5e633af128c2608645398509854239
3
+ size 16982029
models/.ipynb_checkpoints/recover_new_crab-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/recover_new_crab-debug-checkpoint.ipynb ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "id": "851f001a-3882-42cf-8e45-1bb7c4193d20",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "6\n",
14
+ "num params encoder 50840\n",
15
+ "num params 21496282\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "from utils import CustomDataset, transform, Convert_ONNX\n",
21
+ "from torch.utils.data import Dataset, DataLoader\n",
22
+ "import torch\n",
23
+ "import numpy as np\n",
24
+ "from resnet_model_mask import ResidualBlock, ResNet\n",
25
+ "import torch\n",
26
+ "import torch.nn as nn\n",
27
+ "import torch.optim as optim\n",
28
+ "from tqdm import tqdm \n",
29
+ "import torch.nn.functional as F\n",
30
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
31
+ "import pickle\n",
32
+ "\n",
33
+ "torch.manual_seed(1)\n",
34
+ "# torch.manual_seed(42)\n",
35
+ "\n",
36
+ "\n",
37
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
38
+ "num_gpus = torch.cuda.device_count()\n",
39
+ "print(num_gpus)\n",
40
+ "\n",
41
+ "# Create custom dataset instance\n",
42
+ "data_dir = '/mnt/buf0/pma/frbnn/train_ready'\n",
43
+ "dataset = CustomDataset(data_dir, transform=transform)\n",
44
+ "valid_data_dir = '/mnt/buf0/pma/frbnn/valid_ready'\n",
45
+ "valid_dataset = CustomDataset(valid_data_dir, transform=transform)\n",
46
+ "\n",
47
+ "\n",
48
+ "num_classes = 2\n",
49
+ "trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)\n",
50
+ "\n",
51
+ "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n",
52
+ "model = nn.DataParallel(model)\n",
53
+ "model = model.to(device)\n",
54
+ "params = sum(p.numel() for p in model.parameters())\n",
55
+ "print(\"num params \",params)\n"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 9,
61
+ "id": "676a6ffa-5bed-403d-ba03-627f14b36de2",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# model_path = 'models/model-62-98.78.pt'\n",
66
+ "# model_path = 'models/model-47-99.125.pt'\n",
67
+ "model_path = 'models_mask/model-37-99.175_42.pt'\n",
68
+ "\n",
69
+ "# model_path = 'models_mask/model-10-97.055_1.pt'\n",
70
+ "model.load_state_dict(torch.load(model_path, weights_only=True))\n",
71
+ "model = model.eval()"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 10,
77
+ "id": "89d108de-7eae-4bbd-837c-8e657082a1e6",
78
+ "metadata": {},
79
+ "outputs": [
80
+ {
81
+ "name": "stdout",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "Header(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', data_type='raw data', nchans=192, foff=-0.5, fch1=1187.5, nbits=32, tsamp=6.4e-05, tstart=60692.03208333333, nsamples=28125184, nifs=1, coord=<SkyCoord (ICRS): (ra, dec) in deg\n",
85
+ " (83.63322, 22.01446)>, azimuth=<Angle 80.54659271 deg>, zenith=<Angle 66.41192055 deg>, telescope='Effelsberg LOFAR', backend='FAKE', source='crab', frame='topocentric', ibeam=0, nbeams=2, dm=0, period=0, accel=0, signed=False, rawdatafile='', stream_info=StreamInfo(entries=[FileInfo(filename='/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil', hdrlen=338, datalen=21600141312, nsamples=28125184, tstart=60692.03208333333, tsamp=6.4e-05)]))\n"
86
+ ]
87
+ }
88
+ ],
89
+ "source": [
90
+ "import sigpyproc.readers as r\n",
91
+ "import cv2\n",
92
+ "import numpy as np\n",
93
+ "import matplotlib.pyplot as plt\n",
94
+ "fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoA.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
95
+ "# fil = r.FilReader('/mnt/primary/ata/projects/p031/fil_60692_02772_253151611_crab_0001/LoB.C0736/fil_60692_02772_253151611_crab_0001-beam0000.fil')\n",
96
+ "header = fil.header\n",
97
+ "print(header)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 11,
103
+ "id": "0b276e6e-d6c8-41da-808d-542ee22133d1",
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "name": "stderr",
108
+ "output_type": "stream",
109
+ "text": [
110
+ " 0%| | 0/13732 [00:00<?, ?it/s]/tmp/ipykernel_19961/1777549771.py:15: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
111
+ " out = model(transform(torch.tensor(data).cuda())[None])\n",
112
+ " 3%|▉ | 351/13732 [00:13<08:27, 26.38it/s]\n"
113
+ ]
114
+ },
115
+ {
116
+ "ename": "KeyboardInterrupt",
117
+ "evalue": "",
118
+ "output_type": "error",
119
+ "traceback": [
120
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
121
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
122
+ "Cell \u001b[0;32mIn[11], line 13\u001b[0m\n\u001b[1;32m 11\u001b[0m counter \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2048\u001b[39m,header\u001b[38;5;241m.\u001b[39mnsamples, \u001b[38;5;241m2048\u001b[39m)):\n\u001b[0;32m---> 13\u001b[0m data \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(fil\u001b[38;5;241m.\u001b[39mread_block(i\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1024\u001b[39m, \u001b[38;5;241m2048\u001b[39m))\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# Shuffle the tensor using the random indices\u001b[39;00m\n\u001b[1;32m 15\u001b[0m out \u001b[38;5;241m=\u001b[39m model(transform(torch\u001b[38;5;241m.\u001b[39mtensor(data)\u001b[38;5;241m.\u001b[39mcuda())[\u001b[38;5;28;01mNone\u001b[39;00m])\n",
123
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
124
+ ]
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "Error in callback <function flush_figures at 0x7f6c8689ae80> (for post_execute), with arguments args (),kwargs {}:\n"
131
+ ]
132
+ },
133
+ {
134
+ "ename": "KeyboardInterrupt",
135
+ "evalue": "",
136
+ "output_type": "error",
137
+ "traceback": [
138
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
139
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
140
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:126\u001b[0m, in \u001b[0;36mflush_figures\u001b[0;34m()\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m InlineBackend\u001b[38;5;241m.\u001b[39minstance()\u001b[38;5;241m.\u001b[39mclose_figures:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;66;03m# ignore the tracking, just draw and close all figures\u001b[39;00m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m show(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 128\u001b[0m \u001b[38;5;66;03m# safely show traceback if in IPython, else raise\u001b[39;00m\n\u001b[1;32m 129\u001b[0m ip \u001b[38;5;241m=\u001b[39m get_ipython()\n",
141
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib_inline/backend_inline.py:90\u001b[0m, in \u001b[0;36mshow\u001b[0;34m(close, block)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m figure_manager \u001b[38;5;129;01min\u001b[39;00m Gcf\u001b[38;5;241m.\u001b[39mget_all_fig_managers():\n\u001b[0;32m---> 90\u001b[0m display(\n\u001b[1;32m 91\u001b[0m figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure,\n\u001b[1;32m 92\u001b[0m metadata\u001b[38;5;241m=\u001b[39m_fetch_figure_metadata(figure_manager\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mfigure)\n\u001b[1;32m 93\u001b[0m )\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 95\u001b[0m show\u001b[38;5;241m.\u001b[39m_to_draw \u001b[38;5;241m=\u001b[39m []\n",
142
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/display_functions.py:298\u001b[0m, in \u001b[0;36mdisplay\u001b[0;34m(include, exclude, metadata, transient, display_id, raw, clear, *objs, **kwargs)\u001b[0m\n\u001b[1;32m 296\u001b[0m publish_display_data(data\u001b[38;5;241m=\u001b[39mobj, metadata\u001b[38;5;241m=\u001b[39mmetadata, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 298\u001b[0m format_dict, md_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mformat\u001b[39m(obj, include\u001b[38;5;241m=\u001b[39minclude, exclude\u001b[38;5;241m=\u001b[39mexclude)\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m format_dict:\n\u001b[1;32m 300\u001b[0m \u001b[38;5;66;03m# nothing to display (e.g. _ipython_display_ took over)\u001b[39;00m\n\u001b[1;32m 301\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n",
143
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:179\u001b[0m, in \u001b[0;36mDisplayFormatter.format\u001b[0;34m(self, obj, include, exclude)\u001b[0m\n\u001b[1;32m 177\u001b[0m md \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 179\u001b[0m data \u001b[38;5;241m=\u001b[39m formatter(obj)\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 181\u001b[0m \u001b[38;5;66;03m# FIXME: log the exception\u001b[39;00m\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
144
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/decorator.py:232\u001b[0m, in \u001b[0;36mdecorate.<locals>.fun\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kwsyntax:\n\u001b[1;32m 231\u001b[0m args, kw \u001b[38;5;241m=\u001b[39m fix(args, kw, sig)\n\u001b[0;32m--> 232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m caller(func, \u001b[38;5;241m*\u001b[39m(extras \u001b[38;5;241m+\u001b[39m args), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n",
145
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:223\u001b[0m, in \u001b[0;36mcatch_format_error\u001b[0;34m(method, self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"show traceback on failed format call\"\"\"\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 223\u001b[0m r \u001b[38;5;241m=\u001b[39m method(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;66;03m# don't warn on NotImplementedErrors\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_return(\u001b[38;5;28;01mNone\u001b[39;00m, args[\u001b[38;5;241m0\u001b[39m])\n",
146
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/formatters.py:340\u001b[0m, in \u001b[0;36mBaseFormatter.__call__\u001b[0;34m(self, obj)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 340\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m printer(obj)\n\u001b[1;32m 341\u001b[0m \u001b[38;5;66;03m# Finally look for special method names\u001b[39;00m\n\u001b[1;32m 342\u001b[0m method \u001b[38;5;241m=\u001b[39m get_real_method(obj, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint_method)\n",
147
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/IPython/core/pylabtools.py:152\u001b[0m, in \u001b[0;36mprint_figure\u001b[0;34m(fig, fmt, bbox_inches, base64, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend_bases\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FigureCanvasBase\n\u001b[1;32m 150\u001b[0m FigureCanvasBase(fig)\n\u001b[0;32m--> 152\u001b[0m fig\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mprint_figure(bytes_io, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n\u001b[1;32m 153\u001b[0m data \u001b[38;5;241m=\u001b[39m bytes_io\u001b[38;5;241m.\u001b[39mgetvalue()\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fmt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msvg\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
148
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/backend_bases.py:2164\u001b[0m, in \u001b[0;36mFigureCanvasBase.print_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m 2161\u001b[0m \u001b[38;5;66;03m# we do this instead of `self.figure.draw_without_rendering`\u001b[39;00m\n\u001b[1;32m 2162\u001b[0m \u001b[38;5;66;03m# so that we can inject the orientation\u001b[39;00m\n\u001b[1;32m 2163\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(renderer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_draw_disabled\u001b[39m\u001b[38;5;124m\"\u001b[39m, nullcontext)():\n\u001b[0;32m-> 2164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 2165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches:\n\u001b[1;32m 2166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtight\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
149
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:95\u001b[0m, in \u001b[0;36m_finalize_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(draw)\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdraw_wrapper\u001b[39m(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 95\u001b[0m result \u001b[38;5;241m=\u001b[39m draw(artist, renderer, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m renderer\u001b[38;5;241m.\u001b[39m_rasterizing:\n\u001b[1;32m 97\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstop_rasterizing()\n",
150
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
151
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/figure.py:3154\u001b[0m, in \u001b[0;36mFigure.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3151\u001b[0m \u001b[38;5;66;03m# ValueError can occur when resizing a window.\u001b[39;00m\n\u001b[1;32m 3153\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpatch\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[0;32m-> 3154\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3155\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3157\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m sfig \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msubfigs:\n\u001b[1;32m 3158\u001b[0m sfig\u001b[38;5;241m.\u001b[39mdraw(renderer)\n",
152
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
153
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
154
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/axes/_base.py:3070\u001b[0m, in \u001b[0;36m_AxesBase.draw\u001b[0;34m(self, renderer)\u001b[0m\n\u001b[1;32m 3067\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artists_rasterized:\n\u001b[1;32m 3068\u001b[0m _draw_rasterized(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure, artists_rasterized, renderer)\n\u001b[0;32m-> 3070\u001b[0m mimage\u001b[38;5;241m.\u001b[39m_draw_list_compositing_images(\n\u001b[1;32m 3071\u001b[0m renderer, \u001b[38;5;28mself\u001b[39m, artists, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39msuppressComposite)\n\u001b[1;32m 3073\u001b[0m renderer\u001b[38;5;241m.\u001b[39mclose_group(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124maxes\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 3074\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstale \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
155
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:132\u001b[0m, in \u001b[0;36m_draw_list_compositing_images\u001b[0;34m(renderer, parent, artists, suppress_composite)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m not_composite \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m has_images:\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m artists:\n\u001b[0;32m--> 132\u001b[0m a\u001b[38;5;241m.\u001b[39mdraw(renderer)\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# Composite any adjacent images together\u001b[39;00m\n\u001b[1;32m 135\u001b[0m image_group \u001b[38;5;241m=\u001b[39m []\n",
156
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/artist.py:72\u001b[0m, in \u001b[0;36mallow_rasterization.<locals>.draw_wrapper\u001b[0;34m(artist, renderer)\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 70\u001b[0m renderer\u001b[38;5;241m.\u001b[39mstart_filter()\n\u001b[0;32m---> 72\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m draw(artist, renderer)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m artist\u001b[38;5;241m.\u001b[39mget_agg_filter() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
157
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:649\u001b[0m, in \u001b[0;36m_ImageBase.draw\u001b[0;34m(self, renderer, *args, **kwargs)\u001b[0m\n\u001b[1;32m 647\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im, trans)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 649\u001b[0m im, l, b, trans \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_image(\n\u001b[1;32m 650\u001b[0m renderer, renderer\u001b[38;5;241m.\u001b[39mget_image_magnification())\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m im \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 652\u001b[0m renderer\u001b[38;5;241m.\u001b[39mdraw_image(gc, l, b, im)\n",
158
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:939\u001b[0m, in \u001b[0;36mAxesImage.make_image\u001b[0;34m(self, renderer, magnification, unsampled)\u001b[0m\n\u001b[1;32m 936\u001b[0m transformed_bbox \u001b[38;5;241m=\u001b[39m TransformedBbox(bbox, trans)\n\u001b[1;32m 937\u001b[0m clip \u001b[38;5;241m=\u001b[39m ((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_box() \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes\u001b[38;5;241m.\u001b[39mbbox) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_clip_on()\n\u001b[1;32m 938\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure\u001b[38;5;241m.\u001b[39mbbox)\n\u001b[0;32m--> 939\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_image(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_A, bbox, transformed_bbox, clip,\n\u001b[1;32m 940\u001b[0m magnification, unsampled\u001b[38;5;241m=\u001b[39munsampled)\n",
159
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:526\u001b[0m, in \u001b[0;36m_ImageBase._make_image\u001b[0;34m(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)\u001b[0m\n\u001b[1;32m 521\u001b[0m mask \u001b[38;5;241m=\u001b[39m (np\u001b[38;5;241m.\u001b[39mwhere(A\u001b[38;5;241m.\u001b[39mmask, np\u001b[38;5;241m.\u001b[39mfloat32(np\u001b[38;5;241m.\u001b[39mnan), np\u001b[38;5;241m.\u001b[39mfloat32(\u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m A\u001b[38;5;241m.\u001b[39mmask\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m==\u001b[39m A\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;66;03m# nontrivial mask\u001b[39;00m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m np\u001b[38;5;241m.\u001b[39mones_like(A, np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[1;32m 524\u001b[0m \u001b[38;5;66;03m# we always have to interpolate the mask to account for\u001b[39;00m\n\u001b[1;32m 525\u001b[0m \u001b[38;5;66;03m# non-affine transformations\u001b[39;00m\n\u001b[0;32m--> 526\u001b[0m out_alpha \u001b[38;5;241m=\u001b[39m _resample(\u001b[38;5;28mself\u001b[39m, mask, out_shape, t, resample\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m mask \u001b[38;5;66;03m# Make sure we don't use mask anymore!\u001b[39;00m\n\u001b[1;32m 528\u001b[0m \u001b[38;5;66;03m# Agg updates out_alpha in place. If the pixel has no image\u001b[39;00m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;66;03m# data it will not be updated (and still be 0 as we initialized\u001b[39;00m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;66;03m# it), if input data that would go into that output pixel than\u001b[39;00m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;66;03m# it will be `nan`, if all the input data for a pixel is good\u001b[39;00m\n\u001b[1;32m 532\u001b[0m \u001b[38;5;66;03m# it will be 1, and if there is _some_ good data in that output\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;66;03m# pixel it will be between [0, 1] (such as a rotated image).\u001b[39;00m\n",
160
+ "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/matplotlib/image.py:208\u001b[0m, in \u001b[0;36m_resample\u001b[0;34m(image_obj, data, out_shape, transform, resample, alpha)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 207\u001b[0m resample \u001b[38;5;241m=\u001b[39m image_obj\u001b[38;5;241m.\u001b[39mget_resample()\n\u001b[0;32m--> 208\u001b[0m _image\u001b[38;5;241m.\u001b[39mresample(data, out, transform,\n\u001b[1;32m 209\u001b[0m _interpd_[interpolation],\n\u001b[1;32m 210\u001b[0m resample,\n\u001b[1;32m 211\u001b[0m alpha,\n\u001b[1;32m 212\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filternorm(),\n\u001b[1;32m 213\u001b[0m image_obj\u001b[38;5;241m.\u001b[39mget_filterrad())\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
161
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
162
+ ]
163
+ }
164
+ ],
165
+ "source": [
166
+ "import sigpyproc.readers as r\n",
167
+ "import cv2\n",
168
+ "import numpy as np\n",
169
+ "import matplotlib.pyplot as plt\n",
170
+ "from scipy.special import softmax\n",
171
+ "from tqdm import tqdm\n",
172
+ "%matplotlib inline\n",
173
+ "\n",
174
+ "header = fil.header\n",
175
+ "triggers = []\n",
176
+ "counter = 0\n",
177
+ "for i in tqdm(range(2048,header.nsamples, 2048)):\n",
178
+ " data = torch.tensor(fil.read_block(i-1024, 2048)).cuda()\n",
179
+ " # Shuffle the tensor using the random indices\n",
180
+ " out = model(transform(torch.tensor(data).cuda())[None])\n",
181
+ " out = softmax(out.detach().cpu().numpy(), axis=1)\n",
182
+ " triggers.append(out)\n",
183
+ " counter += 1\n",
184
+ " # if counter > 1000:\n",
185
+ " # break\n",
186
+ " # if out[0, 1]>0.999:\n",
187
+ " # key = data.cpu().numpy()\n",
188
+ " # plt.figure(figsize=(10,10))\n",
189
+ " # plt.imshow(data.cpu().numpy(), aspect = 10, vmax = 54557.824)\n",
190
+ "stack = np.stack(triggers)\n",
191
+ "positives = stack[:,0,1]\n",
192
+ "num_pos = np.where(positives > 0.999)[0].shape[0]\n",
193
+ "print(num_pos)"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "id": "64df934d-f4a2-49f0-857d-2661b1d78b21",
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "np.flipud()"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "id": "1eafb2c1-857e-48be-aa8b-18669c0e0f8c",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "plt.figure(figsize=(10,10))\n",
214
+ "# plt.imshow(key, aspect = 10, vmax = 54557.824)\n",
215
+ "plt.imshow(key, aspect = 10)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "ed3783c3-8ed1-46d6-91e4-e906dfa44913",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "key.shape"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "8b56a356-a582-4f5d-a8e2-20f725a48fb3",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "total_data =[]\n",
236
+ "for i in range(32):\n",
237
+ " total_data.append(key)\n",
238
+ "total_data = torch.tensor(np.array(total_data))\n",
239
+ "total_data.cpu().detach().numpy().tofile(\"crab_in.bin\")\n",
240
+ "print(total_data.shape)\n",
241
+ "outputs_data = []\n",
242
+ "for i in range(32):\n",
243
+ " temp = model(transform(total_data.cuda()[i,:,:])[None])\n",
244
+ " print(temp)\n",
245
+ " # outputs_data.append(softmax(temp.detach().cpu().numpy(), axis=1))\n",
246
+ " outputs_data.append(temp.detach().cpu().numpy())\n",
247
+ "outputs_data = torch.tensor(outputs_data)\n",
248
+ "outputs_data.cpu().detach().numpy().tofile(\"crab_out.bin\")"
249
+ ]
250
+ }
251
+ ],
252
+ "metadata": {
253
+ "kernelspec": {
254
+ "display_name": "Python 3 (ipykernel)",
255
+ "language": "python",
256
+ "name": "python3"
257
+ },
258
+ "language_info": {
259
+ "codemirror_mode": {
260
+ "name": "ipython",
261
+ "version": 3
262
+ },
263
+ "file_extension": ".py",
264
+ "mimetype": "text/x-python",
265
+ "name": "python",
266
+ "nbconvert_exporter": "python",
267
+ "pygments_lexer": "ipython3",
268
+ "version": "3.11.9"
269
+ }
270
+ },
271
+ "nbformat": 4,
272
+ "nbformat_minor": 5
273
+ }
models/.ipynb_checkpoints/recover_new_frb-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/.ipynb_checkpoints/resnet_model-checkpoint.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResidualBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
7
+ super(ResidualBlock, self).__init__()
8
+ self.conv1 = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
10
+ nn.BatchNorm2d(out_channels),
11
+ nn.ReLU())
12
+ self.conv2 = nn.Sequential(
13
+ nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
14
+ nn.BatchNorm2d(out_channels))
15
+ self.downsample = downsample
16
+ self.relu = nn.ReLU()
17
+ self.out_channels = out_channels
18
+ self.dropout_percentage = 0.5
19
+ self.dropout1 = nn.Dropout(p=self.dropout_percentage)
20
+ self.batchnorm_mod = nn.BatchNorm2d(out_channels)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+ out = self.conv1(x)
25
+ out = self.dropout1(out)
26
+ # out = self.batchnorm_mod(out)
27
+ out = self.conv2(out)
28
+ out = self.dropout1(out)
29
+ # out = self.batchnorm_mod(out)
30
+ if self.downsample:
31
+ residual = self.downsample(x)
32
+ out += residual
33
+ out = self.relu(out)
34
+ return out
35
+
36
+
37
+ class ResNet(nn.Module):
38
+ def __init__(self, inchan, block, layers, num_classes = 10):
39
+ super(ResNet, self).__init__()
40
+ self.inplanes = 64
41
+ self.eps = 1e-5
42
+ self.relu = nn.ReLU()
43
+ self.conv1 = nn.Sequential(
44
+ nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
45
+ nn.BatchNorm2d(64),
46
+ nn.ReLU())
47
+ self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
48
+ self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
49
+ self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
50
+ self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
51
+ self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
52
+ self.avgpool = nn.AvgPool2d(7, stride=1)
53
+ self.fc = nn.Linear(39424, num_classes)
54
+ self.dropout_percentage = 0.3
55
+ self.dropout1 = nn.Dropout(p=self.dropout_percentage)
56
+
57
+ # Encoder
58
+ self.encoder = nn.Sequential(
59
+ nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
60
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
61
+ nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
62
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
63
+ nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
64
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
65
+ nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
66
+ nn.Sigmoid()
67
+ )
68
+ params = sum(p.numel() for p in self.encoder.parameters())
69
+ print("num params encoder ",params)
70
+
71
+ def norm(self, x):
72
+ shifted = x-x.min()
73
+ maxes = torch.amax(abs(shifted), dim=(-2, -1))
74
+ repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
75
+ x = shifted/repeated_maxes
76
+ return x
77
+
78
+ def _make_layer(self, block, planes, blocks, stride=1):
79
+ downsample = None
80
+ if stride != 1 or self.inplanes != planes:
81
+ downsample = nn.Sequential(
82
+ nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
83
+ nn.BatchNorm2d(planes),
84
+ )
85
+ layers = []
86
+ layers.append(block(self.inplanes, planes, stride, downsample))
87
+ self.inplanes = planes
88
+ for i in range(1, blocks):
89
+ layers.append(block(self.inplanes, planes))
90
+ return nn.Sequential(*layers)
91
+
92
+ def forward(self, x, return_mask=False):
93
+ # x = self.norm(x)
94
+ x = self.conv1(x)
95
+ x = self.maxpool(x)
96
+ x = self.layer0(x)
97
+ x = self.layer1(x)
98
+ x = self.layer2(x)
99
+ x = self.layer3(x)
100
+ x = self.avgpool(x)
101
+ x = x.view(x.size(0), -1)
102
+ x = self.dropout1(x)
103
+ x = self.fc(x)
104
+ # return x
105
+ if return_mask:
106
+ return x, self.mask, self.value
107
+ else:
108
+ return x
109
+
110
+
111
+ class ConvAutoencoder(nn.Module):
112
+ def __init__(self):
113
+ super(ConvAutoencoder, self).__init__()
114
+
115
+ # Encoder
116
+ self.encoder = nn.Sequential(
117
+ nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
118
+ nn.ReLU(),
119
+ nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
120
+ nn.ReLU(),
121
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
122
+ nn.ReLU(),
123
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
124
+ nn.ReLU()
125
+ )
126
+
127
+ # Fully connected latent space
128
+ self.fc1 = nn.Linear(128 * 12 * 16, 8)
129
+ self.fc2 = nn.Linear(8, 128 * 12 * 16)
130
+
131
+ # Decoder
132
+ self.decoder = nn.Sequential(
133
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
134
+ nn.ReLU(),
135
+ nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
136
+ nn.ReLU(),
137
+ nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
138
+ nn.ReLU(),
139
+ nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
140
+ nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
141
+ )
142
+
143
+ def forward(self, x):
144
+ # Encode
145
+ x = self.encoder(x)
146
+
147
+ # Flatten the encoded output
148
+ x = x.view(x.size(0), -1)
149
+
150
+ # Fully connected latent space
151
+ x = self.fc1(x)
152
+ x = self.fc2(x)
153
+
154
+ # Reshape the output to the shape suitable for the decoder
155
+ x = x.view(x.size(0), 128, 12, 16)
156
+
157
+ # Decode
158
+ x = self.decoder(x)
159
+
160
+ return x
models/.ipynb_checkpoints/resnet_model_mask-checkpoint.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResidualBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
7
+ super(ResidualBlock, self).__init__()
8
+ self.conv1 = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
10
+ nn.BatchNorm2d(out_channels),
11
+ nn.ReLU())
12
+ self.conv2 = nn.Sequential(
13
+ nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
14
+ nn.BatchNorm2d(out_channels))
15
+ self.downsample = downsample
16
+ self.relu = nn.ReLU()
17
+ self.out_channels = out_channels
18
+ self.dropout_percentage = 0.5
19
+ self.dropout1 = nn.Dropout(p=self.dropout_percentage)
20
+ self.batchnorm_mod = nn.BatchNorm2d(out_channels)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+ out = self.conv1(x)
25
+ out = self.dropout1(out)
26
+ # out = self.batchnorm_mod(out)
27
+ out = self.conv2(out)
28
+ out = self.dropout1(out)
29
+ # out = self.batchnorm_mod(out)
30
+ if self.downsample:
31
+ residual = self.downsample(x)
32
+ out += residual
33
+ out = self.relu(out)
34
+ return out
35
+
36
+
37
+ class ResNet(nn.Module):
38
+ def __init__(self, inchan, block, layers, num_classes = 10):
39
+ super(ResNet, self).__init__()
40
+ self.inplanes = 64
41
+ self.eps = 1e-5
42
+ self.relu = nn.ReLU()
43
+ self.conv1 = nn.Sequential(
44
+ nn.Conv2d(inchan, 64, kernel_size = 7, stride = 2, padding = 3),
45
+ nn.BatchNorm2d(64),
46
+ nn.ReLU())
47
+ self.maxpool = nn.MaxPool2d(kernel_size = (2, 2), stride = 2, padding = 1)
48
+ self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
49
+ self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
50
+ self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
51
+ self.layer3 = self._make_layer(block, 512, layers[3], stride = 1)
52
+ self.avgpool = nn.AvgPool2d(7, stride=1)
53
+ self.fc = nn.Linear(39424, num_classes)
54
+ self.dropout_percentage = 0.3
55
+ self.dropout1 = nn.Dropout(p=self.dropout_percentage)
56
+
57
+ # Encoder
58
+ self.encoder = nn.Sequential(
59
+ nn.Conv2d(24, 32, kernel_size = 3, stride =1, padding = 1),
60
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
61
+ nn.Conv2d(32, 64, kernel_size = 3, stride =1, padding = 1),
62
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
63
+ nn.Conv2d(64, 32, kernel_size = 3, stride = 1, padding = 1),
64
+ nn.ReLU(True),nn.Dropout(p=self.dropout_percentage),
65
+ nn.Conv2d(32, 24, kernel_size = 3, stride = 1, padding = 1),
66
+ nn.Sigmoid()
67
+ )
68
+ params = sum(p.numel() for p in self.encoder.parameters())
69
+ print("num params encoder ",params)
70
+
71
+ def norm(self, x):
72
+ shifted = x-x.min()
73
+ maxes = torch.amax(abs(shifted), dim=(-2, -1))
74
+ repeated_maxes = maxes.unsqueeze(2).unsqueeze(3).repeat(1, 1, x.shape[-2],x.shape[-1])
75
+ x = shifted/repeated_maxes
76
+ return x
77
+
78
+ def _make_layer(self, block, planes, blocks, stride=1):
79
+ downsample = None
80
+ if stride != 1 or self.inplanes != planes:
81
+ downsample = nn.Sequential(
82
+ nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
83
+ nn.BatchNorm2d(planes),
84
+ )
85
+ layers = []
86
+ layers.append(block(self.inplanes, planes, stride, downsample))
87
+ self.inplanes = planes
88
+ for i in range(1, blocks):
89
+ layers.append(block(self.inplanes, planes))
90
+ return nn.Sequential(*layers)
91
+
92
+ def forward(self, x, return_mask=False):
93
+ # # m = self.encoder(x).unsqueeze(-1).repeat(1, 1, 1, x.shape[-1])
94
+ m = self.encoder(x)
95
+ self.mask = m
96
+ self.value = x
97
+ # # m = nn.Sigmoid()(self.encoder(x))
98
+ x = x * m
99
+ # x = self.norm(x)
100
+ x = self.conv1(x)
101
+ x = self.maxpool(x)
102
+ x = self.layer0(x)
103
+ x = self.layer1(x)
104
+ x = self.layer2(x)
105
+ x = self.layer3(x)
106
+ x = self.avgpool(x)
107
+ x = x.view(x.size(0), -1)
108
+ x = self.dropout1(x)
109
+ x = self.fc(x)
110
+ return x
111
+ # if return_mask:
112
+ # return x, self.mask, self.value
113
+ # else:
114
+ # return x
115
+
116
+
117
+ class ConvAutoencoder(nn.Module):
118
+ def __init__(self):
119
+ super(ConvAutoencoder, self).__init__()
120
+
121
+ # Encoder
122
+ self.encoder = nn.Sequential(
123
+ nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # (16, 96, 128)
124
+ nn.ReLU(),
125
+ nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 48, 64)
126
+ nn.ReLU(),
127
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 24, 32)
128
+ nn.ReLU(),
129
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),# (128, 12, 16)
130
+ nn.ReLU()
131
+ )
132
+
133
+ # Fully connected latent space
134
+ self.fc1 = nn.Linear(128 * 12 * 16, 8)
135
+ self.fc2 = nn.Linear(8, 128 * 12 * 16)
136
+
137
+ # Decoder
138
+ self.decoder = nn.Sequential(
139
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # (64, 24, 32)
140
+ nn.ReLU(),
141
+ nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # (32, 48, 64)
142
+ nn.ReLU(),
143
+ nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # (16, 96, 128)
144
+ nn.ReLU(),
145
+ nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1), # (3, 192, 256)
146
+ nn.Sigmoid() # Using Sigmoid for the final activation to get output in range [0, 1]
147
+ )
148
+
149
+ def forward(self, x):
150
+ # Encode
151
+ x = self.encoder(x)
152
+
153
+ # Flatten the encoded output
154
+ x = x.view(x.size(0), -1)
155
+
156
+ # Fully connected latent space
157
+ x = self.fc1(x)
158
+ x = self.fc2(x)
159
+
160
+ # Reshape the output to the shape suitable for the decoder
161
+ x = x.view(x.size(0), 128, 12, 16)
162
+
163
+ # Decode
164
+ x = self.decoder(x)
165
+
166
+ return x
models/.ipynb_checkpoints/train-checkpoint.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import CustomDataset, transform, preproc, Convert_ONNX
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ import numpy as np
5
+ from resnet_model import ResidualBlock, ResNet
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import tqdm
10
+ import torch.nn.functional as F
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
+ import pickle
13
+ import sys
14
+
15
+ ind = int(sys.argv[1])
16
+ seeds = [1,42,7109,2002,32]
17
+ seed = seeds[ind]
18
+ print("using seed: ",seed)
19
+ torch.manual_seed(seed)
20
+
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ num_gpus = torch.cuda.device_count()
24
+ print(num_gpus)
25
+
26
+ # Create custom dataset instance
27
+ data_dir = '/mnt/buf1/pma/frbnn/train_ready'
28
+ dataset = CustomDataset(data_dir, transform=transform)
29
+ valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
30
+ valid_dataset = CustomDataset(valid_data_dir, transform=transform)
31
+
32
+
33
+ num_classes = 2
34
+ trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
35
+ validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
36
+
37
+ model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
38
+ model = nn.DataParallel(model)
39
+ model = model.to(device)
40
+ params = sum(p.numel() for p in model.parameters())
41
+ print("num params ",params)
42
+ torch.save(model.state_dict(), 'models/test.pt')
43
+ model.load_state_dict(torch.load('models/test.pt'))
44
+
45
+ preproc_model = preproc()
46
+ Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
47
+ Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
48
+
49
+ # Define optimizer and loss function
50
+
51
+ criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
52
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
53
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
54
+
55
+
56
+ from tqdm import tqdm
57
+ # Training loop
58
+ epochs = 10000
59
+ for epoch in range(epochs):
60
+ running_loss = 0.0
61
+ correct_train = 0
62
+ total_train = 0
63
+ with tqdm(trainloader, unit="batch") as tepoch:
64
+ model.train()
65
+ for i, (images, labels) in enumerate(tepoch):
66
+ inputs, labels = images.to(device), labels.to(device).float()
67
+ optimizer.zero_grad()
68
+ outputs = model(inputs, return_mask=False).to(device)
69
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
70
+ loss = criterion(outputs, new_label)
71
+ loss.backward()
72
+ optimizer.step()
73
+ running_loss += loss.item()
74
+ # Calculate training accuracy
75
+ _, predicted = torch.max(outputs.data, 1)
76
+ total_train += labels.size(0)
77
+ correct_train += (predicted == labels).sum().item()
78
+ val_loss = 0.0
79
+ correct_valid = 0
80
+ total = 0
81
+ model.eval()
82
+ with torch.no_grad():
83
+ for images, labels in validloader:
84
+ inputs, labels = images.to(device), labels.to(device).float()
85
+ optimizer.zero_grad()
86
+ outputs = model(inputs, return_mask=False)
87
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
88
+ loss = criterion(outputs, new_label)
89
+ val_loss += loss.item()
90
+ _, predicted = torch.max(outputs, 1)
91
+ total += labels.size(0)
92
+ correct_valid += (predicted == labels).sum().item()
93
+ scheduler.step(val_loss)
94
+ # Calculate training accuracy after each epoch
95
+ train_accuracy = 100 * correct_train / total_train
96
+ val_accuracy = correct_valid / total * 100.0
97
+ torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt')
98
+ Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs)
99
+
100
+ print("===========================")
101
+ print('accuracy: ', epoch, train_accuracy, val_accuracy)
102
+ print('learning rate: ', scheduler.get_last_lr())
103
+ print("===========================")
104
+ if scheduler.get_last_lr()[0] <1e-6:
105
+ break
models/.ipynb_checkpoints/train-mask-8-checkpoint.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import CustomDataset, transform, preproc, Convert_ONNX
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ import numpy as np
5
+ from resnet_model_mask import ResidualBlock, ResNet
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import tqdm
10
+ import torch.nn.functional as F
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
+ import pickle
13
+ import sys
14
+ # [1,42,7109,2002,32]
15
+ ind = int(sys.argv[1])
16
+ seeds = [1,42,7109,2002,32]
17
+ seed = seeds[ind]
18
+ torch.manual_seed(seed)
19
+
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ num_gpus = torch.cuda.device_count()
22
+ print(num_gpus)
23
+
24
+ # Create custom dataset instance
25
+ data_dir = '/mnt/buf1/pma/frbnn/train_ready'
26
+ dataset = CustomDataset(data_dir, bit8 = True, transform=transform)
27
+ valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
28
+ valid_dataset = CustomDataset(valid_data_dir, bit8 = True, transform=transform)
29
+
30
+
31
+ num_classes = 2
32
+ trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
33
+ validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
34
+
35
+ model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
36
+ model = nn.DataParallel(model)
37
+ model = model.to(device)
38
+ params = sum(p.numel() for p in model.parameters())
39
+ print("num params ",params)
40
+ torch.save(model.state_dict(), f'models_8/test_{seed}.pt')
41
+ model.load_state_dict(torch.load(f'models_8/test_{seed}.pt'))
42
+
43
+ preproc_model = preproc()
44
+ Convert_ONNX(model.module,f'models_8/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
45
+ Convert_ONNX(preproc_model,f'models_8/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
46
+
47
+ # Define optimizer and loss function
48
+
49
+ criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
50
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
51
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
52
+
53
+
54
+ from tqdm import tqdm
55
+ # Training loop
56
+ epochs = 10000
57
+ for epoch in range(epochs):
58
+ running_loss = 0.0
59
+ correct_train = 0
60
+ total_train = 0
61
+ with tqdm(trainloader, unit="batch") as tepoch:
62
+ model.train()
63
+ for i, (images, labels) in enumerate(tepoch):
64
+ inputs, labels = images.to(device), labels.to(device).float()
65
+ optimizer.zero_grad()
66
+ outputs = model(inputs, return_mask=False).to(device)
67
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
68
+ loss = criterion(outputs, new_label)
69
+ loss.backward()
70
+ optimizer.step()
71
+ running_loss += loss.item()
72
+ # Calculate training accuracy
73
+ _, predicted = torch.max(outputs.data, 1)
74
+ total_train += labels.size(0)
75
+ correct_train += (predicted == labels).sum().item()
76
+ val_loss = 0.0
77
+ correct_valid = 0
78
+ total = 0
79
+ model.eval()
80
+ with torch.no_grad():
81
+ for images, labels in validloader:
82
+ inputs, labels = images.to(device), labels.to(device).float()
83
+ optimizer.zero_grad()
84
+ outputs = model(inputs, return_mask=False)
85
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
86
+ loss = criterion(outputs, new_label)
87
+ val_loss += loss.item()
88
+ _, predicted = torch.max(outputs, 1)
89
+ total += labels.size(0)
90
+ correct_valid += (predicted == labels).sum().item()
91
+ scheduler.step(val_loss)
92
+ # Calculate training accuracy after each epoch
93
+ train_accuracy = 100 * correct_train / total_train
94
+ val_accuracy = correct_valid / total * 100.0
95
+ torch.save(model.state_dict(), 'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
96
+ Convert_ONNX(model.module,'models_8/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
97
+
98
+ print("===========================")
99
+ print('accuracy: ', epoch, train_accuracy, val_accuracy)
100
+ print('learning rate: ', scheduler.get_last_lr())
101
+ print("===========================")
102
+ if scheduler.get_last_lr()[0] <1e-6:
103
+ break
models/.ipynb_checkpoints/train-mask-checkpoint.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import CustomDataset, transform, preproc, Convert_ONNX
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch
4
+ import numpy as np
5
+ from resnet_model_mask import ResidualBlock, ResNet
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import tqdm
10
+ import torch.nn.functional as F
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
12
+ import pickle
13
+ import sys
14
+
15
+ ind = int(sys.argv[1])
16
+ seeds = [1,42,7109,2002,32]
17
+ seed = seeds[ind]
18
+ print("using seed: ",seed)
19
+ torch.manual_seed(seed)
20
+
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ num_gpus = torch.cuda.device_count()
23
+ print(num_gpus)
24
+
25
+ # Create custom dataset instance
26
+ data_dir = '/mnt/buf1/pma/frbnn/train_ready'
27
+ dataset = CustomDataset(data_dir, transform=transform)
28
+ valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
29
+ valid_dataset = CustomDataset(valid_data_dir, transform=transform)
30
+
31
+
32
+ num_classes = 2
33
+ trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
34
+ validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
35
+
36
+ model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
37
+ model = nn.DataParallel(model)
38
+ model = model.to(device)
39
+ params = sum(p.numel() for p in model.parameters())
40
+ print("num params ",params)
41
+ torch.save(model.state_dict(), f'models_mask/test_{seed}.pt')
42
+ model.load_state_dict(torch.load(f'models_mask/test_{seed}.pt'))
43
+
44
+ preproc_model = preproc()
45
+ Convert_ONNX(model.module,f'models_mask/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
46
+ Convert_ONNX(preproc_model,f'models_mask/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
47
+
48
+ # Define optimizer and loss function
49
+
50
+ criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
51
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
52
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
53
+
54
+
55
+ from tqdm import tqdm
56
+ # Training loop
57
+ epochs = 10000
58
+ for epoch in range(epochs):
59
+ running_loss = 0.0
60
+ correct_train = 0
61
+ total_train = 0
62
+ with tqdm(trainloader, unit="batch") as tepoch:
63
+ model.train()
64
+ for i, (images, labels) in enumerate(tepoch):
65
+ inputs, labels = images.to(device), labels.to(device).float()
66
+ optimizer.zero_grad()
67
+ outputs = model(inputs, return_mask=False).to(device)
68
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
69
+ loss = criterion(outputs, new_label)
70
+ loss.backward()
71
+ optimizer.step()
72
+ running_loss += loss.item()
73
+ # Calculate training accuracy
74
+ _, predicted = torch.max(outputs.data, 1)
75
+ total_train += labels.size(0)
76
+ correct_train += (predicted == labels).sum().item()
77
+ val_loss = 0.0
78
+ correct_valid = 0
79
+ total = 0
80
+ model.eval()
81
+ with torch.no_grad():
82
+ for images, labels in validloader:
83
+ inputs, labels = images.to(device), labels.to(device).float()
84
+ optimizer.zero_grad()
85
+ outputs = model(inputs, return_mask=False)
86
+ new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
87
+ loss = criterion(outputs, new_label)
88
+ val_loss += loss.item()
89
+ _, predicted = torch.max(outputs, 1)
90
+ total += labels.size(0)
91
+ correct_valid += (predicted == labels).sum().item()
92
+ scheduler.step(val_loss)
93
+ # Calculate training accuracy after each epoch
94
+ train_accuracy = 100 * correct_train / total_train
95
+ val_accuracy = correct_valid / total * 100.0
96
+ torch.save(model.state_dict(), 'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
97
+ Convert_ONNX(model.module,'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
98
+
99
+ print("===========================")
100
+ print('accuracy: ', epoch, train_accuracy, val_accuracy)
101
+ print('learning rate: ', scheduler.get_last_lr())
102
+ print("===========================")
103
+ if scheduler.get_last_lr()[0] <1e-6:
104
+ break
models/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import os
5
+ import torch
6
+ from copy import deepcopy
7
+ from blimpy import Waterfall
8
+ from tqdm import tqdm
9
+ from copy import deepcopy
10
+ from sigpyproc.readers import FilReader
11
+ from torch import nn
12
+
13
+
14
+ def load_pickled_data(file_path):
15
+ with open(file_path, 'rb') as f:
16
+ data = pickle.load(f)
17
+ return data
18
+
19
+ # Custom dataset class
20
+ class CustomDataset(Dataset):
21
+ def __init__(self, data_dir, bit8=False, transform=None):
22
+ self.data_dir = data_dir
23
+ self.transform = transform
24
+ self.images = []
25
+ self.labels = []
26
+ self.classes = os.listdir(data_dir)
27
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
28
+ self.bit8 = bit8
29
+ # Load images and labels
30
+ for cls in self.classes:
31
+ class_dir = os.path.join(data_dir, cls)
32
+ for image_name in os.listdir(class_dir):
33
+ image_path = os.path.join(class_dir, image_name)
34
+ self.images.append(image_path)
35
+ self.labels.append(self.class_to_idx[cls])
36
+
37
+ def __len__(self):
38
+ return len(self.images)
39
+
40
+ def __getitem__(self, idx):
41
+ image_path = self.images[idx]
42
+ label = self.labels[idx]
43
+ # Load image
44
+ image = load_pickled_data(image_path)
45
+ if self.transform is not None:
46
+ if self.bit8 == True:
47
+ new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
48
+ else:
49
+ new_image = self.transform(torch.from_numpy(image['data']))
50
+ # new_image = self.transform(image['data'])
51
+ return new_image, label
52
+
53
+ # Custom dataset class
54
+ class CustomDataset_Masked(Dataset):
55
+ def __init__(self, data_dir, transform=None):
56
+ self.data_dir = data_dir
57
+ self.transform = transform
58
+ self.images = []
59
+ self.labels = []
60
+ self.classes = os.listdir(data_dir)
61
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
62
+
63
+ # Load images and labels
64
+ for cls in self.classes:
65
+ class_dir = os.path.join(data_dir, cls)
66
+ for image_name in os.listdir(class_dir):
67
+ image_path = os.path.join(class_dir, image_name)
68
+ self.images.append(image_path)
69
+ self.labels.append(self.class_to_idx[cls])
70
+
71
+ def __len__(self):
72
+ return len(self.images)
73
+
74
+ def __getitem__(self, idx):
75
+ image_path = self.images[idx]
76
+
77
+ label = self.labels[idx]
78
+ # Load image
79
+ image = load_pickled_data(image_path)
80
+ if self.transform is not None:
81
+ if image['burst'].max() ==0:
82
+ new_burst = torch.from_numpy(image['burst'])
83
+ else:
84
+ new_burst = torch.from_numpy(image['burst']/image['burst'].max())
85
+ ind = new_burst > 0.1
86
+ ind_not = new_burst <= 0.1
87
+ new_burst[ind] = 1
88
+ new_burst[ind_not] = 0
89
+ new_image = self.transform(torch.from_numpy(image['data'].data))
90
+ new_burst_arr = torch.zeros_like(new_image)
91
+ new_burst_arr[ 0, :,:] = new_burst
92
+ new_burst_arr[ 1, :,:] = new_burst
93
+ new_burst_arr[ 2, :,:] = new_burst
94
+ return new_image, label, new_burst_arr
95
+
96
+ # Custom dataset class
97
+ class TestingDataset(Dataset):
98
+ def __init__(self, data_dir, bit8=False, transform=None):
99
+ self.data_dir = data_dir
100
+ self.transform = transform
101
+ self.images = []
102
+ self.labels = []
103
+ self.classes = os.listdir(data_dir)
104
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
105
+ self.bit8 = bit8
106
+ # Load images and labels
107
+ for cls in self.classes:
108
+ class_dir = os.path.join(data_dir, cls)
109
+ for image_name in os.listdir(class_dir):
110
+ image_path = os.path.join(class_dir, image_name)
111
+ self.images.append(image_path)
112
+ self.labels.append(self.class_to_idx[cls])
113
+
114
+ def __len__(self):
115
+ return len(self.images)
116
+
117
+ def __getitem__(self, idx):
118
+ image_path = self.images[idx]
119
+ label = self.labels[idx]
120
+ # Load image
121
+ image = load_pickled_data(image_path)
122
+ params = image['params']
123
+ if self.transform is not None:
124
+ params = image['params']
125
+ if self.bit8 == True:
126
+ new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32))
127
+ else:
128
+ new_image = self.transform(torch.from_numpy(image['data']))
129
+ params['labels'] = label
130
+ return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard'])
131
+
132
+ # Custom dataset class
133
+ class SearchDataset(Dataset):
134
+ def __init__(self, data_dir, transform=None, pickle_data=False):
135
+ self.window_size = 2048
136
+
137
+ if pickle_data:
138
+ with open(data_dir, 'rb') as f:
139
+ self.d = pickle.load(f)
140
+ self.header = self.d['header']
141
+ self.images = self.crop(self.d['data'][:,0,:], self.window_size)
142
+ else:
143
+ self.obs = Waterfall(data_dir, max_load = 50)
144
+ self.header = self.obs.header
145
+ self.images = self.crop(self.obs.data[:,0,:], self.window_size)
146
+ self.transform = transform
147
+ self.SEC_PER_DAY = 86400
148
+
149
+ def crop(self, data, window_size = 2048):
150
+ n_samp = data.shape[0]//window_size
151
+ new_data = np.zeros((n_samp, window_size, 192 ))
152
+ for i in range(n_samp):
153
+ new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
154
+ return new_data
155
+
156
+ def __len__(self):
157
+ return self.images.shape[0]
158
+ def __getitem__(self, idx):
159
+ data = self.images[idx, :, :].T
160
+ tindex = idx * self.window_size
161
+ time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart']
162
+ if self.transform is not None:
163
+ new_image = self.transform(data)
164
+ return new_image, idx
165
+
166
+ # Custom dataset class
167
+ class SearchDataset_Sigproc(Dataset):
168
+ def __init__(self, data_dir, transform=None):
169
+ self.window_size = 2048
170
+ fil = FilReader(data_dir)
171
+ self.header = fil.header
172
+ # print("check shape ",fil.read_block(0, fil.header.nsamples).shape)
173
+ read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024]
174
+ read_data = np.swapaxes(read_data, 0,-1)
175
+ self.images = self.crop(read_data, self.window_size)
176
+ self.transform = transform
177
+ self.SEC_PER_DAY = 86400
178
+
179
+ def crop(self, data, window_size = 2048):
180
+ n_samp = data.shape[0]//window_size
181
+ new_data = np.zeros((n_samp, window_size, 192 ))
182
+ for i in range(n_samp):
183
+ new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :]
184
+ return new_data
185
+
186
+ def __len__(self):
187
+ return self.images.shape[0]
188
+
189
+ def __getitem__(self, idx):
190
+ data = self.images[idx, :, :].T
191
+ tindex = idx * self.window_size
192
+ time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart
193
+ if self.transform is not None:
194
+ new_image = self.transform(torch.from_numpy(data))
195
+ return new_image, idx
196
+
197
+ # def renorm(data):
198
+ # shifted = data - data.min()
199
+ # shifted = shifted/shifted.max()
200
+ # return shifted
201
+
202
+ def renorm(data):
203
+ mean = torch.mean(data)
204
+ std = torch.std(data)
205
+ # Standardize the data
206
+ standardized_data = (data - mean) / std
207
+ return standardized_data
208
+
209
+ def transform(data):
210
+ copy_data = data.detach().clone()
211
+ rms = torch.std(data)
212
+ mean = torch.mean(data)
213
+ masks_rms = [-1, 5]
214
+ new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
215
+ new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
216
+ for i in range(1, len(masks_rms)+1):
217
+ scale = masks_rms[i-1]
218
+ copy_data = data.detach().clone() #deepcopy(data)
219
+ if scale < 0:
220
+ ind = copy_data < abs(scale) * rms + mean
221
+ copy_data[ind] = 0
222
+ else:
223
+ ind = copy_data > (scale) * rms + mean
224
+ copy_data[ind] = 0
225
+ new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
226
+ new_data = new_data.type(torch.float32)
227
+ slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
228
+ new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
229
+ new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
230
+ return new_data
231
+
232
+
233
+ def renorm_batched(data):
234
+ mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
235
+ std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
236
+ standardized_data = (data - mean) / std
237
+ return standardized_data
238
+
239
+ def transform_batched(data):
240
+ copy_data = data.detach().clone()
241
+ rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
242
+ mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
243
+ masks_rms = [-1, 5]
244
+
245
+ # Prepare the new_data tensor
246
+ num_masks = len(masks_rms) + 1
247
+ new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
248
+
249
+ # First layer: Apply renorm(log10(copy_data + epsilon))
250
+ new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
251
+ for i, scale in enumerate(masks_rms, start=1):
252
+ copy_data = data.detach().clone()
253
+
254
+ # Apply masking based on the scale
255
+ if scale < 0:
256
+ ind = copy_data < abs(scale) * rms + mean
257
+ else:
258
+ ind = copy_data > scale * rms + mean
259
+ copy_data[ind] = 0
260
+
261
+ # Renormalize and log10 transform
262
+ new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
263
+
264
+ # Convert to float32
265
+ new_data = new_data.type(torch.float32)
266
+
267
+ # Chunk along the last dimension and stack
268
+ slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
269
+ new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
270
+ new_data = torch.swapaxes(new_data, 0,1)
271
+ # Reshape into final format
272
+ new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
273
+ return new_data
274
+
275
+
276
+
277
+ class preproc(nn.Module):
278
+ def forward(self, x, flip=True):
279
+ if flip:
280
+ transform_batched(torch.flip(x, dims = (-2,)))
281
+ else:
282
+ transform_batched(x)
283
+ return template
284
+
285
+ # class preproc_debug(nn.Module):
286
+ # def forward(self, x):
287
+ # template = torch.zeros((32, 24, 192, 256))
288
+ # # for i in torch.arange(x.shape[0]): # Use a tensor-based range
289
+ # template[0,:,:,:] = transform_debug(torch.flip(x[0,:,:], dims = (0,)))
290
+ # template[1,:,:,:] = transform_debug(torch.flip(x[1,:,:], dims = (0,)))
291
+ # template[2,:,:,:] = transform_debug(torch.flip(x[2,:,:], dims = (0,)))
292
+ # template[3,:,:,:] = transform_debug(torch.flip(x[3,:,:], dims = (0,)))
293
+ # template[4,:,:,:] = transform_debug(torch.flip(x[4,:,:], dims = (0,)))
294
+ # template[5,:,:,:] = transform_debug(torch.flip(x[5,:,:], dims = (0,)))
295
+ # template[6,:,:,:] = transform_debug(torch.flip(x[6,:,:], dims = (0,)))
296
+ # template[7,:,:,:] = transform_debug(torch.flip(x[7,:,:], dims = (0,)))
297
+ # template[8,:,:,:] = transform_debug(torch.flip(x[8,:,:], dims = (0,)))
298
+ # template[9,:,:,:] = transform_debug(torch.flip(x[9,:,:], dims = (0,)))
299
+ # template[10,:,:,:] = transform_debug(torch.flip(x[10,:,:], dims = (0,)))
300
+ # template[11,:,:,:] = transform_debug(torch.flip(x[11,:,:], dims = (0,)))
301
+ # template[12,:,:,:] = transform_debug(torch.flip(x[12,:,:], dims = (0,)))
302
+ # template[13,:,:,:] = transform_debug(torch.flip(x[13,:,:], dims = (0,)))
303
+ # template[14,:,:,:] = transform_debug(torch.flip(x[14,:,:], dims = (0,)))
304
+ # template[15,:,:,:] = transform_debug(torch.flip(x[15,:,:], dims = (0,)))
305
+ # template[16,:,:,:] = transform_debug(torch.flip(x[16,:,:], dims = (0,)))
306
+ # template[17,:,:,:] = transform_debug(torch.flip(x[17,:,:], dims = (0,)))
307
+ # template[18,:,:,:] = transform_debug(torch.flip(x[18,:,:], dims = (0,)))
308
+ # template[19,:,:,:] = transform_debug(torch.flip(x[19,:,:], dims = (0,)))
309
+ # template[20,:,:,:] = transform_debug(torch.flip(x[20,:,:], dims = (0,)))
310
+ # template[21,:,:,:] = transform_debug(torch.flip(x[21,:,:], dims = (0,)))
311
+ # template[22,:,:,:] = transform_debug(torch.flip(x[22,:,:], dims = (0,)))
312
+ # template[23,:,:,:] = transform_debug(torch.flip(x[23,:,:], dims = (0,)))
313
+ # template[24,:,:,:] = transform_debug(torch.flip(x[24,:,:], dims = (0,)))
314
+ # template[25,:,:,:] = transform_debug(torch.flip(x[25,:,:], dims = (0,)))
315
+ # template[26,:,:,:] = transform_debug(torch.flip(x[26,:,:], dims = (0,)))
316
+ # template[27,:,:,:] = transform_debug(torch.flip(x[27,:,:], dims = (0,)))
317
+ # template[28,:,:,:] = transform_debug(torch.flip(x[28,:,:], dims = (0,)))
318
+ # template[29,:,:,:] = transform_debug(torch.flip(x[29,:,:], dims = (0,)))
319
+ # template[30,:,:,:] = transform_debug(torch.flip(x[30,:,:], dims = (0,)))
320
+ # template[31,:,:,:] = transform_debug(torch.flip(x[31,:,:], dims = (0,)))
321
+ # return template
322
+
323
+ # def transform_debug(data):
324
+ # copy_data = data.detach().clone()
325
+ # rms = torch.std(data)
326
+ # mean = torch.mean(data)
327
+ # masks_rms = [-1, 5]
328
+ # new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1]))
329
+ # new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10))
330
+ # for i in range(1, len(masks_rms)+1):
331
+ # scale = masks_rms[i-1]
332
+ # copy_data = data.detach().clone()
333
+ # if scale < 0:
334
+ # ind = copy_data < abs(scale) * rms + mean
335
+ # copy_data[ind] = 0
336
+ # else:
337
+ # ind = copy_data > (scale) * rms + mean
338
+ # copy_data[ind] = 0
339
+ # new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10))
340
+ # new_data = new_data.type(torch.float32)
341
+ # slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension
342
+ # new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1
343
+ # new_data = new_data.view(-1, new_data.size(2), new_data.size(3))
344
+ # return new_data
345
+
346
+ def renorm_batched(data):
347
+ mins = torch.amin(data, (-2, -1))
348
+ mins = mins.unsqueeze(1).unsqueeze(2)
349
+ mins = mins.expand(data.shape[0], 192, 2048)
350
+ shifted = data - mins
351
+ maxs = torch.amax(shifted, (-2, -1))
352
+ maxs = maxs.unsqueeze(1).unsqueeze(2)
353
+ maxs = maxs.expand(data.shape[0], 192, 2048)
354
+ shifted = shifted/maxs
355
+ return shifted
356
+
357
+
358
+ def transform_mask(data):
359
+ copy_data = deepcopy(data)
360
+ shift = copy_data - copy_data.min()
361
+ normalized_data = shift / shift.max()
362
+ new_data = np.zeros((3, data.shape[0], data.shape[1]))
363
+ for i in range(3):
364
+ new_data[i,:,:] = normalized_data
365
+ new_data = new_data.astype(np.float32)
366
+ return new_data
367
+
368
+
369
+ #Function to Convert to ONNX
370
+ def Convert_ONNX(model, saveloc, input_data_mock):
371
+ print("Saving to ONNX")
372
+ # set the model to inference mode
373
+ model.eval()
374
+
375
+ # Let's create a dummy input tensor
376
+ dummy_input = torch.autograd.Variable(input_data_mock)
377
+
378
+ # Export the model
379
+ torch.onnx.export(model, # model being run
380
+ dummy_input, # model input (or a tuple for multiple inputs)
381
+ saveloc, # where to save the model
382
+ input_names = ['modelInput'], # the model's input names
383
+ output_names = ['modelOutput'], # the model's output names
384
+ dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes
385
+ 'modelOutput' : {0 : 'batch_size'}} )
386
+ print(" ")
387
+ print('Model has been converted to ONNX')
388
+
389
+
390
+
391
+
392
+
393
+
models/.ipynb_checkpoints/utils_batched_preproc-checkpoint.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import os
5
+ import torch
6
+ from copy import deepcopy
7
+ from blimpy import Waterfall
8
+ from tqdm import tqdm
9
+ from copy import deepcopy
10
+ from sigpyproc.readers import FilReader
11
+ from torch import nn
12
+
13
+
14
+ def renorm_batched(data):
15
+ mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
16
+ std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
17
+ standardized_data = (data - mean) / std
18
+ return standardized_data
19
+
20
+ def transform_batched(data):
21
+ copy_data = data.detach().clone()
22
+ rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
23
+ mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
24
+ masks_rms = [-1, 5]
25
+
26
+ # Prepare the new_data tensor
27
+ num_masks = len(masks_rms) + 1
28
+ new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
29
+
30
+ # First layer: Apply renorm(log10(copy_data + epsilon))
31
+ new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
32
+ for i, scale in enumerate(masks_rms, start=1):
33
+ copy_data = data.detach().clone()
34
+
35
+ # Apply masking based on the scale
36
+ if scale < 0:
37
+ ind = copy_data < abs(scale) * rms + mean
38
+ else:
39
+ ind = copy_data > scale * rms + mean
40
+ copy_data[ind] = 0
41
+
42
+ # Renormalize and log10 transform
43
+ new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
44
+
45
+ # Convert to float32
46
+ new_data = new_data.type(torch.float32)
47
+
48
+ # Chunk along the last dimension and stack
49
+ slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
50
+ new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
51
+ new_data = torch.swapaxes(new_data, 0,1)
52
+ # Reshape into final format
53
+ new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
54
+ return new_data
55
+
56
+ class preproc_flip(nn.Module):
57
+ def forward(self, x, flip=True):
58
+ template = transform_batched(torch.flip(x, dims = (-2,)))
59
+ return template
60
+
61
+ class preproc(nn.Module):
62
+ def forward(self, x, flip=True):
63
+ template = transform_batched(x)
64
+ return template
65
+
models/HITS-FEB-10.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce8b602cef03e22c666cdea792741411f623fb5eb0a254ef1ffd9a32864d754
3
+ size 270858960
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739230556_9-checkpoint.png ADDED

Git LFS Details

  • SHA256: ee9fdd79492b5751eb009da3213f9512ce915172f8d0428fe781de9a3d8a1a8e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.41 MB
models/HITS-FEB-10/.ipynb_checkpoints/hit_100000000_1739231399_1-checkpoint.png ADDED

Git LFS Details

  • SHA256: 6633230e8e09dd3383a8b456be80fe1a44de28daf4d0e515fce46053c0b972a2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB
models/HITS-FEB-10/hit_100000000_1739230556_9.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cc8774df726c0db3358653dc48894f70f2964cac7604dc5319b44f8f6340b71
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739230556_9.png ADDED

Git LFS Details

  • SHA256: ee9fdd79492b5751eb009da3213f9512ce915172f8d0428fe781de9a3d8a1a8e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.41 MB
models/HITS-FEB-10/hit_100000000_1739231399_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f66e781e7776a11067d0b103b36e9f42b7130d5297ab8936a814af659e125524
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739231399_1.png ADDED

Git LFS Details

  • SHA256: 6633230e8e09dd3383a8b456be80fe1a44de28daf4d0e515fce46053c0b972a2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB
models/HITS-FEB-10/hit_100000000_1739231802_11.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9edde89b6a0526420dcff5e35516979a20f379fac7a2c98d59b9dae897bf4426
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739231802_11.png ADDED

Git LFS Details

  • SHA256: 614b3b142d2ef53041569aa046b1874725725b1132e4b1441938912c9fcd88ad
  • Pointer size: 132 Bytes
  • Size of remote file: 3.94 MB
models/HITS-FEB-10/hit_100000000_1739234628_13.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:550f498a863202da1dad26237f7daf0a4f347e7fecde3650e9abee7611994078
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739234628_13.png ADDED

Git LFS Details

  • SHA256: 4cff023a127e7b204632f84f3900e78e83168cdca96d532ca21678f4d65f5f42
  • Pointer size: 132 Bytes
  • Size of remote file: 4.04 MB
models/HITS-FEB-10/hit_100000000_1739234628_14.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b49681cd9e55c1b3f81be0d2e9c04aacece284581b5b534cf8531ef5464b084
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739234628_14.png ADDED

Git LFS Details

  • SHA256: 3a437b3baafa88de6023b97175c1ec11dc9347359494a2b35eff5863933a29c1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.11 MB
models/HITS-FEB-10/hit_100000000_1739235333_29.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04f284312e3c2d80e49d617bec875f1b6664bbe21e1418ba9fee8f0cafe70e24
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739235333_29.png ADDED

Git LFS Details

  • SHA256: 890d153131412d535d8c71c442dbe8d3ea761b014d4f992e3d714d4091f09418
  • Pointer size: 132 Bytes
  • Size of remote file: 4.12 MB
models/HITS-FEB-10/hit_100000000_1739235841_12.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5b8008a6d13ef7b36fe03fd9dab1e58a6bbb770b433ffb77a3eb11a5033a09f
3
+ size 1572992
models/HITS-FEB-10/hit_100000000_1739235841_12.png ADDED

Git LFS Details

  • SHA256: 2d3317775f2972c176ef9ea904b9ca597bd47dac99173ef3951907d66793e83c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
models/HITS-FEB-10/hit_50233055_1739232802_29.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2d0a49b0bf3825cfbbdbd91c465764f82b326e172c378685c9de87a44f1296e
3
+ size 1572992
models/HITS-FEB-10/hit_50233055_1739232802_29.png ADDED

Git LFS Details

  • SHA256: fe65c81a869a9390f746611a2f1c4ca734c924ed462c884055e1be2f105c1396
  • Pointer size: 132 Bytes
  • Size of remote file: 3.78 MB
models/HITS-FEB-10/hit_52111435_1739229641_28.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caa8876c4c3185ad998f1a61a2b4176dec62d5ad38e3ab585aa9ff730848f83f
3
+ size 1572992
models/HITS-FEB-10/hit_52111435_1739229641_28.png ADDED

Git LFS Details

  • SHA256: bfbc7557ca622b3f6447f9913b3f342c27b419f2403030e58ffe4110bc4c3975
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB
models/HITS-FEB-10/hit_52550001_1739233595_4.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a69ae7f971552525b8beac6adee0baa0eb066158fa73e2748265e7a803f8c9f
3
+ size 1572992
models/HITS-FEB-10/hit_52550001_1739233595_4.png ADDED

Git LFS Details

  • SHA256: d59e32f7015d41494863bace9175ac032da00591bf085bc183b4befcab2ebbf7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB