{ "cells": [ { "cell_type": "code", "execution_count": 10, "id": "5577ffee-a5c9-4648-8849-95c2c7ebcebe", "metadata": {}, "outputs": [], "source": [ "from utils import CustomDataset, transform, Convert_ONNX\n", "from utils_batched_preproc import transform_batched, preproc_flip\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch\n", "import numpy as np\n", "from resnet_model_mask import ResidualBlock, ResNet\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import tqdm \n", "import torch.nn.functional as F\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "import pickle\n", "import torch\n", "from functorch import vmap" ] }, { "cell_type": "code", "execution_count": 3, "id": "f1180d60-83e7-47ca-aa09-58d26af3c706", "metadata": {}, "outputs": [], "source": [ "# def renorm_batched(data):\n", "# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n", "# std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)\n", "# standardized_data = (data - mean) / std\n", "# return standardized_data\n", "\n", "# def transform_batched(data):\n", "# copy_data = data.detach().clone()\n", "# rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std\n", "# mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean\n", "# masks_rms = [-1, 5]\n", " \n", "# # Prepare the new_data tensor\n", "# num_masks = len(masks_rms) + 1\n", "# new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)\n", "\n", "# # First layer: Apply renorm(log10(copy_data + epsilon))\n", "# new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))\n", "# for i, scale in enumerate(masks_rms, start=1):\n", "# copy_data = data.detach().clone()\n", " \n", "# # Apply masking based on the scale\n", "# if scale < 0:\n", "# ind = copy_data < abs(scale) * rms + mean\n", "# else:\n", "# ind = copy_data > scale * rms + mean\n", "# copy_data[ind] = 0\n", " \n", "# # Renormalize and log10 transform\n", "# new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))\n", " \n", "# # Convert to float32\n", "# new_data = new_data.type(torch.float32)\n", "\n", "# # Chunk along the last dimension and stack\n", "# slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing\n", "# new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1\n", "# new_data = torch.swapaxes(new_data, 0,1)\n", "# # Reshape into final format\n", "# new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions\n", "# return new_data\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "81cc81a8-ecef-43ef-a5cd-c35765384812", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num params encoder 50840\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_19147/1680389579.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " model.load_state_dict(torch.load(model_path))\n" ] }, { "data": { "text/plain": [ "DataParallel(\n", " (module): ResNet(\n", " (relu): ReLU()\n", " (conv1): Sequential(\n", " (0): Conv2d(24, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (maxpool): MaxPool2d(kernel_size=(2, 2), stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer0): Sequential(\n", " (0): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer1): Sequential(\n", " (0): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (5): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): ResidualBlock(\n", " (conv1): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU()\n", " )\n", " (conv2): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (relu): ReLU()\n", " (dropout1): Dropout(p=0.5, inplace=False)\n", " (batchnorm_mod): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n", " (fc): Linear(in_features=39424, out_features=2, bias=True)\n", " (dropout1): Dropout(p=0.3, inplace=False)\n", " (encoder): Sequential(\n", " (0): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Dropout(p=0.3, inplace=False)\n", " (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (4): ReLU(inplace=True)\n", " (5): Dropout(p=0.3, inplace=False)\n", " (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (7): ReLU(inplace=True)\n", " (8): Dropout(p=0.3, inplace=False)\n", " (9): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (10): Sigmoid()\n", " )\n", " )\n", ")" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_path = 'models/model-47-99.125.pt'\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=2).to(device)\n", "model = nn.DataParallel(model)\n", "model = model.to(device)\n", "model.load_state_dict(torch.load(model_path))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 12, "id": "58b8c338-df2f-4ef0-92cf-409c9f034cab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/pma/.conda/envs/frbnn/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", " with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 4.1780, -4.1750],\n", " [ 4.6414, -4.6303],\n", " [ 5.0103, -5.0162],\n", " [ 4.8273, -4.8311],\n", " [ 4.8523, -4.8661],\n", " [ 4.8855, -4.9074],\n", " [ 4.4973, -4.5213],\n", " [ 5.5996, -5.6192],\n", " [ 4.7929, -4.8116],\n", " [ 5.5999, -5.5925],\n", " [ 4.7918, -4.7998],\n", " [ 4.0914, -4.0766],\n", " [ 0.7072, -0.6955],\n", " [ 4.7136, -4.7234],\n", " [ 5.3918, -5.4307],\n", " [ 4.5491, -4.5524],\n", " [ 4.5412, -4.5391],\n", " [ 4.6264, -4.6137],\n", " [ 3.9378, -3.9300],\n", " [ 5.0673, -5.0792],\n", " [ 5.7389, -5.7330],\n", " [ 5.2259, -5.2326],\n", " [ 5.3856, -5.4036],\n", " [ 5.0781, -5.1232],\n", " [ 5.2432, -5.2584],\n", " [ 5.8163, -5.8209],\n", " [ 4.7730, -4.7823],\n", " [ 5.1320, -5.1657],\n", " [ 5.6486, -5.6485],\n", " [ 3.7626, -3.7674],\n", " [ 4.1834, -4.1797],\n", " [ 4.4452, -4.4566]], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "test_in = abs(torch.randn(32, 192, 2048).to(device))\n", "results = []\n", "for i in range(32):\n", " results.append(transform(test_in[i,:,:]))\n", "intermediate = torch.stack(results).cuda()\n", "out = model(intermediate)\n", "test_in.cpu().detach().numpy().tofile(\"input.bin\")\n", "intermediate.cpu().detach().numpy().tofile(\"intermediate.bin\")\n", "out.cpu().detach().numpy().tofile(\"output.bin\")\n", "print(out)" ] }, { "cell_type": "code", "execution_count": 13, "id": "ad56299a-44e4-4d6b-afcc-18a5f4cf0138", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'preproc_flip' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m preproc_model \u001b[38;5;241m=\u001b[39m preproc_flip()\n\u001b[1;32m 2\u001b[0m Convert_ONNX(preproc_model,\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodels_mask/preproc_flip.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m, input_data_mock\u001b[38;5;241m=\u001b[39mtest_in\u001b[38;5;241m.\u001b[39mto(device))\n", "\u001b[0;31mNameError\u001b[0m: name 'preproc_flip' is not defined" ] } ], "source": [ "preproc_model = preproc_flip()\n", "Convert_ONNX(preproc_model,f'models_mask/preproc_flip.onnx', input_data_mock=test_in.to(device))\n", "# Convert_ONNX(model.module,f'models_mask/model_test.onnx', input_data_mock=intermediate.to(device))" ] }, { "cell_type": "code", "execution_count": 7, "id": "30e84a9b-0d4f-4cb2-a92b-2e3f0b2ccb20", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([32, 192, 2048])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_in.shape" ] }, { "cell_type": "code", "execution_count": 13, "id": "1bb26727-7914-470e-bb48-43d7ee81cb50", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "torch.flip(test_in[0,:,:], dims = (0,)) - torch.flipud(test_in[0,:,:])" ] }, { "cell_type": "code", "execution_count": 29, "id": "aeaaab90-6a2a-4851-a1ca-28c54a446573", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(float)\n", "torch.float32\n", "Input Name: modelInput\n", "Output Name: modelOutput\n", "[array([[ 4.3262615, -4.3409047],\n", " [ 4.9648395, -4.968621 ],\n", " [ 5.5126643, -5.522872 ],\n", " [ 4.7735534, -4.8004475],\n", " [ 4.0924144, -4.112945 ],\n", " [ 4.588802 , -4.6043544],\n", " [ 4.6231914, -4.617625 ],\n", " [ 5.229881 , -5.2555394],\n", " [ 4.877381 , -4.882144 ],\n", " [ 5.2514744, -5.2786503],\n", " [ 4.2948875, -4.3169603],\n", " [ 4.5997186, -4.6177607],\n", " [ 4.9509926, -4.9685597],\n", " [ 4.933158 , -4.9568825],\n", " [ 4.747336 , -4.7639017],\n", " [ 5.020595 , -5.0202913],\n", " [ 4.914437 , -4.9206715],\n", " [ 5.193108 , -5.1925435],\n", " [ 4.5233765, -4.512763 ],\n", " [ 4.7573333, -4.762632 ],\n", " [ 5.268702 , -5.2838397],\n", " [ 4.857734 , -4.8605857],\n", " [ 5.1886744, -5.2047734],\n", " [ 5.512568 , -5.5503583],\n", " [ 5.320961 , -5.344709 ],\n", " [ 4.1023226, -4.1073256],\n", " [ 5.17857 , -5.185736 ],\n", " [ 4.997028 , -4.9933476],\n", " [ 4.771303 , -4.767269 ],\n", " [ 5.312805 , -5.3265243],\n", " [ 5.0030336, -5.0492 ],\n", " [ 5.429731 , -5.4249325]], dtype=float32)]\n" ] } ], "source": [ "import onnxruntime as ort\n", "import onnx\n", "\n", "# Path to your ONNX model\n", "model_path = \"models/model-47-99.125.onnx\"\n", "\n", "# Load the ONNX model\n", "session = ort.InferenceSession(model_path)\n", "\n", "# Get input and output details\n", "input_name = session.get_inputs()[0].name\n", "output_name = session.get_outputs()[0].name\n", "\n", "print(session.get_inputs()[0].type)\n", "print(test_in.dtype)\n", "\n", "print(f\"Input Name: {input_name}\")\n", "print(f\"Output Name: {output_name}\")\n", "\n", "# Example Input Data (Replace with your actual input data)\n", "import numpy as np\n", "\n", "# Perform inference\n", "outputs = session.run([output_name], {input_name: intermediate.cpu().numpy()})\n", "print(outputs)\n", "\n", "onnx_model = onnx.load(model_path)" ] }, { "cell_type": "code", "execution_count": 30, "id": "f250739d-4c8a-4752-964a-d0b929c396f4", "metadata": {}, "outputs": [], "source": [ "# import onnxruntime as ort\n", "# import onnx\n", "\n", "# # Path to your ONNX model\n", "# model_path = \"models_mask/preproc_test.onnx\"\n", "\n", "# # Load the ONNX model\n", "# session = ort.InferenceSession(model_path)\n", "\n", "# # Get input and output details\n", "# input_name = session.get_inputs()[0].name\n", "# output_name = session.get_outputs()[0].name\n", "\n", "# print(session.get_inputs()[0].type)\n", "# print(test_in.dtype)\n", "\n", "# print(f\"Input Name: {input_name}\")\n", "# print(f\"Output Name: {output_name}\")\n", "\n", "# # Example Input Data (Replace with your actual input data)\n", "# import numpy as np\n", "\n", "# # Perform inference\n", "# outputs = session.run([output_name], {input_name: test_in.cpu().numpy()})\n", "# print(\"Model Output:\", outputs)\n", "\n", "# onnx_model = onnx.load(model_path)" ] }, { "cell_type": "code", "execution_count": 8, "id": "24fed4e7-4838-44cc-9c3a-0862bdbe173a", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAey0lEQVR4nO3df0xd9f3H8deFymXMcpURL8WCzG124o/L5Jd0dpblboQ6snZZxvaHItm6ZcFFc6NL+w+4rJMsMUiynIXtmyDZr8gaIy5zqanXH/gDQwvFVfEXjhkWvZc26r3luoBezvePxau0UHvhlvs5nOcjuX/ccw/nvDkhl2fuPedej23btgAAAAyRk+0BAAAAPok4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGCUTdkeIF2Li4t66623tHnzZnk8nmyPAwAAzoFt2zp16pRKS0uVk3P210YcFydvvfWWysrKsj0GAABYhZmZGW3duvWs6zgmTizLkmVZ+vDDDyX975crLCzM8lQAAOBcxONxlZWVafPmzZ+6rsdp360Tj8fl8/kUi8WIEwAAHCKd/9+cEAsAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMQpwAAACjECcAAMAoxAkAADBKVuJkenpajY2Nqqys1DXXXKNEIpGNMQAAgIE2ZWOnt956qw4cOKAdO3bonXfekdfrzcYYy7vbd9r9WHbmAADApdY9Tl566SVdcMEF2rFjhySpqKhovUcAAAAGS/ttneHhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo6mHnv99dd14YUXqqWlRdddd53uueeeNf0CAABgY0k7ThKJhAKBgCzLWvbxwcFBhUIhdXV1aXx8XIFAQE1NTZqdnZUkffjhh3r66af129/+ViMjIzp8+LAOHz68tt8CAABsGGnHSXNzsw4cOKA9e/Ys+3hPT4/27t2r9vZ2VVZWqq+vTwUFBerv75ckXXrppaqpqVFZWZm8Xq927dqliYmJFfc3Pz+veDy+5AYAADaujF6ts7CwoLGxMQWDwY93kJOjYDCokZERSVJtba1mZ2f17rvvanFxUcPDw7ryyitX3GZ3d7d8Pl/qVlZWlsmRAQCAYTIaJydPnlQymZTf71+y3O/3KxKJSJI2bdqke+65R1/72td07bXX6ktf+pK+9a1vrbjN/fv3KxaLpW4zMzOZHBkAABgmK5cSNzc3q7m5+ZzW9Xq9Zl1qDAAAzquMvnJSXFys3NxcRaPRJcuj0ahKSkrWtG3LslRZWana2to1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw2E1NDSsadsdHR2anJzUkSNH1jomAAAwWNpv68zNzWlqaip1f3p6WhMTEyoqKlJ5eblCoZDa2tpUU1Ojuro69fb2KpFIqL29PaODAwCAjSntODl69KgaGxtT90OhkCSpra1NAwMDam1t1YkTJ9TZ2alIJKKqqiodOnTojJNkAQAAluOxbdvO9hDnwrIsWZalZDKp1157TbFYTIWFhZnfEd+tAwBAxsXjcfl8vnP6/52VbyVeDc45AQDAHRwTJwAAwB0cEydcSgwAgDs4Jk54WwcAAHdwTJwAAAB3IE4AAIBRiBMAAGAUx8QJJ8QCAOAOjokTTogFAMAdHBMnAADAHYgTAABgFOIEAAAYxTFxwgmxAAC4g2PihBNiAQBwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAURwTJ1ytAwCAOzgmTrhaBwAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFMfECR/CBgCAOzgmTvgQNgAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwsfXAwDgDo6JEz6+HgAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABhlUzZ2WlFRocLCQuXk5Ojiiy/WE088kY0xAACAgbISJ5L03HPP6cILL8zW7gEAgKF4WwcAABgl7TgZHh5WS0uLSktL5fF4NDQ0dMY6lmWpoqJC+fn5qq+v1+jo6JLHPR6PbrzxRtXW1urPf/7zqocHAAAbT9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsbGqdZ555RmNjY/rb3/6me+65R//85z9X/xsAAIANJe04aW5u1oEDB7Rnz55lH+/p6dHevXvV3t6uyspK9fX1qaCgQP39/al1Lr30UknSli1btGvXLo2Pj6+4v/n5ecXj8SU3AACwcWX0nJOFhQWNjY0pGAx+vIOcHAWDQY2MjEj63ysvp06dkiTNzc3p8ccf11VXXbXiNru7u+Xz+VK3srKyTI4MAAAMk9E4OXnypJLJpPx+/5Llfr9fkUhEkhSNRnXDDTcoEAjo+uuv1y233KLa2toVt7l//37FYrHUbWZmJpMjAwAAw6z7pcSXX365XnjhhXNe3+v1yuv1nseJAACASTL6yklxcbFyc3MVjUaXLI9GoyopKVnTti3LUmVl5VlfZQEAAM6X0TjJy8tTdXW1wuFwatni4qLC4bAaGhrWtO2Ojg5NTk7qyJEjax0TAAAYLO23debm5jQ1NZW6Pz09rYmJCRUVFam8vFyhUEhtbW2qqalRXV2dent7lUgk1N7entHBAQDAxpR2nBw9elSNjY2p+6FQSJLU1tamgYEBtba26sSJE+rs7FQkElFVVZUOHTp0xkmy6bIsS5ZlKZlMrmk7AADAbB7btu1sD5GOeDwun8+nWCymwsLCzO/gbt9p92OZ3wcAAC6Tzv9vvlsHAAAYhTgBAABGcUyccCkxAADu4Jg44VJiAADcwTFxAgAA3IE4AQAARnFMnHDOCQAA7uCYOOGcEwAA3MExcQIAANyBOAEAAEYhTgAAgFGIEwAAYBTHxAlX6wAA4A6OiROu1gEAwB0cEycAAMAdiBMAAGAU4gQAABiFOAEAAEZxTJxwtQ4AAO7gmDjhah0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEZxTJzwOScAALiDY+KEzzkBAMAdHBMnAADAHYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTHxAnfrQMAgDs4Jk74bh0AANzBMXECAADcgTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABglKzFyfvvv6/LLrtMd955Z7ZGAAAABspanPzqV7/S9ddfn63dAwAAQ2UlTl5//XW98soram5uzsbuAQCAwdKOk+HhYbW0tKi0tFQej0dDQ0NnrGNZlioqKpSfn6/6+nqNjo4uefzOO+9Ud3f3qocGAAAbV9pxkkgkFAgEZFnWso8PDg4qFAqpq6tL4+PjCgQCampq0uzsrCTp4Ycf1hVXXKErrrhibZMDAIANaVO6P9Dc3HzWt2N6enq0d+9etbe3S5L6+vr0yCOPqL+/X/v27dPzzz+vBx54QAcPHtTc3Jw++OADFRYWqrOzc9ntzc/Pa35+PnU/Ho+nOzIAAHCQjJ5zsrCwoLGxMQWDwY93kJOjYDCokZERSVJ3d7dmZmb073//W/fee6/27t27Yph8tL7P50vdysrKMjkyAAAwTEbj5OTJk0omk/L7/UuW+/1+RSKRVW1z//79isViqdvMzEwmRgUAAIZK+22dTLr11ls/dR2v1yuv13v+hwEAAEbI6CsnxcXFys3NVTQaXbI8Go2qpKRkTdu2LEuVlZWqra1d03YAAIDZMhoneXl5qq6uVjgcTi1bXFxUOBxWQ0PDmrbd0dGhyclJHTlyZK1jAgAAg6X9ts7c3JympqZS96enpzUxMaGioiKVl5crFAqpra1NNTU1qqurU29vrxKJROrqHQAAgLNJO06OHj2qxsbG1P1QKCRJamtr08DAgFpbW3XixAl1dnYqEomoqqpKhw4dOuMk2XRZliXLspRMJte0HQAAYDaPbdt2todIRzwel8/nUywWU2FhYeZ3cLfvtPuxzO8DAACXSef/d9a++A8AAGA5xAkAADCKY+KES4kBAHAHx8QJlxIDAOAOjokTAADgDsQJAAAwimPihHNOAABwB8fECeecAADgDo6JEwAA4A7ECQAAMApxAgAAjOKYOOGEWAAA3MExccIJsQAAuINj4gQAALgDcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oRLiQEAcAfHxAmXEgMA4A6OiRMAAOAOxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMIpj4oTPOQEAwB0cEyd8zgkAAO7gmDgBAADuQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjOKYOOHj6wEAcAfHxAkfXw8AgDs4Jk4AAIA7ECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwCnECAACMsu5x8t5776mmpkZVVVW6+uqr9X//93/rPQIAADDYpvXe4ebNmzU8PKyCggIlEgldffXV+s53vqPPfe5z6z0KAAAw0Lq/cpKbm6uCggJJ0vz8vGzblm3b6z0GAAAwVNpxMjw8rJaWFpWWlsrj8WhoaOiMdSzLUkVFhfLz81VfX6/R0dElj7/33nsKBALaunWr7rrrLhUXF6/6FwAAABtL2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Oxsap2LLrpIL7zwgqanp/WXv/xF0Wh09b8BAADYUNKOk+bmZh04cEB79uxZ9vGenh7t3btX7e3tqqysVF9fnwoKCtTf33/Gun6/X4FAQE8//fSK+5ufn1c8Hl9yAwAAG1dGzzlZWFjQ2NiYgsHgxzvIyVEwGNTIyIgkKRqN6tSpU5KkWCym4eFhbdu2bcVtdnd3y+fzpW5lZWWZHBkAABgmo3Fy8uRJJZNJ+f3+Jcv9fr8ikYgk6c0339SOHTsUCAS0Y8cO/exnP9M111yz4jb379+vWCyWus3MzGRyZAAAYJh1v5S4rq5OExMT57y+1+uV1+s9fwMBAACjZPSVk+LiYuXm5p5xgms0GlVJScmatm1ZliorK1VbW7um7QAAALNlNE7y8vJUXV2tcDicWra4uKhwOKyGhoY1bbujo0OTk5M6cuTIWscEAAAGS/ttnbm5OU1NTaXuT09Pa2JiQkVFRSovL1coFFJbW5tqampUV1en3t5eJRIJtbe3Z3RwAACwMaUdJ0ePHlVjY2PqfigUkiS1tbVpYGBAra2tOnHihDo7OxWJRFRVVaVDhw6dcZIsAADAcjy2Qz473rIsWZalZDKp1157TbFYTIWFhZnf0d2+0+7HMr8PAABcJh6Py+fzndP/73X/bp3V4pwTAADcwTFxAgAA3MExccKlxAAAuINj4oS3dQAAcAfHxAkAAHAH4gQAABiFOAEAAEZxTJxwQiwAAO7gmDjhhFgAANzBMXECAADcgTgBAABGIU4AAIBRHBMnnBALAIA7OCZOOCEWAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYxTFxwtU6AAC4g2PihKt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYxTFxwuecAADgDo6JEz7nBAAAd3BMnAAAAHcgTgAAgFGIEwAAYBTiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfLcOAADu4Jg44bt1AABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFHWPU5mZma0c+dOVVZW6tprr9XBgwfXewQAAGCwTeu+w02b1Nvbq6qqKkUiEVVXV2vXrl367Gc/u96jAAAAA617nGzZskVbtmyRJJWUlKi4uFjvvPMOcQIAACSt4m2d4eFhtbS0qLS0VB6PR0NDQ2esY1mWKioqlJ+fr/r6eo2Oji67rbGxMSWTSZWVlaU9OAAA2JjSjpNEIqFAICDLspZ9fHBwUKFQSF1dXRofH1cgEFBTU5NmZ2eXrPfOO+/olltu0e9///vVTQ4AADaktN/WaW5uVnNz84qP9/T0aO/evWpvb5ck9fX16ZFHHlF/f7/27dsnSZqfn9fu3bu1b98+bd++/az7m5+f1/z8fOp+PB5Pd2QAAOAgGb1aZ2FhQWNjYwoGgx/vICdHwWBQIyMjkiTbtnXrrbfq61//um6++eZP3WZ3d7d8Pl/qxltAAABsbBmNk5MnTyqZTMrv9y9Z7vf7FYlEJEnPPvusBgcHNTQ0pKqqKlVVVen48eMrbnP//v2KxWKp28zMTCZHBgAAhln3q3VuuOEGLS4unvP6Xq9XXq/3PE4EAABMktFXToqLi5Wbm6toNLpkeTQaVUlJyZq2bVmWKisrVVtbu6btAAAAs2U0TvLy8lRdXa1wOJxatri4qHA4rIaGhjVtu6OjQ5OTkzpy5MhaxwQAAAZL+22dubk5TU1Npe5PT09rYmJCRUVFKi8vVygUUltbm2pqalRXV6fe3l4lEonU1TsAAABnk3acHD16VI2Njan7oVBIktTW1qaBgQG1trbqxIkT6uzsVCQSUVVVlQ4dOnTGSbLpsixLlmUpmUyuaTsAAMBsHtu27WwPkY54PC6fz6dYLKbCwsLM7+Bu32n3Y5nfBwAALpPO/+91/1ZiAACAsyFOAACAURwTJ1xKDACAOzgmTriUGAAAd3BMnAAAAHcgTgAAgFEcEyeccwIAgDs4Jk445wQAAHdwTJwAAAB3IE4AAIBRiBMAAGAU4gQAABjFMXHC1ToAALiDY+KEq3UAAHAHx8QJAABwB+IEAAAYhTgBAABGIU4AAIBRHBMnXK0DAIA7OCZOuFoHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRHBMnfM4JAADu4Jg44XNOAABwB8fECQAAcAfiBAAAGIU4AQAARiFOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABjFMXHCd+sAAOAOjokTvlsHAAB3cEycAAAAdyBOAACAUYgTAABgFOIEAAAYhTgBAABGIU4AAIBRiBMAAGAU4gQAABiFOAEAAEYhTgAAgFGIEwAAYBTiBAAAGCUrcbJnzx5dfPHF+u53v5uN3QMAAINlJU5uv/12/eEPf8jGrgEAgOGyEic7d+7U5s2bs7FrAABguLTjZHh4WC0tLSotLZXH49HQ0NAZ61iWpYqKCuXn56u+vl6jo6OZmBUAALhA2nGSSCQUCARkWdayjw8ODioUCqmrq0vj4+MKBAJqamrS7Ozsqgacn59XPB5fcgMAABtX2nHS3NysAwcOaM+ePcs+3tPTo71796q9vV2VlZXq6+tTQUGB+vv7VzVgd3e3fD5f6lZWVraq7QAAAGfI6DknCwsLGhsbUzAY/HgHOTkKBoMaGRlZ1Tb379+vWCyWus3MzGRqXAAAYKBNmdzYyZMnlUwm5ff7lyz3+/165ZVXUveDwaBeeOEFJRIJbd26VQcPHlRDQ8Oy2/R6vfJ6vZkcEwAAGCyjcXKuHnvssbR/xrIsWZalZDJ5HiYCAMCl7vYtsyy2/nN8Qkbf1ikuLlZubq6i0eiS5dFoVCUlJWvadkdHhyYnJ3XkyJE1bQcAAJgto3GSl5en6upqhcPh1LLFxUWFw+EV37YBAAD4pLTf1pmbm9PU1FTq/vT0tCYmJlRUVKTy8nKFQiG1tbWppqZGdXV16u3tVSKRUHt7e0YHBwAAG1PacXL06FE1Njam7odCIUlSW1ubBgYG1NraqhMnTqizs1ORSERVVVU6dOjQGSfJpotzTgAAcAePbdt2todIRzwel8/nUywWU2FhYeZ3cPqJQVk+KQgAgPNqnU6ITef/d1a+WwcAAGAlxAkAADCKY+LEsixVVlaqtrY226MAAIDzyDFxwuecAADgDo6JEwAA4A7ECQAAMIpj4oRzTgAAcAfHxAnnnAAA4A6OiRMAAOAOxAkAADBK2t+tk20ffdp+PB4/PzuYP+3T/M/XfgAAMMHp//ek8/K/76P/2+fyrTmO+W6dj774b2FhQW+88Ua2xwEAAKswMzOjrVu3nnUdx8TJRxYXF/XWW29p8+bN8ng8Gd12PB5XWVmZZmZmzs+XCroIxzJzOJaZw7HMHI5l5rjlWNq2rVOnTqm0tFQ5OWc/q8Rxb+vk5OR8anGtVWFh4Yb+A1lPHMvM4VhmDscycziWmeOGY+nzLfMNyMvghFgAAGAU4gQAABiFOPkEr9errq4ueb3ebI/ieBzLzOFYZg7HMnM4lpnDsTyT406IBQAAGxuvnAAAAKMQJwAAwCjECQAAMApxAgAAjOK6OLEsSxUVFcrPz1d9fb1GR0fPuv7Bgwf15S9/Wfn5+brmmmv0j3/8Y50mNV86x3JgYEAej2fJLT8/fx2nNdPw8LBaWlpUWloqj8ejoaGhT/2ZJ598Utddd528Xq+++MUvamBg4LzP6QTpHssnn3zyjL9Jj8ejSCSyPgMbrLu7W7W1tdq8ebMuueQS7d69W6+++uqn/hzPl2dazbHk+dJlcTI4OKhQKKSuri6Nj48rEAioqalJs7Ozy67/3HPP6Qc/+IF++MMf6tixY9q9e7d2796tF198cZ0nN0+6x1L636cfvv3226nbm2++uY4TmymRSCgQCMiyrHNaf3p6WjfddJMaGxs1MTGhO+64Qz/60Y/06KOPnudJzZfusfzIq6++uuTv8pJLLjlPEzrHU089pY6ODj3//PM6fPiwPvjgA33zm99UIpFY8Wd4vlzeao6lxPOlbBepq6uzOzo6UveTyaRdWlpqd3d3L7v+9773Pfumm25asqy+vt7+yU9+cl7ndIJ0j+X9999v+3y+dZrOmSTZDz300FnX+fnPf25fddVVS5a1trbaTU1N53Ey5zmXY/nEE0/Ykux33313XWZystnZWVuS/dRTT624Ds+X5+ZcjiXPl7btmldOFhYWNDY2pmAwmFqWk5OjYDCokZGRZX9mZGRkyfqS1NTUtOL6brGaYylJc3Nzuuyyy1RWVqZvf/vbeumll9Zj3A2Fv8nMq6qq0pYtW/SNb3xDzz77bLbHMVIsFpMkFRUVrbgOf5vn5lyOpcTzpWvi5OTJk0omk/L7/UuW+/3+Fd9jjkQiaa3vFqs5ltu2bVN/f78efvhh/elPf9Li4qK2b9+u//znP+sx8oax0t9kPB7Xf//73yxN5UxbtmxRX1+fHnzwQT344IMqKyvTzp07NT4+nu3RjLK4uKg77rhDX/3qV3X11VevuB7Pl5/uXI8lz5cO/FZiOFNDQ4MaGhpS97dv364rr7xSv/vd7/TLX/4yi5PBrbZt26Zt27al7m/fvl1vvPGG7rvvPv3xj3/M4mRm6ejo0Isvvqhnnnkm26M43rkeS54vXfTKSXFxsXJzcxWNRpcsj0ajKikpWfZnSkpK0lrfLVZzLE93wQUX6Ctf+YqmpqbOx4gb1kp/k4WFhfrMZz6Tpak2jrq6Ov4mP+G2227T3//+dz3xxBPaunXrWdfl+fLs0jmWp3Pj86Vr4iQvL0/V1dUKh8OpZYuLiwqHw0sK9ZMaGhqWrC9Jhw8fXnF9t1jNsTxdMpnU8ePHtWXLlvM15obE3+T5NTExwd+kJNu2ddttt+mhhx7S448/rs9//vOf+jP8bS5vNcfydK58vsz2Gbnr6YEHHrC9Xq89MDBgT05O2j/+8Y/tiy66yI5EIrZt2/bNN99s79u3L7X+s88+a2/atMm+99577Zdfftnu6uqyL7jgAvv48ePZ+hWMke6x/MUvfmE/+uij9htvvGGPjY3Z3//+9+38/Hz7pZdeytavYIRTp07Zx44ds48dO2ZLsnt6euxjx47Zb775pm3btr1v3z775ptvTq3/r3/9yy4oKLDvuusu++WXX7Yty7Jzc3PtQ4cOZetXMEa6x/K+++6zh4aG7Ndff90+fvy4ffvtt9s5OTn2Y489lq1fwRg//elPbZ/PZz/55JP222+/nbq9//77qXV4vjw3qzmWPF/atqvixLZt+ze/+Y1dXl5u5+Xl2XV1dfbzzz+feuzGG2+029ralqz/17/+1b7iiivsvLw8+6qrrrIfeeSRdZ7YXOkcyzvuuCO1rt/vt3ft2mWPj49nYWqzfHQ56+m3j45dW1ubfeONN57xM1VVVXZeXp59+eWX2/fff/+6z22idI/lr3/9a/sLX/iCnZ+fbxcVFdk7d+60H3/88ewMb5jljqOkJX9rPF+em9UcS54vbdtj27a9fq/TAAAAnJ1rzjkBAADOQJwAAACjECcAAMAoxAkAADAKcQIAAIxCnAAAAKMQJwAAwCjECQAAMApxAgAAjEKcAAAAoxAnAADAKMQJAAAwyv8D6KAeY7AISbEAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "plt.hist(abs(intermediate-outputs[0]).ravel(), bins = 100)\n", "plt.yscale('log')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 15, "id": "71cb219e-b91a-4629-99f6-00db786903c7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.1902e-03, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00,\n", " 2.7140e+00, 2.7140e+00, 2.7140e+00, 2.7140e+00])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sort(abs(intermediate-outputs[0]).ravel())[0][-10:]" ] }, { "cell_type": "code", "execution_count": null, "id": "92ba5920-5451-4bb1-af0e-5ea987841ab1", "metadata": {}, "outputs": [], "source": [ "import onnxruntime as ort\n", "\n", "session_options = ort.SessionOptions()\n", "session_options.log_severity_level = 0 # Verbose logging\n", "session = ort.InferenceSession(\"models_mask/preproc_test.onnx\", sess_options=session_options)" ] }, { "cell_type": "code", "execution_count": null, "id": "3277b343-245d-4ac8-a91c-373061dcbf53", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "plt.imshow(outputs[0][0,8,:,:])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "67e99ef8-e49a-4037-a818-244555b0bdc5", "metadata": {}, "outputs": [], "source": [ "import onnx\n", "\n", "# Path to your ONNX model\n", "model_path = \"models/model-47-99.125.onnx\"\n", "\n", "# Load the ONNX model\n", "onnx_model = onnx.load(model_path)\n", "\n", "# Check the model for validity\n", "onnx.checker.check_model(onnx_model)\n", "\n", "# Print model graph structure (optional)\n", "print(onnx.helper.printable_graph(onnx_model.graph))\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }