{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "98165c88-8ead-4fae-9ea8-6b2e82996fc5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6\n", "num params encoder 50840\n", "num params 21496282\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/48 [00:00>\n", "Traceback (most recent call last):\n", " File \"/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n", " def _clean_thread_parent_frames(\n", "\n", "KeyboardInterrupt: \n", " 10%|████▌ | 5/48 [02:53<24:56, 34.79s/it]\n" ] }, { "ename": "RuntimeError", "evalue": "DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_queue\u001b[38;5;241m.\u001b[39mget(timeout\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _ForkingPickler\u001b[38;5;241m.\u001b[39mloads(res)\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/multiprocessing/reductions.py:496\u001b[0m, in \u001b[0;36mrebuild_storage_fd\u001b[0;34m(cls, df, size)\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrebuild_storage_fd\u001b[39m(\u001b[38;5;28mcls\u001b[39m, df, size):\n\u001b[0;32m--> 496\u001b[0m fd \u001b[38;5;241m=\u001b[39m df\u001b[38;5;241m.\u001b[39mdetach()\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:57\u001b[0m, in \u001b[0;36mDupFd.detach\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m'''Get the fd. This should only be called once.'''\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _resource_sharer\u001b[38;5;241m.\u001b[39mget_connection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_id) \u001b[38;5;28;01mas\u001b[39;00m conn:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m reduction\u001b[38;5;241m.\u001b[39mrecv_handle(conn)\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/resource_sharer.py:86\u001b[0m, in \u001b[0;36m_ResourceSharer.get_connection\u001b[0;34m(ident)\u001b[0m\n\u001b[1;32m 85\u001b[0m address, key \u001b[38;5;241m=\u001b[39m ident\n\u001b[0;32m---> 86\u001b[0m c \u001b[38;5;241m=\u001b[39m Client(address, authkey\u001b[38;5;241m=\u001b[39mprocess\u001b[38;5;241m.\u001b[39mcurrent_process()\u001b[38;5;241m.\u001b[39mauthkey)\n\u001b[1;32m 87\u001b[0m c\u001b[38;5;241m.\u001b[39msend((key, os\u001b[38;5;241m.\u001b[39mgetpid()))\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:519\u001b[0m, in \u001b[0;36mClient\u001b[0;34m(address, family, authkey)\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 519\u001b[0m c \u001b[38;5;241m=\u001b[39m SocketClient(address)\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m authkey \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(authkey, \u001b[38;5;28mbytes\u001b[39m):\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/multiprocessing/connection.py:647\u001b[0m, in \u001b[0;36mSocketClient\u001b[0;34m(address)\u001b[0m\n\u001b[1;32m 646\u001b[0m s\u001b[38;5;241m.\u001b[39msetblocking(\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 647\u001b[0m s\u001b[38;5;241m.\u001b[39mconnect(address)\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Connection(s\u001b[38;5;241m.\u001b[39mdetach())\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 48\u001b[0m\n\u001b[1;32m 46\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m images, labels \u001b[38;5;129;01min\u001b[39;00m tqdm(testloader):\n\u001b[1;32m 49\u001b[0m inputs, labels \u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mto(device), labels\n\u001b[1;32m 50\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(inputs, return_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m)\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_data()\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_data()\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_try_get_data()\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", "File \u001b[0;32m~/.conda/envs/frbnn/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1144\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(failed_workers) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1143\u001b[0m pids_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(w\u001b[38;5;241m.\u001b[39mpid) \u001b[38;5;28;01mfor\u001b[39;00m w \u001b[38;5;129;01min\u001b[39;00m failed_workers)\n\u001b[0;32m-> 1144\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDataLoader worker (pid(s) \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpids_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) exited unexpectedly\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1145\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, queue\u001b[38;5;241m.\u001b[39mEmpty):\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n", "\u001b[0;31mRuntimeError\u001b[0m: DataLoader worker (pid(s) 4158742, 4158790, 4158838, 4158886, 4158934, 4158982, 4159030, 4159078, 4159126, 4159174, 4159222, 4159270, 4159318, 4159366, 4159414, 4159462, 4159510, 4159558, 4159606, 4159654, 4159702, 4159750, 4159798, 4159846, 4159894, 4159942, 4159990, 4160038, 4160086, 4160134, 4160182, 4160230) exited unexpectedly" ] } ], "source": [ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n", "from torch.utils.data import Dataset, DataLoader\n", "from utils import CustomDataset, TestingDataset, transform\n", "from tqdm import tqdm\n", "import torch\n", "import numpy as np\n", "from resnet_model_mask import ResidualBlock, ResNet\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from tqdm import tqdm \n", "import torch.nn.functional as F\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "import pickle\n", "\n", "torch.manual_seed(1)\n", "# torch.manual_seed(42)\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "num_gpus = torch.cuda.device_count()\n", "print(num_gpus)\n", "\n", "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n", "test_dataset = TestingDataset(test_data_dir, transform=transform)\n", "\n", "num_classes = 2\n", "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n", "\n", "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n", "model = nn.DataParallel(model)\n", "model = model.to(device)\n", "params = sum(p.numel() for p in model.parameters())\n", "print(\"num params \",params)\n", "\n", "model_1 = 'models_mask/model-43-99.235_42.pt'\n", "# model_1 ='models/model-47-99.125.pt'\n", "model.load_state_dict(torch.load(model_1, weights_only=True))\n", "model = model.eval()\n", "\n", "# eval\n", "val_loss = 0.0\n", "correct_valid = 0\n", "total = 0\n", "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n", "model.eval()\n", "with torch.no_grad():\n", " for images, labels in tqdm(testloader):\n", " inputs, labels = images.to(device), labels\n", " outputs = model(inputs, return_mask = True)\n", " _, predicted = torch.max(outputs, 1)\n", " results['output'].extend(outputs.cpu().numpy().tolist())\n", " results['pred'].extend(predicted.cpu().numpy().tolist())\n", " results['true'].extend(labels[0].cpu().numpy().tolist())\n", " results['freq'].extend(labels[2].cpu().numpy().tolist())\n", " results['dm'].extend(labels[1].cpu().numpy().tolist())\n", " results['snr'].extend(labels[3].cpu().numpy().tolist())\n", " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n", " total += labels[0].size(0)\n", " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n", " \n", "# Calculate training accuracy after each epoch\n", "val_accuracy = correct_valid / total * 100.0\n", "print(\"===========================\")\n", "print('accuracy: ', val_accuracy)\n", "print(\"===========================\")\n", "\n", "import pickle\n", "\n", "# Pickle the dictionary to a file\n", "with open('models_mask/test_42.pkl', 'wb') as f:\n", " pickle.dump(results, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "64733667-75c3-4fd3-ab9f-62b85c5e27e3", "metadata": {}, "outputs": [], "source": [ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n", "from torch.utils.data import Dataset, DataLoader\n", "from utils import CustomDataset, TestingDataset, transform\n", "from tqdm import tqdm\n", "import torch\n", "import numpy as np\n", "from resnet_model_mask import ResidualBlock, ResNet\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from tqdm import tqdm \n", "import torch.nn.functional as F\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "import pickle\n", "\n", "torch.manual_seed(1)\n", "# torch.manual_seed(42)\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "num_gpus = torch.cuda.device_count()\n", "print(num_gpus)\n", "\n", "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n", "test_dataset = TestingDataset(test_data_dir, transform=transform)\n", "\n", "num_classes = 2\n", "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n", "\n", "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n", "model = nn.DataParallel(model)\n", "model = model.to(device)\n", "params = sum(p.numel() for p in model.parameters())\n", "print(\"num params \",params)\n", "\n", "\n", "model_1 = 'models_mask/model-36-99.11999999999999_1.pt'\n", "# model_1 ='models/model-47-99.125.pt'\n", "model.load_state_dict(torch.load(model_1, weights_only=True))\n", "model = model.eval()\n", "\n", "# eval\n", "val_loss = 0.0\n", "correct_valid = 0\n", "total = 0\n", "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n", "model.eval()\n", "with torch.no_grad():\n", " for images, labels in tqdm(testloader):\n", " inputs, labels = images.to(device), labels\n", " outputs = model(inputs, return_mask = True)\n", " _, predicted = torch.max(outputs, 1)\n", " results['output'].extend(outputs.cpu().numpy().tolist())\n", " results['pred'].extend(predicted.cpu().numpy().tolist())\n", " results['true'].extend(labels[0].cpu().numpy().tolist())\n", " results['freq'].extend(labels[2].cpu().numpy().tolist())\n", " results['dm'].extend(labels[1].cpu().numpy().tolist())\n", " results['snr'].extend(labels[3].cpu().numpy().tolist())\n", " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n", " total += labels[0].size(0)\n", " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n", " \n", " \n", "# Calculate training accuracy after each epoch\n", "val_accuracy = correct_valid / total * 100.0\n", "print(\"===========================\")\n", "print('accuracy: ', val_accuracy)\n", "print(\"===========================\")\n", "\n", "import pickle\n", "\n", "# Pickle the dictionary to a file\n", "with open('models_mask/test_1.pkl', 'wb') as f:\n", " pickle.dump(results, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "fe74ada8-43e4-4c73-b772-0ef18983345d", "metadata": {}, "outputs": [], "source": [ "from utils import CustomDataset, transform, preproc, Convert_ONNX\n", "from torch.utils.data import Dataset, DataLoader\n", "from utils import CustomDataset, TestingDataset, transform\n", "from tqdm import tqdm\n", "import torch\n", "import numpy as np\n", "from resnet_model_mask import ResidualBlock, ResNet\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from tqdm import tqdm \n", "import torch.nn.functional as F\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "import pickle\n", "\n", "torch.manual_seed(1)\n", "# torch.manual_seed(42)\n", "\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "num_gpus = torch.cuda.device_count()\n", "print(num_gpus)\n", "\n", "test_data_dir = '/mnt/buf1/pma/frbnn/test_ready'\n", "test_dataset = TestingDataset(test_data_dir, transform=transform)\n", "\n", "num_classes = 2\n", "testloader = DataLoader(test_dataset, batch_size=420, shuffle=True, num_workers=32)\n", "\n", "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)\n", "model = nn.DataParallel(model)\n", "model = model.to(device)\n", "params = sum(p.numel() for p in model.parameters())\n", "print(\"num params \",params)\n", "\n", "\n", "model_1 = 'models_mask/model-26-99.13_7109.pt'\n", "# model_1 ='models/model-47-99.125.pt'\n", "model.load_state_dict(torch.load(model_1, weights_only=True))\n", "model = model.eval()\n", "\n", "# eval\n", "val_loss = 0.0\n", "correct_valid = 0\n", "total = 0\n", "results = {'output': [],'pred': [], 'true':[], 'freq':[], 'snr':[], 'dm':[], 'boxcar':[]}\n", "model.eval()\n", "with torch.no_grad():\n", " for images, labels in tqdm(testloader):\n", " inputs, labels = images.to(device), labels\n", " outputs = model(inputs, return_mask = True)\n", " _, predicted = torch.max(outputs, 1)\n", " results['output'].extend(outputs.cpu().numpy().tolist())\n", " results['pred'].extend(predicted.cpu().numpy().tolist())\n", " results['true'].extend(labels[0].cpu().numpy().tolist())\n", " results['freq'].extend(labels[2].cpu().numpy().tolist())\n", " results['dm'].extend(labels[1].cpu().numpy().tolist())\n", " results['snr'].extend(labels[3].cpu().numpy().tolist())\n", " results['boxcar'].extend(labels[4].cpu().numpy().tolist())\n", " total += labels[0].size(0)\n", " correct_valid += (predicted.cpu() == labels[0].cpu()).sum().item()\n", " \n", " \n", "# Calculate training accuracy after each epoch\n", "val_accuracy = correct_valid / total * 100.0\n", "print(\"===========================\")\n", "print('accuracy: ', val_accuracy)\n", "print(\"===========================\")\n", "\n", "import pickle\n", "\n", "# Pickle the dictionary to a file\n", "with open('models_mask/test_7109.pkl', 'wb') as f:\n", " pickle.dump(results, f)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }