Commit
·
29b2705
0
Parent(s):
Create Space
Browse files- README.md +67 -0
- bcgrun_web.py +127 -0
- bcgunet/__init__.py +0 -0
- bcgunet/__pycache__/__init__.cpython-310.pyc +0 -0
- bcgunet/__pycache__/bcgunet.cpython-310.pyc +0 -0
- bcgunet/__pycache__/unet.cpython-310.pyc +0 -0
- bcgunet/bcgunet.py +148 -0
- bcgunet/unet.py +123 -0
- bcgunet/unet1d-simple.ipynb +282 -0
- requirements.txt +91 -0
README.md
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: BCGunet
|
3 |
+
emoji: 🧠
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: orange
|
6 |
+
sdk: gradio
|
7 |
+
app_file: bcgrun_web.py
|
8 |
+
python_version: "3.10"
|
9 |
+
sdk_version: "3.27.0"
|
10 |
+
pinned: false
|
11 |
+
---
|
12 |
+
|
13 |
+
# BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner
|
14 |
+
|
15 |
+
Ballistocardiogram (BCG) is the induced electric potentials caused by heartbeats when the EEG data are collected with a strong ambient magnetic field. This scenario is common in concurrent EEG-MRI acquisitions. One particular application of concurrent EEG-MRI is to delineate the irritative zones responsible for generating inter-ictal spikes (IIS) in medically refractory epilepsy patients. Specifically, EEG is used to detect onsets of spikes and these timing are used to inform functional MRI time series analysis. However, with strong BCG artifacts, the EEG data can be seriously corrupted and thus make spike annotation difficult.
|
16 |
+
|
17 |
+
This project aims at using machine learning approaches to suppress BCG artifacts. We will use Unet as the artificual neural network structure to tackle this challenge.
|
18 |
+
|
19 |
+

|
20 |
+
|
21 |
+
## Data
|
22 |
+
|
23 |
+
Data were EEG time series collected inside a 3T MRI scanner (Skyra, Siemens). EEG were sampled by a 32-channel systemm (Brain Products) with electrodes arranged by the international 10-20 standard. EEG were sampled at 5,000 Hz.
|
24 |
+
|
25 |
+
### Eyes open/closed in healthy control subjects
|
26 |
+
A tar ball is [here](https://drive.google.com/file/d/1Te94WlQ4nGCT3rnij_w0pbPFhRcaphGJ/view?usp=share_link). Each subject had two sessions of data. One was "eyes-open" and the other was "eyes-closed", where subjects were instructed laying in the MRI without falling sleep but keeping their eyes open and closed, respectively. This is a resting-state recording.
|
27 |
+
During the recording, the MRI scanner did not collect any images. No so-called "gradient artifacts" caused by the swithcing of the imaging gradient coils of MRI was present.
|
28 |
+
|
29 |
+
[This tar ball](https://drive.google.com/file/d/1Hfu5w0-CT6p3g82yXIp-7wIi6921DW4m/view?usp=share_link) includes the EEG data taken *outside* MRI, including "eyes-open" and "eyes-closed" conditions. In other words, these data can be taken as the gold standard to see how much alpha oscillation power increased when eyes were closed.
|
30 |
+
|
31 |
+
The more complicated case is that when MRI was collecting the data. This imposed a strong "gradient artifacts" over EEG. Thus it takes extra efforts to deal with both gradient artifacts and BCG artifacts at the same time. [Here](https://drive.google.com/file/d/1oZAjCnec73ErwkuMxulUv_v3XipMZ7N_/view?usp=sharing) is the file with concurrent MRI-EEG when echo-planar imaging was used. Note that the gradient artifacts in these data have been suppressed already.
|
32 |
+
|
33 |
+
### Steady-state visual evoked potential (SSVEP)
|
34 |
+
Check [this page](https://github.com/fahsuanlin/labmanual/wiki/21.-Sample-data:-Steady-state-visual-potential) for some details.
|
35 |
+
|
36 |
+
## Code
|
37 |
+
- [Data input (Matlab)](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/read_eeg.m): An example of reading EEG data. Each EEG recording has three files with .eeg, .vmrk, and .vhdr file suffix. Supply the .vmrk and .vmrk file names to read data into Matlab. Need functions at [bvaloader](https://github.com/stefanSchinkel/bvaloader).
|
38 |
+
|
39 |
+
**NOTE**: Do not change the file names because data are associated with the file name.
|
40 |
+
|
41 |
+
- [Unet basic structure and BCG suppression (Python)](https://github.com/fahsuanlin/BCGunet/blob/main/bcg_unet/unet1d-simple.ipynb): perform BCG suppression by Unet, including training and testing of data from the same subject.
|
42 |
+
|
43 |
+
- Assessment (Matlab):
|
44 |
+
|
45 |
+
-- Alpha oscillations in eyes-closed vs. eyes-opened conditions: Calculate the alpha-band (10-Hz) power at all EEG electrodes. We expect that stronger alpha-band neural oscillations are found at the parietal lobe of the subject when he/she closed eyes than opened eyes after successful BCG artifact suppression. Download [our toolbox](https://github.com/fahsuanlin/fhlin_toolbox) to use the function in the following codes to calculate the average of 10-Hz oscillatory power across all EEG channels (in columns; with data stored in `EEG`) using the Morlet wavelet transform with 5-cycle. EEG data were sampled at 5,000 Hz denoted by `sfreq` variable.
|
46 |
+
|
47 |
+
```
|
48 |
+
sfreq=5000;
|
49 |
+
mean(abs(inverse_waveletcoef(10,double(EEG),sfreq,5)),2);
|
50 |
+
```
|
51 |
+
[A sample script](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/calc_alpha_unet.m) of calculating alpha-band oscillations across subjects and between conditions.
|
52 |
+
|
53 |
+
|
54 |
+
- Rendering (Matlab): tools to render EEG data over a scalp.
|
55 |
+
|
56 |
+
Use the [EEG topolgoy definition fiile](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/bem.mat) to draw 10-Hz power distribution. Download [our toolbox](https://github.com/fahsuanlin/fhlin_toolbox) to use the function in the following codes.
|
57 |
+
```
|
58 |
+
load bem.mat;
|
59 |
+
verts_osc_electrode_idx(end-2:end,:)=[]; %last three channels are not needed.
|
60 |
+
etc_render_topo('vol_vertex',verts_osc,'vol_face',faces_osc-1,'topo_vertex',verts_osc_electrode_idx-1,'topo_stc',mean(EEG_unet_close,2)./mean(EEG_unet_open,2),'topo_smooth',10,'topo_threshold',[1.25 1.5],'topo_stc_timevec_unit','Hz','view_angle',[0 50]);
|
61 |
+
```
|
62 |
+
|
63 |
+
## Resources
|
64 |
+
- [Our lab routine](https://github.com/fahsuanlin/labmanual/wiki/18.-Suppression-of-ballistocardiography-artifacts-in-EEG-collected-inside-MRI) in BCG suppression.
|
65 |
+
- A conventional PCA-based BCG suppression method using [Optimal Basis Sets (OBS)](https://www.sciencedirect.com/science/article/abs/pii/S1053811905004726?via%3Dihub).
|
66 |
+
- An RNN-type BCG artifact suppression method (BCGnet) can be found [here](https://github.com/jiaangyao/BCGNet)
|
67 |
+
- [Our BCG suppression protocol](https://github.com/fahsuanlin/labmanual/wiki/18.-Suppression-of-ballistocardiography-artifacts-in-EEG-collected-inside-MRI)
|
bcgrun_web.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.io import savemat
|
3 |
+
import h5py
|
4 |
+
from bcgunet import bcgunet
|
5 |
+
import platform
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import os
|
9 |
+
import gradio as gr
|
10 |
+
import matplotlib
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
matplotlib.use("agg")
|
14 |
+
|
15 |
+
dir = os.path.dirname(os.path.realpath(__file__)) + "/tmp"
|
16 |
+
os.makedirs(dir, exist_ok=True)
|
17 |
+
|
18 |
+
|
19 |
+
def run(
|
20 |
+
files: list[bytes],
|
21 |
+
lr: float,
|
22 |
+
winsec: int,
|
23 |
+
iters: int,
|
24 |
+
onecycle: bool,
|
25 |
+
ecg: str,
|
26 |
+
bce: str,
|
27 |
+
eeg: str,
|
28 |
+
) -> tuple[list[str], str]:
|
29 |
+
task = os.path.join(dir, str(int(time.time())))
|
30 |
+
os.makedirs(task)
|
31 |
+
|
32 |
+
outputs = []
|
33 |
+
|
34 |
+
for i, file in enumerate(files):
|
35 |
+
input = os.path.join(task, str(i) + ".mat")
|
36 |
+
with open(input, "wb") as o:
|
37 |
+
o.write(file)
|
38 |
+
|
39 |
+
output = os.path.join(task, str(i) + "_clean.mat")
|
40 |
+
|
41 |
+
mat = h5py.File(input, "r")
|
42 |
+
ECG = np.array(mat[ecg]).flatten()
|
43 |
+
EEG = np.array(mat[bce]).T
|
44 |
+
|
45 |
+
EEG_unet = bcgunet.run(
|
46 |
+
EEG,
|
47 |
+
ECG,
|
48 |
+
iter_num=iters,
|
49 |
+
winsize_sec=winsec,
|
50 |
+
lr=lr,
|
51 |
+
onecycle=onecycle,
|
52 |
+
)
|
53 |
+
result = dict()
|
54 |
+
result[eeg] = EEG_unet
|
55 |
+
|
56 |
+
savemat(output, result, do_compression=True)
|
57 |
+
outputs.append(output)
|
58 |
+
|
59 |
+
if i == 0:
|
60 |
+
plt.figure(figsize=(12, 6), dpi=300)
|
61 |
+
plt.plot(EEG[19, :10000], "b.-", label="Orig EEG")
|
62 |
+
plt.plot(EEG_unet[19, :10000], "g.-", label="U-Net")
|
63 |
+
plt.legend()
|
64 |
+
plt.title("BCG Unet")
|
65 |
+
plt.xlabel("Time (samples)")
|
66 |
+
plot = os.path.join(task, str(i) + ".png")
|
67 |
+
plt.savefig(plot)
|
68 |
+
|
69 |
+
return outputs, plot
|
70 |
+
|
71 |
+
|
72 |
+
def main():
|
73 |
+
app = gr.Interface(
|
74 |
+
title="BCG Unet",
|
75 |
+
description="BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner",
|
76 |
+
fn=run,
|
77 |
+
inputs=[
|
78 |
+
gr.File(
|
79 |
+
label="Input Files (.mat)",
|
80 |
+
type="binary",
|
81 |
+
file_types=["mat"],
|
82 |
+
file_count=["multiple", "directory"],
|
83 |
+
),
|
84 |
+
gr.Slider(
|
85 |
+
label="Learning Rate", minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-3
|
86 |
+
),
|
87 |
+
gr.Slider(
|
88 |
+
label="Window Size (seconds)", minimum=1, maximum=10, step=1, value=2
|
89 |
+
),
|
90 |
+
gr.Slider(
|
91 |
+
label="Number of Iterations",
|
92 |
+
minimum=1000,
|
93 |
+
maximum=10000,
|
94 |
+
step=1000,
|
95 |
+
value=5000,
|
96 |
+
),
|
97 |
+
gr.Checkbox(
|
98 |
+
label="One Cycle Scheduler",
|
99 |
+
value=True,
|
100 |
+
),
|
101 |
+
gr.Textbox(
|
102 |
+
label="Variable name for ECG (input)",
|
103 |
+
value="ECG",
|
104 |
+
),
|
105 |
+
gr.Textbox(
|
106 |
+
label="Variable name for BCG corropted EEG (input)",
|
107 |
+
value="EEG_before_bcg",
|
108 |
+
),
|
109 |
+
gr.Textbox(
|
110 |
+
label="Variable name for clean EEG (output)",
|
111 |
+
value="EEG_clean",
|
112 |
+
),
|
113 |
+
],
|
114 |
+
outputs=[
|
115 |
+
gr.File(label="Output File", file_count="multiple"),
|
116 |
+
gr.Image(label="Output Image", type="filepath"),
|
117 |
+
],
|
118 |
+
allow_flagging="never",
|
119 |
+
)
|
120 |
+
|
121 |
+
app.launch()
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|
126 |
+
if platform.system() == "Windows":
|
127 |
+
os.system("pause")
|
bcgunet/__init__.py
ADDED
File without changes
|
bcgunet/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
bcgunet/__pycache__/bcgunet.cpython-310.pyc
ADDED
Binary file (3.44 kB). View file
|
|
bcgunet/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (3.99 kB). View file
|
|
bcgunet/bcgunet.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import *
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import time
|
7 |
+
import tqdm
|
8 |
+
from scipy.signal import butter, sosfilt
|
9 |
+
from .unet import UNet1d
|
10 |
+
|
11 |
+
|
12 |
+
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
|
13 |
+
nyq = 0.5 * fs
|
14 |
+
low = lowcut / nyq
|
15 |
+
high = highcut / nyq
|
16 |
+
sos = butter(order, [low, high], analog=False, btype="band", output="sos")
|
17 |
+
y = sosfilt(sos, data)
|
18 |
+
return y
|
19 |
+
|
20 |
+
|
21 |
+
def norm(ecg):
|
22 |
+
min1, max1 = np.percentile(ecg, [1, 99])
|
23 |
+
ecg[ecg > max1] = max1
|
24 |
+
ecg[ecg < min1] = min1
|
25 |
+
ecg = (ecg - min1) / (max1 - min1)
|
26 |
+
return ecg
|
27 |
+
|
28 |
+
|
29 |
+
def run(
|
30 |
+
input_eeg,
|
31 |
+
input_ecg=None,
|
32 |
+
sfreq=5000,
|
33 |
+
iter_num=5000,
|
34 |
+
winsize_sec=2,
|
35 |
+
lr=1e-3,
|
36 |
+
onecycle=True,
|
37 |
+
):
|
38 |
+
window = winsize_sec * sfreq
|
39 |
+
eeg_raw = input_eeg
|
40 |
+
eeg_channel = eeg_raw.shape[0]
|
41 |
+
|
42 |
+
eeg_filtered = eeg_raw * 0
|
43 |
+
t = time.time()
|
44 |
+
for ii in range(eeg_channel):
|
45 |
+
eeg_filtered[ii, ...] = butter_bandpass_filter(
|
46 |
+
eeg_raw[ii, :], 0.5, sfreq * 0.4, sfreq
|
47 |
+
)
|
48 |
+
|
49 |
+
baseline = eeg_raw - eeg_filtered
|
50 |
+
|
51 |
+
if input_ecg is None:
|
52 |
+
from sklearn.decomposition import PCA
|
53 |
+
|
54 |
+
pca = PCA(n_components=1)
|
55 |
+
ecg = norm(pca.fit_transform(eeg_filtered.T)[:, 0].flatten())
|
56 |
+
else:
|
57 |
+
ecg = norm(input_ecg.flatten())
|
58 |
+
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
+
NET = UNet1d(n_channels=1, n_classes=eeg_channel, nfilter=8).to(device)
|
62 |
+
optimizer = torch.optim.Adam(NET.parameters(), lr=lr)
|
63 |
+
optimizer.zero_grad()
|
64 |
+
maxlen = ecg.size
|
65 |
+
if onecycle:
|
66 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
67 |
+
optimizer, lr, total_steps=iter_num
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
# constant learning rate
|
71 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
|
72 |
+
|
73 |
+
loss_list = []
|
74 |
+
|
75 |
+
# randomly get windows in ECG signal
|
76 |
+
|
77 |
+
index_all = (np.random.random_sample(iter_num) * (maxlen - window)).astype(int)
|
78 |
+
|
79 |
+
pbar = tqdm.tqdm(index_all)
|
80 |
+
count = 0
|
81 |
+
for index in pbar:
|
82 |
+
count += 1
|
83 |
+
ECG = ecg[index : (index + window)]
|
84 |
+
EEG = eeg_filtered[:, index : (index + window)]
|
85 |
+
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
|
86 |
+
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
|
87 |
+
|
88 |
+
# step 3: forward path of UNET
|
89 |
+
logits = NET(ECG_d)
|
90 |
+
loss = nn.MSELoss()(logits, EEG_d)
|
91 |
+
loss_list.append(loss.item())
|
92 |
+
|
93 |
+
# Step 5: Perform back-propagation
|
94 |
+
loss.backward() # accumulate the gradients
|
95 |
+
optimizer.step() # Update network weights according to the optimizer
|
96 |
+
optimizer.zero_grad() # empty the gradients
|
97 |
+
scheduler.step()
|
98 |
+
|
99 |
+
if count % 50 == 0:
|
100 |
+
pbar.set_description(
|
101 |
+
f"Loss {np.mean(loss_list):.3f}, lr: {optimizer.param_groups[0]['lr']:.5f}"
|
102 |
+
)
|
103 |
+
loss_list = []
|
104 |
+
|
105 |
+
EEG = eeg_filtered
|
106 |
+
# ECG = norm(butter_bandpass_filter(data['ECG'], 0.5, 20, sfreq))
|
107 |
+
ECG = ecg
|
108 |
+
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
|
109 |
+
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
|
110 |
+
with torch.no_grad():
|
111 |
+
logits = NET(ECG_d)
|
112 |
+
BCG_pred = logits.cpu().detach().numpy()[0, ...]
|
113 |
+
|
114 |
+
neweeg = EEG - BCG_pred + baseline
|
115 |
+
|
116 |
+
return neweeg
|
117 |
+
|
118 |
+
|
119 |
+
def morlet_psd(signal, sample_rate=5000, freq=10, wavelet="morl"):
|
120 |
+
import pywt
|
121 |
+
|
122 |
+
# Define the wavelet and scales to be used
|
123 |
+
|
124 |
+
scales = np.arange(sample_rate)
|
125 |
+
freqs = pywt.scale2frequency("morl", scales) * sample_rate
|
126 |
+
indx = np.argmin(abs(freqs - freq))
|
127 |
+
|
128 |
+
scale = scales[indx]
|
129 |
+
|
130 |
+
# scale = pywt.frequency2scale('morl', freq/sample_rate)
|
131 |
+
|
132 |
+
# Calculate the wavelet coefficients
|
133 |
+
coeffs, freq = pywt.cwt(signal, scale, wavelet, 1 / sample_rate)
|
134 |
+
# Calculate the power (magnitude squared) of the coefficients
|
135 |
+
power = np.abs(coeffs) ** 2
|
136 |
+
|
137 |
+
# Average the power across time to get the power spectral density
|
138 |
+
psd = np.mean(power, axis=1)
|
139 |
+
|
140 |
+
return psd
|
141 |
+
|
142 |
+
|
143 |
+
def get_psd(eeg, sfreq=5000, freq=10):
|
144 |
+
psd = []
|
145 |
+
for ii in tqdm.tqdm(range(eeg.shape[0])):
|
146 |
+
psd.append(morlet_psd(eeg[ii], sample_rate=sfreq, freq=freq))
|
147 |
+
|
148 |
+
return np.array(psd)
|
bcgunet/unet.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DoubleConv(nn.Module):
|
7 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
8 |
+
|
9 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
10 |
+
super().__init__()
|
11 |
+
if not mid_channels:
|
12 |
+
mid_channels = out_channels
|
13 |
+
self.double_conv = nn.Sequential(
|
14 |
+
nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1),
|
15 |
+
nn.GroupNorm(num_groups=4, num_channels=mid_channels),
|
16 |
+
nn.ReLU(inplace=True),
|
17 |
+
nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1),
|
18 |
+
nn.GroupNorm(num_groups=4, num_channels=out_channels),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.double_conv(x)
|
24 |
+
|
25 |
+
|
26 |
+
class DoubleConvX(nn.Module):
|
27 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
28 |
+
|
29 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
30 |
+
super().__init__()
|
31 |
+
if not mid_channels:
|
32 |
+
mid_channels = out_channels
|
33 |
+
self.double_conv = nn.Sequential(
|
34 |
+
nn.Conv1d(in_channels, mid_channels, kernel_size=15, padding=7),
|
35 |
+
nn.GroupNorm(num_groups=8, num_channels=mid_channels),
|
36 |
+
nn.ReLU(inplace=True),
|
37 |
+
nn.Conv1d(mid_channels, out_channels, kernel_size=15, padding=7),
|
38 |
+
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return self.double_conv(x)
|
44 |
+
|
45 |
+
|
46 |
+
class Down(nn.Module):
|
47 |
+
"""Downscaling with maxpool then double conv"""
|
48 |
+
|
49 |
+
def __init__(self, in_channels, out_channels):
|
50 |
+
super().__init__()
|
51 |
+
self.maxpool_conv = nn.Sequential(
|
52 |
+
nn.MaxPool1d(2), DoubleConv(in_channels, out_channels)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self.maxpool_conv(x)
|
57 |
+
|
58 |
+
|
59 |
+
class Up(nn.Module):
|
60 |
+
"""Upscaling then double conv"""
|
61 |
+
|
62 |
+
def __init__(self, in_channels, out_channels):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
|
66 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
67 |
+
|
68 |
+
def forward(self, x1, x2):
|
69 |
+
x1 = self.up(x1)
|
70 |
+
# input is CHW
|
71 |
+
diffX = x2.size()[2] - x1.size()[2]
|
72 |
+
|
73 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
|
74 |
+
# if you have padding issues, see
|
75 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
76 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
77 |
+
x = torch.cat([x2, x1], dim=1)
|
78 |
+
return self.conv(x)
|
79 |
+
|
80 |
+
|
81 |
+
class OutConv(nn.Module):
|
82 |
+
def __init__(self, in_channels, out_channels):
|
83 |
+
super(OutConv, self).__init__()
|
84 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.conv(x)
|
88 |
+
|
89 |
+
|
90 |
+
class UNet1d(nn.Module):
|
91 |
+
def __init__(self, n_channels, n_classes, nfilter=24):
|
92 |
+
super(UNet1d, self).__init__()
|
93 |
+
self.n_channels = n_channels
|
94 |
+
self.n_classes = n_classes
|
95 |
+
|
96 |
+
self.inc = DoubleConv(n_channels, nfilter)
|
97 |
+
self.down1 = Down(nfilter, nfilter * 2)
|
98 |
+
self.down2 = Down(nfilter * 2, nfilter * 4)
|
99 |
+
self.down3 = Down(nfilter * 4, nfilter * 8)
|
100 |
+
self.down4 = Down(nfilter * 8, nfilter * 8)
|
101 |
+
self.up1 = Up(nfilter * 16, nfilter * 4)
|
102 |
+
self.up2 = Up(nfilter * 8, nfilter * 2)
|
103 |
+
self.up3 = Up(nfilter * 4, nfilter * 1)
|
104 |
+
self.up4 = Up(nfilter * 2, nfilter)
|
105 |
+
self.outc = OutConv(nfilter, n_classes)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
x1 = self.inc(x)
|
109 |
+
x2 = self.down1(x1)
|
110 |
+
x3 = self.down2(x2)
|
111 |
+
x4 = self.down3(x3)
|
112 |
+
x5 = self.down4(x4)
|
113 |
+
x = self.up1(x5, x4)
|
114 |
+
x = self.up2(x, x3)
|
115 |
+
x = self.up3(x, x2)
|
116 |
+
x = self.up4(x, x1)
|
117 |
+
logits = self.outc(x)
|
118 |
+
return logits
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
model = UNet1d(1, 1)
|
123 |
+
print(model)
|
bcgunet/unet1d-simple.ipynb
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "62e7e36e",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"Requirements: Pytorch, mat73, numpy\n",
|
9 |
+
"\n",
|
10 |
+
"```pip install mat73```\n",
|
11 |
+
"\n",
|
12 |
+
"相關論文: https://ieeexplore.ieee.org/document/9124646"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
+
"id": "2962af4f",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"import glob\n",
|
23 |
+
"from os.path import *\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import random\n",
|
26 |
+
"import torch\n",
|
27 |
+
"import torch.nn as nn\n",
|
28 |
+
"import torch.nn.functional as F\n",
|
29 |
+
"import time\n",
|
30 |
+
"import sys\n",
|
31 |
+
"import mat73\n",
|
32 |
+
"import matplotlib.pyplot as plt\n",
|
33 |
+
"from scipy.io import savemat\n",
|
34 |
+
"import os\n",
|
35 |
+
"import tqdm"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
+
"id": "6151e083",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"\n",
|
46 |
+
"class DoubleConv(nn.Module):\n",
|
47 |
+
" \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n",
|
48 |
+
"\n",
|
49 |
+
" def __init__(self, in_channels, out_channels, mid_channels=None):\n",
|
50 |
+
" super().__init__()\n",
|
51 |
+
" if not mid_channels:\n",
|
52 |
+
" mid_channels = out_channels\n",
|
53 |
+
" self.double_conv = nn.Sequential(\n",
|
54 |
+
" nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1),\n",
|
55 |
+
" nn.GroupNorm(num_groups=8, num_channels=mid_channels),\n",
|
56 |
+
" nn.ReLU(inplace=True),\n",
|
57 |
+
" nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1),\n",
|
58 |
+
" nn.GroupNorm(num_groups=8, num_channels=out_channels),\n",
|
59 |
+
" nn.ReLU(inplace=True)\n",
|
60 |
+
" )\n",
|
61 |
+
"\n",
|
62 |
+
" def forward(self, x):\n",
|
63 |
+
" return self.double_conv(x)\n",
|
64 |
+
"\n",
|
65 |
+
"\n",
|
66 |
+
"class Down(nn.Module):\n",
|
67 |
+
" \"\"\"Downscaling with maxpool then double conv\"\"\"\n",
|
68 |
+
"\n",
|
69 |
+
" def __init__(self, in_channels, out_channels):\n",
|
70 |
+
" super().__init__()\n",
|
71 |
+
" self.maxpool_conv = nn.Sequential(\n",
|
72 |
+
" nn.MaxPool1d(2),\n",
|
73 |
+
" DoubleConv(in_channels, out_channels)\n",
|
74 |
+
" )\n",
|
75 |
+
"\n",
|
76 |
+
" def forward(self, x):\n",
|
77 |
+
" return self.maxpool_conv(x)\n",
|
78 |
+
"\n",
|
79 |
+
"\n",
|
80 |
+
"class Up(nn.Module):\n",
|
81 |
+
" \"\"\"Upscaling then double conv\"\"\"\n",
|
82 |
+
"\n",
|
83 |
+
" def __init__(self, in_channels, out_channels):\n",
|
84 |
+
" super().__init__()\n",
|
85 |
+
"\n",
|
86 |
+
" self.up = nn.Upsample(\n",
|
87 |
+
" scale_factor=2, mode='linear', align_corners=True)\n",
|
88 |
+
" self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n",
|
89 |
+
"\n",
|
90 |
+
" \n",
|
91 |
+
" def forward(self, x1, x2):\n",
|
92 |
+
" x1 = self.up(x1)\n",
|
93 |
+
" # input is CHW\n",
|
94 |
+
" diffX = x2.size()[2] - x1.size()[2]\n",
|
95 |
+
"\n",
|
96 |
+
" x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])\n",
|
97 |
+
" # if you have padding issues, see\n",
|
98 |
+
" # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n",
|
99 |
+
" # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n",
|
100 |
+
" x = torch.cat([x2, x1], dim=1)\n",
|
101 |
+
" return self.conv(x)\n",
|
102 |
+
"\n",
|
103 |
+
"\n",
|
104 |
+
"class OutConv(nn.Module):\n",
|
105 |
+
" def __init__(self, in_channels, out_channels):\n",
|
106 |
+
" super(OutConv, self).__init__()\n",
|
107 |
+
" self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)\n",
|
108 |
+
"\n",
|
109 |
+
" def forward(self, x):\n",
|
110 |
+
" return self.conv(x)\n",
|
111 |
+
"\n",
|
112 |
+
"\n",
|
113 |
+
"class UNet1d(nn.Module):\n",
|
114 |
+
" def __init__(self, n_channels, n_classes, nfilter=24):\n",
|
115 |
+
" super(UNet1d, self).__init__()\n",
|
116 |
+
" self.n_channels = n_channels\n",
|
117 |
+
" self.n_classes = n_classes\n",
|
118 |
+
"\n",
|
119 |
+
" self.inc = DoubleConv(n_channels, nfilter)\n",
|
120 |
+
" self.down1 = Down(nfilter, nfilter * 2)\n",
|
121 |
+
" self.down2 = Down(nfilter * 2, nfilter * 4)\n",
|
122 |
+
" self.down3 = Down(nfilter * 4, nfilter * 8)\n",
|
123 |
+
" self.down4 = Down(nfilter * 8, nfilter * 8)\n",
|
124 |
+
" self.up1 = Up(nfilter * 16, nfilter * 4)\n",
|
125 |
+
" self.up2 = Up(nfilter * 8, nfilter * 2)\n",
|
126 |
+
" self.up3 = Up(nfilter * 4, nfilter * 1)\n",
|
127 |
+
" self.up4 = Up(nfilter * 2, nfilter)\n",
|
128 |
+
" self.outc = OutConv(nfilter, n_classes)\n",
|
129 |
+
"\n",
|
130 |
+
" def forward(self, x):\n",
|
131 |
+
" x1 = self.inc(x)\n",
|
132 |
+
" x2 = self.down1(x1)\n",
|
133 |
+
" x3 = self.down2(x2)\n",
|
134 |
+
" x4 = self.down3(x3)\n",
|
135 |
+
" x5 = self.down4(x4)\n",
|
136 |
+
" x = self.up1(x5, x4)\n",
|
137 |
+
" x = self.up2(x, x3)\n",
|
138 |
+
" x = self.up3(x, x2)\n",
|
139 |
+
" x = self.up4(x, x1)\n",
|
140 |
+
" logits = self.outc(x)\n",
|
141 |
+
" return logits\n"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"id": "1f21a6a9",
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": [
|
151 |
+
"from scipy.signal import butter, sosfilt, sosfreqz\n",
|
152 |
+
"\n",
|
153 |
+
"def butter_bandpass(lowcut, highcut, fs, order=5):\n",
|
154 |
+
" nyq = 0.5 * fs\n",
|
155 |
+
" low = lowcut / nyq\n",
|
156 |
+
" high = highcut / nyq\n",
|
157 |
+
" sos = butter(order, [low, high], analog=False, btype='band', output='sos')\n",
|
158 |
+
" return sos\n",
|
159 |
+
"\n",
|
160 |
+
"def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):\n",
|
161 |
+
" sos = butter_bandpass(lowcut, highcut, fs, order=order)\n",
|
162 |
+
" y = sosfilt(sos, data)\n",
|
163 |
+
" return y"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"id": "71b4e84a",
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"f = r'/NFS/tyhuang/fhlinXBCG/noscan/170320_CLY/analysis/EyeClose1_noscan.mat'\n",
|
174 |
+
"\n",
|
175 |
+
"def norm_ecg(ecg):\n",
|
176 |
+
" min1, max1 = np.percentile(ecg, [1, 99])\n",
|
177 |
+
" ecg[ecg>max1] = max1\n",
|
178 |
+
" ecg[ecg<min1] = min1\n",
|
179 |
+
" ecg = (ecg - min1)/(max1-min1)\n",
|
180 |
+
" return ecg\n",
|
181 |
+
"\n",
|
182 |
+
"\n",
|
183 |
+
"data = mat73.loadmat(f)\n",
|
184 |
+
"eeg_filtered = data['EEG_before_bcg'] * 0\n",
|
185 |
+
"t = time.time()\n",
|
186 |
+
"for ii in range(31):\n",
|
187 |
+
" eeg_filtered[ii, ...] = butter_bandpass_filter(data['EEG_before_bcg'][ii,:], 1, 40, 5000)\n",
|
188 |
+
"\n",
|
189 |
+
"torch.cuda.empty_cache()\n",
|
190 |
+
"device = ('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
191 |
+
"NET = UNet1d(n_channels=1, n_classes=31, nfilter=8).to(device)\n",
|
192 |
+
"#NET = torch.load('pretrainbcg_f8.pt')\n",
|
193 |
+
"#NET.outc = OutConv(8, 31)\n",
|
194 |
+
"NET = NET.to(device)\n",
|
195 |
+
"optimizer = torch.optim.Adam(NET.parameters(), lr=5e-4)\n",
|
196 |
+
"optimizer.zero_grad()\n",
|
197 |
+
"maxlen = data['ECG'].size\n",
|
198 |
+
"\n",
|
199 |
+
"loss_list = []\n",
|
200 |
+
"count = 0\n",
|
201 |
+
"ecg = norm_ecg(data['ECG'])\n",
|
202 |
+
"for ii in range(5000):\n",
|
203 |
+
" if ii % 10 == 0:\n",
|
204 |
+
" sys.stdout.write('.')\n",
|
205 |
+
" index = random.randrange(maxlen-20000)\n",
|
206 |
+
" ECG = ecg[index:(index+20000)]\n",
|
207 |
+
" EEG = eeg_filtered[:, index:(index+20000)]\n",
|
208 |
+
" ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
|
209 |
+
" EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
|
210 |
+
"\n",
|
211 |
+
" # step 3: forward path of UNET\n",
|
212 |
+
" logits = NET(ECG_d)\n",
|
213 |
+
" loss = nn.MSELoss()(logits, EEG_d)\n",
|
214 |
+
" loss_list.append(loss.item())\n",
|
215 |
+
"\n",
|
216 |
+
"\n",
|
217 |
+
" # Step 5: Perform back-propagation\n",
|
218 |
+
" loss.backward() #accumulate the gradients\n",
|
219 |
+
" optimizer.step() #Update network weights according to the optimizer\n",
|
220 |
+
" optimizer.zero_grad() #empty the gradients\n",
|
221 |
+
"\n",
|
222 |
+
"\n",
|
223 |
+
" if (ii + 1) % 500 == 0: #plot results per 500 iterations\n",
|
224 |
+
" print('mse loss: ', np.mean(loss_list))\n",
|
225 |
+
" loss_list = []\n",
|
226 |
+
" EEG = eeg_filtered[:, :50000]\n",
|
227 |
+
" ECG = data['ECG'][:50000]\n",
|
228 |
+
" ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
|
229 |
+
" EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
|
230 |
+
" logits = NET(ECG_d) \n",
|
231 |
+
" EEG_pred = logits.cpu().detach().numpy()\n",
|
232 |
+
" plt.figure(figsize=(12, 6), dpi=300)\n",
|
233 |
+
"\n",
|
234 |
+
" plt.plot(EEG[0, ...], 'g')\n",
|
235 |
+
" plt.plot(EEG[0, ...] - EEG_pred[0, 0, ...], 'r')\n",
|
236 |
+
" time1 = round(time.time() - t, 1)\n",
|
237 |
+
" plt.title(f' {time1} seconds')\n",
|
238 |
+
" plt.show()\n",
|
239 |
+
"\n",
|
240 |
+
"\n",
|
241 |
+
"# remove BCG from the whole dataset\n",
|
242 |
+
"EEG = eeg_filtered\n",
|
243 |
+
"ECG = data['ECG']\n",
|
244 |
+
"ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
|
245 |
+
"EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
|
246 |
+
"logits = NET(ECG_d)\n",
|
247 |
+
"BCG_pred = logits.cpu().detach().numpy()[0, ...]\n",
|
248 |
+
"EEG_removeBCG_unet = eeg_filtered - BCG_pred\n",
|
249 |
+
"\n"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": null,
|
255 |
+
"id": "4dc44936",
|
256 |
+
"metadata": {},
|
257 |
+
"outputs": [],
|
258 |
+
"source": []
|
259 |
+
}
|
260 |
+
],
|
261 |
+
"metadata": {
|
262 |
+
"kernelspec": {
|
263 |
+
"display_name": "Python 3 (ipykernel)",
|
264 |
+
"language": "python",
|
265 |
+
"name": "python3"
|
266 |
+
},
|
267 |
+
"language_info": {
|
268 |
+
"codemirror_mode": {
|
269 |
+
"name": "ipython",
|
270 |
+
"version": 3
|
271 |
+
},
|
272 |
+
"file_extension": ".py",
|
273 |
+
"mimetype": "text/x-python",
|
274 |
+
"name": "python",
|
275 |
+
"nbconvert_exporter": "python",
|
276 |
+
"pygments_lexer": "ipython3",
|
277 |
+
"version": "3.7.10"
|
278 |
+
}
|
279 |
+
},
|
280 |
+
"nbformat": 4,
|
281 |
+
"nbformat_minor": 5
|
282 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==22.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
2 |
+
aiohttp==3.8.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
3 |
+
aiosignal==1.3.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
4 |
+
altair==4.2.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
5 |
+
anyio==3.6.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
6 |
+
async-timeout==4.0.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
7 |
+
attrs==23.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
8 |
+
certifi==2022.12.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
9 |
+
charset-normalizer==3.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
10 |
+
click==8.1.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
11 |
+
cmake==3.26.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
12 |
+
colorama==0.4.6 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0" and platform_system == "Windows"
|
13 |
+
contourpy==1.0.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
14 |
+
cycler==0.11.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
15 |
+
entrypoints==0.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
16 |
+
fastapi==0.95.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
17 |
+
ffmpy==0.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
18 |
+
filelock==3.12.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
19 |
+
fonttools==4.39.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
20 |
+
frozenlist==1.3.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
21 |
+
fsspec==2023.4.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
22 |
+
gradio-client==0.1.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
23 |
+
gradio==3.27.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
24 |
+
h11==0.14.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
25 |
+
h5py==3.8.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
26 |
+
httpcore==0.17.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
27 |
+
httpx==0.24.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
28 |
+
huggingface-hub==0.14.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
29 |
+
idna==3.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
30 |
+
jinja2==3.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
31 |
+
joblib==1.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
32 |
+
jsonschema==4.17.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
33 |
+
kiwisolver==1.4.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
34 |
+
linkify-it-py==2.0.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
35 |
+
lit==16.0.2 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
36 |
+
markdown-it-py==2.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
37 |
+
markdown-it-py[linkify]==2.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
38 |
+
markupsafe==2.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
39 |
+
matplotlib==3.7.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
40 |
+
mdit-py-plugins==0.3.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
41 |
+
mdurl==0.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
42 |
+
mpmath==1.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
43 |
+
multidict==6.0.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
44 |
+
networkx==3.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
45 |
+
numpy==1.24.3 ; python_version >= "3.10" and python_full_version < "3.11.0"
|
46 |
+
nvidia-cublas-cu11==11.10.3.66 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
47 |
+
nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
48 |
+
nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
49 |
+
nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
50 |
+
nvidia-cudnn-cu11==8.5.0.96 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
51 |
+
nvidia-cufft-cu11==10.9.0.58 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
52 |
+
nvidia-curand-cu11==10.2.10.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
53 |
+
nvidia-cusolver-cu11==11.4.0.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
54 |
+
nvidia-cusparse-cu11==11.7.4.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
55 |
+
nvidia-nccl-cu11==2.14.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
56 |
+
nvidia-nvtx-cu11==11.7.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
57 |
+
orjson==3.8.10 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
58 |
+
packaging==23.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
59 |
+
pandas==2.0.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
60 |
+
pillow==9.5.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
61 |
+
pydantic==1.10.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
62 |
+
pydub==0.25.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
63 |
+
pyparsing==3.0.9 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
64 |
+
pyrsistent==0.19.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
65 |
+
python-dateutil==2.8.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
66 |
+
python-multipart==0.0.6 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
67 |
+
pytz==2023.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
68 |
+
pywavelets==1.4.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
69 |
+
pyyaml==6.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
70 |
+
requests==2.28.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
71 |
+
scikit-learn==1.2.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
72 |
+
scipy==1.10.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
73 |
+
semantic-version==2.10.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
74 |
+
setuptools==67.7.2 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
75 |
+
six==1.16.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
76 |
+
sniffio==1.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
77 |
+
starlette==0.26.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
78 |
+
sympy==1.11.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
79 |
+
threadpoolctl==3.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
80 |
+
toolz==0.12.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
81 |
+
torch==2.0.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
82 |
+
tqdm==4.65.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
83 |
+
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
84 |
+
typing-extensions==4.5.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
85 |
+
tzdata==2023.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
86 |
+
uc-micro-py==1.0.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
87 |
+
urllib3==1.26.15 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
88 |
+
uvicorn==0.21.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
89 |
+
websockets==11.0.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
90 |
+
wheel==0.40.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|
91 |
+
yarl==1.9.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
|