{ "cells": [ { "cell_type": "markdown", "id": "62e7e36e", "metadata": {}, "source": [ "Requirements: Pytorch, mat73, numpy\n", "\n", "```pip install mat73```\n", "\n", "相關論文: https://ieeexplore.ieee.org/document/9124646" ] }, { "cell_type": "code", "execution_count": null, "id": "2962af4f", "metadata": {}, "outputs": [], "source": [ "import glob\n", "from os.path import *\n", "import numpy as np\n", "import random\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import time\n", "import sys\n", "import mat73\n", "import matplotlib.pyplot as plt\n", "from scipy.io import savemat\n", "import os\n", "import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "6151e083", "metadata": {}, "outputs": [], "source": [ "\n", "class DoubleConv(nn.Module):\n", " \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n", "\n", " def __init__(self, in_channels, out_channels, mid_channels=None):\n", " super().__init__()\n", " if not mid_channels:\n", " mid_channels = out_channels\n", " self.double_conv = nn.Sequential(\n", " nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1),\n", " nn.GroupNorm(num_groups=8, num_channels=mid_channels),\n", " nn.ReLU(inplace=True),\n", " nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1),\n", " nn.GroupNorm(num_groups=8, num_channels=out_channels),\n", " nn.ReLU(inplace=True)\n", " )\n", "\n", " def forward(self, x):\n", " return self.double_conv(x)\n", "\n", "\n", "class Down(nn.Module):\n", " \"\"\"Downscaling with maxpool then double conv\"\"\"\n", "\n", " def __init__(self, in_channels, out_channels):\n", " super().__init__()\n", " self.maxpool_conv = nn.Sequential(\n", " nn.MaxPool1d(2),\n", " DoubleConv(in_channels, out_channels)\n", " )\n", "\n", " def forward(self, x):\n", " return self.maxpool_conv(x)\n", "\n", "\n", "class Up(nn.Module):\n", " \"\"\"Upscaling then double conv\"\"\"\n", "\n", " def __init__(self, in_channels, out_channels):\n", " super().__init__()\n", "\n", " self.up = nn.Upsample(\n", " scale_factor=2, mode='linear', align_corners=True)\n", " self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n", "\n", " \n", " def forward(self, x1, x2):\n", " x1 = self.up(x1)\n", " # input is CHW\n", " diffX = x2.size()[2] - x1.size()[2]\n", "\n", " x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])\n", " # if you have padding issues, see\n", " # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n", " # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n", " x = torch.cat([x2, x1], dim=1)\n", " return self.conv(x)\n", "\n", "\n", "class OutConv(nn.Module):\n", " def __init__(self, in_channels, out_channels):\n", " super(OutConv, self).__init__()\n", " self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)\n", "\n", " def forward(self, x):\n", " return self.conv(x)\n", "\n", "\n", "class UNet1d(nn.Module):\n", " def __init__(self, n_channels, n_classes, nfilter=24):\n", " super(UNet1d, self).__init__()\n", " self.n_channels = n_channels\n", " self.n_classes = n_classes\n", "\n", " self.inc = DoubleConv(n_channels, nfilter)\n", " self.down1 = Down(nfilter, nfilter * 2)\n", " self.down2 = Down(nfilter * 2, nfilter * 4)\n", " self.down3 = Down(nfilter * 4, nfilter * 8)\n", " self.down4 = Down(nfilter * 8, nfilter * 8)\n", " self.up1 = Up(nfilter * 16, nfilter * 4)\n", " self.up2 = Up(nfilter * 8, nfilter * 2)\n", " self.up3 = Up(nfilter * 4, nfilter * 1)\n", " self.up4 = Up(nfilter * 2, nfilter)\n", " self.outc = OutConv(nfilter, n_classes)\n", "\n", " def forward(self, x):\n", " x1 = self.inc(x)\n", " x2 = self.down1(x1)\n", " x3 = self.down2(x2)\n", " x4 = self.down3(x3)\n", " x5 = self.down4(x4)\n", " x = self.up1(x5, x4)\n", " x = self.up2(x, x3)\n", " x = self.up3(x, x2)\n", " x = self.up4(x, x1)\n", " logits = self.outc(x)\n", " return logits\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1f21a6a9", "metadata": {}, "outputs": [], "source": [ "from scipy.signal import butter, sosfilt, sosfreqz\n", "\n", "def butter_bandpass(lowcut, highcut, fs, order=5):\n", " nyq = 0.5 * fs\n", " low = lowcut / nyq\n", " high = highcut / nyq\n", " sos = butter(order, [low, high], analog=False, btype='band', output='sos')\n", " return sos\n", "\n", "def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):\n", " sos = butter_bandpass(lowcut, highcut, fs, order=order)\n", " y = sosfilt(sos, data)\n", " return y" ] }, { "cell_type": "code", "execution_count": null, "id": "71b4e84a", "metadata": {}, "outputs": [], "source": [ "f = r'/NFS/tyhuang/fhlinXBCG/noscan/170320_CLY/analysis/EyeClose1_noscan.mat'\n", "\n", "def norm_ecg(ecg):\n", " min1, max1 = np.percentile(ecg, [1, 99])\n", " ecg[ecg>max1] = max1\n", " ecg[ecg