JacobLinCool commited on
Commit
29b2705
·
0 Parent(s):

Create Space

Browse files
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BCGunet
3
+ emoji: 🧠
4
+ colorFrom: yellow
5
+ colorTo: orange
6
+ sdk: gradio
7
+ app_file: bcgrun_web.py
8
+ python_version: "3.10"
9
+ sdk_version: "3.27.0"
10
+ pinned: false
11
+ ---
12
+
13
+ # BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner
14
+
15
+ Ballistocardiogram (BCG) is the induced electric potentials caused by heartbeats when the EEG data are collected with a strong ambient magnetic field. This scenario is common in concurrent EEG-MRI acquisitions. One particular application of concurrent EEG-MRI is to delineate the irritative zones responsible for generating inter-ictal spikes (IIS) in medically refractory epilepsy patients. Specifically, EEG is used to detect onsets of spikes and these timing are used to inform functional MRI time series analysis. However, with strong BCG artifacts, the EEG data can be seriously corrupted and thus make spike annotation difficult.
16
+
17
+ This project aims at using machine learning approaches to suppress BCG artifacts. We will use Unet as the artificual neural network structure to tackle this challenge.
18
+
19
+ ![](https://github.com/fahsuanlin/BCGunet/blob/main/images/alpha_annot.png)
20
+
21
+ ## Data
22
+
23
+ Data were EEG time series collected inside a 3T MRI scanner (Skyra, Siemens). EEG were sampled by a 32-channel systemm (Brain Products) with electrodes arranged by the international 10-20 standard. EEG were sampled at 5,000 Hz.
24
+
25
+ ### Eyes open/closed in healthy control subjects
26
+ A tar ball is [here](https://drive.google.com/file/d/1Te94WlQ4nGCT3rnij_w0pbPFhRcaphGJ/view?usp=share_link). Each subject had two sessions of data. One was "eyes-open" and the other was "eyes-closed", where subjects were instructed laying in the MRI without falling sleep but keeping their eyes open and closed, respectively. This is a resting-state recording.
27
+ During the recording, the MRI scanner did not collect any images. No so-called "gradient artifacts" caused by the swithcing of the imaging gradient coils of MRI was present.
28
+
29
+ [This tar ball](https://drive.google.com/file/d/1Hfu5w0-CT6p3g82yXIp-7wIi6921DW4m/view?usp=share_link) includes the EEG data taken *outside* MRI, including "eyes-open" and "eyes-closed" conditions. In other words, these data can be taken as the gold standard to see how much alpha oscillation power increased when eyes were closed.
30
+
31
+ The more complicated case is that when MRI was collecting the data. This imposed a strong "gradient artifacts" over EEG. Thus it takes extra efforts to deal with both gradient artifacts and BCG artifacts at the same time. [Here](https://drive.google.com/file/d/1oZAjCnec73ErwkuMxulUv_v3XipMZ7N_/view?usp=sharing) is the file with concurrent MRI-EEG when echo-planar imaging was used. Note that the gradient artifacts in these data have been suppressed already.
32
+
33
+ ### Steady-state visual evoked potential (SSVEP)
34
+ Check [this page](https://github.com/fahsuanlin/labmanual/wiki/21.-Sample-data:-Steady-state-visual-potential) for some details.
35
+
36
+ ## Code
37
+ - [Data input (Matlab)](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/read_eeg.m): An example of reading EEG data. Each EEG recording has three files with .eeg, .vmrk, and .vhdr file suffix. Supply the .vmrk and .vmrk file names to read data into Matlab. Need functions at [bvaloader](https://github.com/stefanSchinkel/bvaloader).
38
+
39
+ **NOTE**: Do not change the file names because data are associated with the file name.
40
+
41
+ - [Unet basic structure and BCG suppression (Python)](https://github.com/fahsuanlin/BCGunet/blob/main/bcg_unet/unet1d-simple.ipynb): perform BCG suppression by Unet, including training and testing of data from the same subject.
42
+
43
+ - Assessment (Matlab):
44
+
45
+ -- Alpha oscillations in eyes-closed vs. eyes-opened conditions: Calculate the alpha-band (10-Hz) power at all EEG electrodes. We expect that stronger alpha-band neural oscillations are found at the parietal lobe of the subject when he/she closed eyes than opened eyes after successful BCG artifact suppression. Download [our toolbox](https://github.com/fahsuanlin/fhlin_toolbox) to use the function in the following codes to calculate the average of 10-Hz oscillatory power across all EEG channels (in columns; with data stored in `EEG`) using the Morlet wavelet transform with 5-cycle. EEG data were sampled at 5,000 Hz denoted by `sfreq` variable.
46
+
47
+ ```
48
+ sfreq=5000;
49
+ mean(abs(inverse_waveletcoef(10,double(EEG),sfreq,5)),2);
50
+ ```
51
+ [A sample script](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/calc_alpha_unet.m) of calculating alpha-band oscillations across subjects and between conditions.
52
+
53
+
54
+ - Rendering (Matlab): tools to render EEG data over a scalp.
55
+
56
+ Use the [EEG topolgoy definition fiile](https://github.com/fahsuanlin/BCGunet/blob/main/matlab/bem.mat) to draw 10-Hz power distribution. Download [our toolbox](https://github.com/fahsuanlin/fhlin_toolbox) to use the function in the following codes.
57
+ ```
58
+ load bem.mat;
59
+ verts_osc_electrode_idx(end-2:end,:)=[]; %last three channels are not needed.
60
+ etc_render_topo('vol_vertex',verts_osc,'vol_face',faces_osc-1,'topo_vertex',verts_osc_electrode_idx-1,'topo_stc',mean(EEG_unet_close,2)./mean(EEG_unet_open,2),'topo_smooth',10,'topo_threshold',[1.25 1.5],'topo_stc_timevec_unit','Hz','view_angle',[0 50]);
61
+ ```
62
+
63
+ ## Resources
64
+ - [Our lab routine](https://github.com/fahsuanlin/labmanual/wiki/18.-Suppression-of-ballistocardiography-artifacts-in-EEG-collected-inside-MRI) in BCG suppression.
65
+ - A conventional PCA-based BCG suppression method using [Optimal Basis Sets (OBS)](https://www.sciencedirect.com/science/article/abs/pii/S1053811905004726?via%3Dihub).
66
+ - An RNN-type BCG artifact suppression method (BCGnet) can be found [here](https://github.com/jiaangyao/BCGNet)
67
+ - [Our BCG suppression protocol](https://github.com/fahsuanlin/labmanual/wiki/18.-Suppression-of-ballistocardiography-artifacts-in-EEG-collected-inside-MRI)
bcgrun_web.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.io import savemat
3
+ import h5py
4
+ from bcgunet import bcgunet
5
+ import platform
6
+ import os
7
+ import time
8
+ import os
9
+ import gradio as gr
10
+ import matplotlib
11
+ import matplotlib.pyplot as plt
12
+
13
+ matplotlib.use("agg")
14
+
15
+ dir = os.path.dirname(os.path.realpath(__file__)) + "/tmp"
16
+ os.makedirs(dir, exist_ok=True)
17
+
18
+
19
+ def run(
20
+ files: list[bytes],
21
+ lr: float,
22
+ winsec: int,
23
+ iters: int,
24
+ onecycle: bool,
25
+ ecg: str,
26
+ bce: str,
27
+ eeg: str,
28
+ ) -> tuple[list[str], str]:
29
+ task = os.path.join(dir, str(int(time.time())))
30
+ os.makedirs(task)
31
+
32
+ outputs = []
33
+
34
+ for i, file in enumerate(files):
35
+ input = os.path.join(task, str(i) + ".mat")
36
+ with open(input, "wb") as o:
37
+ o.write(file)
38
+
39
+ output = os.path.join(task, str(i) + "_clean.mat")
40
+
41
+ mat = h5py.File(input, "r")
42
+ ECG = np.array(mat[ecg]).flatten()
43
+ EEG = np.array(mat[bce]).T
44
+
45
+ EEG_unet = bcgunet.run(
46
+ EEG,
47
+ ECG,
48
+ iter_num=iters,
49
+ winsize_sec=winsec,
50
+ lr=lr,
51
+ onecycle=onecycle,
52
+ )
53
+ result = dict()
54
+ result[eeg] = EEG_unet
55
+
56
+ savemat(output, result, do_compression=True)
57
+ outputs.append(output)
58
+
59
+ if i == 0:
60
+ plt.figure(figsize=(12, 6), dpi=300)
61
+ plt.plot(EEG[19, :10000], "b.-", label="Orig EEG")
62
+ plt.plot(EEG_unet[19, :10000], "g.-", label="U-Net")
63
+ plt.legend()
64
+ plt.title("BCG Unet")
65
+ plt.xlabel("Time (samples)")
66
+ plot = os.path.join(task, str(i) + ".png")
67
+ plt.savefig(plot)
68
+
69
+ return outputs, plot
70
+
71
+
72
+ def main():
73
+ app = gr.Interface(
74
+ title="BCG Unet",
75
+ description="BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner",
76
+ fn=run,
77
+ inputs=[
78
+ gr.File(
79
+ label="Input Files (.mat)",
80
+ type="binary",
81
+ file_types=["mat"],
82
+ file_count=["multiple", "directory"],
83
+ ),
84
+ gr.Slider(
85
+ label="Learning Rate", minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-3
86
+ ),
87
+ gr.Slider(
88
+ label="Window Size (seconds)", minimum=1, maximum=10, step=1, value=2
89
+ ),
90
+ gr.Slider(
91
+ label="Number of Iterations",
92
+ minimum=1000,
93
+ maximum=10000,
94
+ step=1000,
95
+ value=5000,
96
+ ),
97
+ gr.Checkbox(
98
+ label="One Cycle Scheduler",
99
+ value=True,
100
+ ),
101
+ gr.Textbox(
102
+ label="Variable name for ECG (input)",
103
+ value="ECG",
104
+ ),
105
+ gr.Textbox(
106
+ label="Variable name for BCG corropted EEG (input)",
107
+ value="EEG_before_bcg",
108
+ ),
109
+ gr.Textbox(
110
+ label="Variable name for clean EEG (output)",
111
+ value="EEG_clean",
112
+ ),
113
+ ],
114
+ outputs=[
115
+ gr.File(label="Output File", file_count="multiple"),
116
+ gr.Image(label="Output Image", type="filepath"),
117
+ ],
118
+ allow_flagging="never",
119
+ )
120
+
121
+ app.launch()
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
126
+ if platform.system() == "Windows":
127
+ os.system("pause")
bcgunet/__init__.py ADDED
File without changes
bcgunet/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
bcgunet/__pycache__/bcgunet.cpython-310.pyc ADDED
Binary file (3.44 kB). View file
 
bcgunet/__pycache__/unet.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
bcgunet/bcgunet.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import *
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import time
7
+ import tqdm
8
+ from scipy.signal import butter, sosfilt
9
+ from .unet import UNet1d
10
+
11
+
12
+ def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
13
+ nyq = 0.5 * fs
14
+ low = lowcut / nyq
15
+ high = highcut / nyq
16
+ sos = butter(order, [low, high], analog=False, btype="band", output="sos")
17
+ y = sosfilt(sos, data)
18
+ return y
19
+
20
+
21
+ def norm(ecg):
22
+ min1, max1 = np.percentile(ecg, [1, 99])
23
+ ecg[ecg > max1] = max1
24
+ ecg[ecg < min1] = min1
25
+ ecg = (ecg - min1) / (max1 - min1)
26
+ return ecg
27
+
28
+
29
+ def run(
30
+ input_eeg,
31
+ input_ecg=None,
32
+ sfreq=5000,
33
+ iter_num=5000,
34
+ winsize_sec=2,
35
+ lr=1e-3,
36
+ onecycle=True,
37
+ ):
38
+ window = winsize_sec * sfreq
39
+ eeg_raw = input_eeg
40
+ eeg_channel = eeg_raw.shape[0]
41
+
42
+ eeg_filtered = eeg_raw * 0
43
+ t = time.time()
44
+ for ii in range(eeg_channel):
45
+ eeg_filtered[ii, ...] = butter_bandpass_filter(
46
+ eeg_raw[ii, :], 0.5, sfreq * 0.4, sfreq
47
+ )
48
+
49
+ baseline = eeg_raw - eeg_filtered
50
+
51
+ if input_ecg is None:
52
+ from sklearn.decomposition import PCA
53
+
54
+ pca = PCA(n_components=1)
55
+ ecg = norm(pca.fit_transform(eeg_filtered.T)[:, 0].flatten())
56
+ else:
57
+ ecg = norm(input_ecg.flatten())
58
+
59
+ torch.cuda.empty_cache()
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ NET = UNet1d(n_channels=1, n_classes=eeg_channel, nfilter=8).to(device)
62
+ optimizer = torch.optim.Adam(NET.parameters(), lr=lr)
63
+ optimizer.zero_grad()
64
+ maxlen = ecg.size
65
+ if onecycle:
66
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
67
+ optimizer, lr, total_steps=iter_num
68
+ )
69
+ else:
70
+ # constant learning rate
71
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
72
+
73
+ loss_list = []
74
+
75
+ # randomly get windows in ECG signal
76
+
77
+ index_all = (np.random.random_sample(iter_num) * (maxlen - window)).astype(int)
78
+
79
+ pbar = tqdm.tqdm(index_all)
80
+ count = 0
81
+ for index in pbar:
82
+ count += 1
83
+ ECG = ecg[index : (index + window)]
84
+ EEG = eeg_filtered[:, index : (index + window)]
85
+ ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
86
+ EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
87
+
88
+ # step 3: forward path of UNET
89
+ logits = NET(ECG_d)
90
+ loss = nn.MSELoss()(logits, EEG_d)
91
+ loss_list.append(loss.item())
92
+
93
+ # Step 5: Perform back-propagation
94
+ loss.backward() # accumulate the gradients
95
+ optimizer.step() # Update network weights according to the optimizer
96
+ optimizer.zero_grad() # empty the gradients
97
+ scheduler.step()
98
+
99
+ if count % 50 == 0:
100
+ pbar.set_description(
101
+ f"Loss {np.mean(loss_list):.3f}, lr: {optimizer.param_groups[0]['lr']:.5f}"
102
+ )
103
+ loss_list = []
104
+
105
+ EEG = eeg_filtered
106
+ # ECG = norm(butter_bandpass_filter(data['ECG'], 0.5, 20, sfreq))
107
+ ECG = ecg
108
+ ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
109
+ EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
110
+ with torch.no_grad():
111
+ logits = NET(ECG_d)
112
+ BCG_pred = logits.cpu().detach().numpy()[0, ...]
113
+
114
+ neweeg = EEG - BCG_pred + baseline
115
+
116
+ return neweeg
117
+
118
+
119
+ def morlet_psd(signal, sample_rate=5000, freq=10, wavelet="morl"):
120
+ import pywt
121
+
122
+ # Define the wavelet and scales to be used
123
+
124
+ scales = np.arange(sample_rate)
125
+ freqs = pywt.scale2frequency("morl", scales) * sample_rate
126
+ indx = np.argmin(abs(freqs - freq))
127
+
128
+ scale = scales[indx]
129
+
130
+ # scale = pywt.frequency2scale('morl', freq/sample_rate)
131
+
132
+ # Calculate the wavelet coefficients
133
+ coeffs, freq = pywt.cwt(signal, scale, wavelet, 1 / sample_rate)
134
+ # Calculate the power (magnitude squared) of the coefficients
135
+ power = np.abs(coeffs) ** 2
136
+
137
+ # Average the power across time to get the power spectral density
138
+ psd = np.mean(power, axis=1)
139
+
140
+ return psd
141
+
142
+
143
+ def get_psd(eeg, sfreq=5000, freq=10):
144
+ psd = []
145
+ for ii in tqdm.tqdm(range(eeg.shape[0])):
146
+ psd.append(morlet_psd(eeg[ii], sample_rate=sfreq, freq=freq))
147
+
148
+ return np.array(psd)
bcgunet/unet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+ """(convolution => [BN] => ReLU) * 2"""
8
+
9
+ def __init__(self, in_channels, out_channels, mid_channels=None):
10
+ super().__init__()
11
+ if not mid_channels:
12
+ mid_channels = out_channels
13
+ self.double_conv = nn.Sequential(
14
+ nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1),
15
+ nn.GroupNorm(num_groups=4, num_channels=mid_channels),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1),
18
+ nn.GroupNorm(num_groups=4, num_channels=out_channels),
19
+ nn.ReLU(inplace=True),
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.double_conv(x)
24
+
25
+
26
+ class DoubleConvX(nn.Module):
27
+ """(convolution => [BN] => ReLU) * 2"""
28
+
29
+ def __init__(self, in_channels, out_channels, mid_channels=None):
30
+ super().__init__()
31
+ if not mid_channels:
32
+ mid_channels = out_channels
33
+ self.double_conv = nn.Sequential(
34
+ nn.Conv1d(in_channels, mid_channels, kernel_size=15, padding=7),
35
+ nn.GroupNorm(num_groups=8, num_channels=mid_channels),
36
+ nn.ReLU(inplace=True),
37
+ nn.Conv1d(mid_channels, out_channels, kernel_size=15, padding=7),
38
+ nn.GroupNorm(num_groups=8, num_channels=out_channels),
39
+ nn.ReLU(inplace=True),
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.double_conv(x)
44
+
45
+
46
+ class Down(nn.Module):
47
+ """Downscaling with maxpool then double conv"""
48
+
49
+ def __init__(self, in_channels, out_channels):
50
+ super().__init__()
51
+ self.maxpool_conv = nn.Sequential(
52
+ nn.MaxPool1d(2), DoubleConv(in_channels, out_channels)
53
+ )
54
+
55
+ def forward(self, x):
56
+ return self.maxpool_conv(x)
57
+
58
+
59
+ class Up(nn.Module):
60
+ """Upscaling then double conv"""
61
+
62
+ def __init__(self, in_channels, out_channels):
63
+ super().__init__()
64
+
65
+ self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
66
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
67
+
68
+ def forward(self, x1, x2):
69
+ x1 = self.up(x1)
70
+ # input is CHW
71
+ diffX = x2.size()[2] - x1.size()[2]
72
+
73
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])
74
+ # if you have padding issues, see
75
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
76
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
77
+ x = torch.cat([x2, x1], dim=1)
78
+ return self.conv(x)
79
+
80
+
81
+ class OutConv(nn.Module):
82
+ def __init__(self, in_channels, out_channels):
83
+ super(OutConv, self).__init__()
84
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
85
+
86
+ def forward(self, x):
87
+ return self.conv(x)
88
+
89
+
90
+ class UNet1d(nn.Module):
91
+ def __init__(self, n_channels, n_classes, nfilter=24):
92
+ super(UNet1d, self).__init__()
93
+ self.n_channels = n_channels
94
+ self.n_classes = n_classes
95
+
96
+ self.inc = DoubleConv(n_channels, nfilter)
97
+ self.down1 = Down(nfilter, nfilter * 2)
98
+ self.down2 = Down(nfilter * 2, nfilter * 4)
99
+ self.down3 = Down(nfilter * 4, nfilter * 8)
100
+ self.down4 = Down(nfilter * 8, nfilter * 8)
101
+ self.up1 = Up(nfilter * 16, nfilter * 4)
102
+ self.up2 = Up(nfilter * 8, nfilter * 2)
103
+ self.up3 = Up(nfilter * 4, nfilter * 1)
104
+ self.up4 = Up(nfilter * 2, nfilter)
105
+ self.outc = OutConv(nfilter, n_classes)
106
+
107
+ def forward(self, x):
108
+ x1 = self.inc(x)
109
+ x2 = self.down1(x1)
110
+ x3 = self.down2(x2)
111
+ x4 = self.down3(x3)
112
+ x5 = self.down4(x4)
113
+ x = self.up1(x5, x4)
114
+ x = self.up2(x, x3)
115
+ x = self.up3(x, x2)
116
+ x = self.up4(x, x1)
117
+ logits = self.outc(x)
118
+ return logits
119
+
120
+
121
+ if __name__ == "__main__":
122
+ model = UNet1d(1, 1)
123
+ print(model)
bcgunet/unet1d-simple.ipynb ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "62e7e36e",
6
+ "metadata": {},
7
+ "source": [
8
+ "Requirements: Pytorch, mat73, numpy\n",
9
+ "\n",
10
+ "```pip install mat73```\n",
11
+ "\n",
12
+ "相關論文: https://ieeexplore.ieee.org/document/9124646"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "2962af4f",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import glob\n",
23
+ "from os.path import *\n",
24
+ "import numpy as np\n",
25
+ "import random\n",
26
+ "import torch\n",
27
+ "import torch.nn as nn\n",
28
+ "import torch.nn.functional as F\n",
29
+ "import time\n",
30
+ "import sys\n",
31
+ "import mat73\n",
32
+ "import matplotlib.pyplot as plt\n",
33
+ "from scipy.io import savemat\n",
34
+ "import os\n",
35
+ "import tqdm"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "6151e083",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "\n",
46
+ "class DoubleConv(nn.Module):\n",
47
+ " \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n",
48
+ "\n",
49
+ " def __init__(self, in_channels, out_channels, mid_channels=None):\n",
50
+ " super().__init__()\n",
51
+ " if not mid_channels:\n",
52
+ " mid_channels = out_channels\n",
53
+ " self.double_conv = nn.Sequential(\n",
54
+ " nn.Conv1d(in_channels, mid_channels, kernel_size=3, padding=1),\n",
55
+ " nn.GroupNorm(num_groups=8, num_channels=mid_channels),\n",
56
+ " nn.ReLU(inplace=True),\n",
57
+ " nn.Conv1d(mid_channels, out_channels, kernel_size=3, padding=1),\n",
58
+ " nn.GroupNorm(num_groups=8, num_channels=out_channels),\n",
59
+ " nn.ReLU(inplace=True)\n",
60
+ " )\n",
61
+ "\n",
62
+ " def forward(self, x):\n",
63
+ " return self.double_conv(x)\n",
64
+ "\n",
65
+ "\n",
66
+ "class Down(nn.Module):\n",
67
+ " \"\"\"Downscaling with maxpool then double conv\"\"\"\n",
68
+ "\n",
69
+ " def __init__(self, in_channels, out_channels):\n",
70
+ " super().__init__()\n",
71
+ " self.maxpool_conv = nn.Sequential(\n",
72
+ " nn.MaxPool1d(2),\n",
73
+ " DoubleConv(in_channels, out_channels)\n",
74
+ " )\n",
75
+ "\n",
76
+ " def forward(self, x):\n",
77
+ " return self.maxpool_conv(x)\n",
78
+ "\n",
79
+ "\n",
80
+ "class Up(nn.Module):\n",
81
+ " \"\"\"Upscaling then double conv\"\"\"\n",
82
+ "\n",
83
+ " def __init__(self, in_channels, out_channels):\n",
84
+ " super().__init__()\n",
85
+ "\n",
86
+ " self.up = nn.Upsample(\n",
87
+ " scale_factor=2, mode='linear', align_corners=True)\n",
88
+ " self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n",
89
+ "\n",
90
+ " \n",
91
+ " def forward(self, x1, x2):\n",
92
+ " x1 = self.up(x1)\n",
93
+ " # input is CHW\n",
94
+ " diffX = x2.size()[2] - x1.size()[2]\n",
95
+ "\n",
96
+ " x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2])\n",
97
+ " # if you have padding issues, see\n",
98
+ " # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n",
99
+ " # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n",
100
+ " x = torch.cat([x2, x1], dim=1)\n",
101
+ " return self.conv(x)\n",
102
+ "\n",
103
+ "\n",
104
+ "class OutConv(nn.Module):\n",
105
+ " def __init__(self, in_channels, out_channels):\n",
106
+ " super(OutConv, self).__init__()\n",
107
+ " self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)\n",
108
+ "\n",
109
+ " def forward(self, x):\n",
110
+ " return self.conv(x)\n",
111
+ "\n",
112
+ "\n",
113
+ "class UNet1d(nn.Module):\n",
114
+ " def __init__(self, n_channels, n_classes, nfilter=24):\n",
115
+ " super(UNet1d, self).__init__()\n",
116
+ " self.n_channels = n_channels\n",
117
+ " self.n_classes = n_classes\n",
118
+ "\n",
119
+ " self.inc = DoubleConv(n_channels, nfilter)\n",
120
+ " self.down1 = Down(nfilter, nfilter * 2)\n",
121
+ " self.down2 = Down(nfilter * 2, nfilter * 4)\n",
122
+ " self.down3 = Down(nfilter * 4, nfilter * 8)\n",
123
+ " self.down4 = Down(nfilter * 8, nfilter * 8)\n",
124
+ " self.up1 = Up(nfilter * 16, nfilter * 4)\n",
125
+ " self.up2 = Up(nfilter * 8, nfilter * 2)\n",
126
+ " self.up3 = Up(nfilter * 4, nfilter * 1)\n",
127
+ " self.up4 = Up(nfilter * 2, nfilter)\n",
128
+ " self.outc = OutConv(nfilter, n_classes)\n",
129
+ "\n",
130
+ " def forward(self, x):\n",
131
+ " x1 = self.inc(x)\n",
132
+ " x2 = self.down1(x1)\n",
133
+ " x3 = self.down2(x2)\n",
134
+ " x4 = self.down3(x3)\n",
135
+ " x5 = self.down4(x4)\n",
136
+ " x = self.up1(x5, x4)\n",
137
+ " x = self.up2(x, x3)\n",
138
+ " x = self.up3(x, x2)\n",
139
+ " x = self.up4(x, x1)\n",
140
+ " logits = self.outc(x)\n",
141
+ " return logits\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "1f21a6a9",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "from scipy.signal import butter, sosfilt, sosfreqz\n",
152
+ "\n",
153
+ "def butter_bandpass(lowcut, highcut, fs, order=5):\n",
154
+ " nyq = 0.5 * fs\n",
155
+ " low = lowcut / nyq\n",
156
+ " high = highcut / nyq\n",
157
+ " sos = butter(order, [low, high], analog=False, btype='band', output='sos')\n",
158
+ " return sos\n",
159
+ "\n",
160
+ "def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):\n",
161
+ " sos = butter_bandpass(lowcut, highcut, fs, order=order)\n",
162
+ " y = sosfilt(sos, data)\n",
163
+ " return y"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "71b4e84a",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "f = r'/NFS/tyhuang/fhlinXBCG/noscan/170320_CLY/analysis/EyeClose1_noscan.mat'\n",
174
+ "\n",
175
+ "def norm_ecg(ecg):\n",
176
+ " min1, max1 = np.percentile(ecg, [1, 99])\n",
177
+ " ecg[ecg>max1] = max1\n",
178
+ " ecg[ecg<min1] = min1\n",
179
+ " ecg = (ecg - min1)/(max1-min1)\n",
180
+ " return ecg\n",
181
+ "\n",
182
+ "\n",
183
+ "data = mat73.loadmat(f)\n",
184
+ "eeg_filtered = data['EEG_before_bcg'] * 0\n",
185
+ "t = time.time()\n",
186
+ "for ii in range(31):\n",
187
+ " eeg_filtered[ii, ...] = butter_bandpass_filter(data['EEG_before_bcg'][ii,:], 1, 40, 5000)\n",
188
+ "\n",
189
+ "torch.cuda.empty_cache()\n",
190
+ "device = ('cuda' if torch.cuda.is_available() else 'cpu')\n",
191
+ "NET = UNet1d(n_channels=1, n_classes=31, nfilter=8).to(device)\n",
192
+ "#NET = torch.load('pretrainbcg_f8.pt')\n",
193
+ "#NET.outc = OutConv(8, 31)\n",
194
+ "NET = NET.to(device)\n",
195
+ "optimizer = torch.optim.Adam(NET.parameters(), lr=5e-4)\n",
196
+ "optimizer.zero_grad()\n",
197
+ "maxlen = data['ECG'].size\n",
198
+ "\n",
199
+ "loss_list = []\n",
200
+ "count = 0\n",
201
+ "ecg = norm_ecg(data['ECG'])\n",
202
+ "for ii in range(5000):\n",
203
+ " if ii % 10 == 0:\n",
204
+ " sys.stdout.write('.')\n",
205
+ " index = random.randrange(maxlen-20000)\n",
206
+ " ECG = ecg[index:(index+20000)]\n",
207
+ " EEG = eeg_filtered[:, index:(index+20000)]\n",
208
+ " ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
209
+ " EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
210
+ "\n",
211
+ " # step 3: forward path of UNET\n",
212
+ " logits = NET(ECG_d)\n",
213
+ " loss = nn.MSELoss()(logits, EEG_d)\n",
214
+ " loss_list.append(loss.item())\n",
215
+ "\n",
216
+ "\n",
217
+ " # Step 5: Perform back-propagation\n",
218
+ " loss.backward() #accumulate the gradients\n",
219
+ " optimizer.step() #Update network weights according to the optimizer\n",
220
+ " optimizer.zero_grad() #empty the gradients\n",
221
+ "\n",
222
+ "\n",
223
+ " if (ii + 1) % 500 == 0: #plot results per 500 iterations\n",
224
+ " print('mse loss: ', np.mean(loss_list))\n",
225
+ " loss_list = []\n",
226
+ " EEG = eeg_filtered[:, :50000]\n",
227
+ " ECG = data['ECG'][:50000]\n",
228
+ " ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
229
+ " EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
230
+ " logits = NET(ECG_d) \n",
231
+ " EEG_pred = logits.cpu().detach().numpy()\n",
232
+ " plt.figure(figsize=(12, 6), dpi=300)\n",
233
+ "\n",
234
+ " plt.plot(EEG[0, ...], 'g')\n",
235
+ " plt.plot(EEG[0, ...] - EEG_pred[0, 0, ...], 'r')\n",
236
+ " time1 = round(time.time() - t, 1)\n",
237
+ " plt.title(f' {time1} seconds')\n",
238
+ " plt.show()\n",
239
+ "\n",
240
+ "\n",
241
+ "# remove BCG from the whole dataset\n",
242
+ "EEG = eeg_filtered\n",
243
+ "ECG = data['ECG']\n",
244
+ "ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()\n",
245
+ "EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()\n",
246
+ "logits = NET(ECG_d)\n",
247
+ "BCG_pred = logits.cpu().detach().numpy()[0, ...]\n",
248
+ "EEG_removeBCG_unet = eeg_filtered - BCG_pred\n",
249
+ "\n"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "id": "4dc44936",
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": []
259
+ }
260
+ ],
261
+ "metadata": {
262
+ "kernelspec": {
263
+ "display_name": "Python 3 (ipykernel)",
264
+ "language": "python",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "codemirror_mode": {
269
+ "name": "ipython",
270
+ "version": 3
271
+ },
272
+ "file_extension": ".py",
273
+ "mimetype": "text/x-python",
274
+ "name": "python",
275
+ "nbconvert_exporter": "python",
276
+ "pygments_lexer": "ipython3",
277
+ "version": "3.7.10"
278
+ }
279
+ },
280
+ "nbformat": 4,
281
+ "nbformat_minor": 5
282
+ }
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==22.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
2
+ aiohttp==3.8.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
3
+ aiosignal==1.3.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
4
+ altair==4.2.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
5
+ anyio==3.6.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
6
+ async-timeout==4.0.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
7
+ attrs==23.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
8
+ certifi==2022.12.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
9
+ charset-normalizer==3.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
10
+ click==8.1.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
11
+ cmake==3.26.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
12
+ colorama==0.4.6 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0" and platform_system == "Windows"
13
+ contourpy==1.0.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
14
+ cycler==0.11.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
15
+ entrypoints==0.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
16
+ fastapi==0.95.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
17
+ ffmpy==0.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
18
+ filelock==3.12.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
19
+ fonttools==4.39.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
20
+ frozenlist==1.3.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
21
+ fsspec==2023.4.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
22
+ gradio-client==0.1.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
23
+ gradio==3.27.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
24
+ h11==0.14.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
25
+ h5py==3.8.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
26
+ httpcore==0.17.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
27
+ httpx==0.24.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
28
+ huggingface-hub==0.14.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
29
+ idna==3.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
30
+ jinja2==3.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
31
+ joblib==1.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
32
+ jsonschema==4.17.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
33
+ kiwisolver==1.4.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
34
+ linkify-it-py==2.0.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
35
+ lit==16.0.2 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
36
+ markdown-it-py==2.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
37
+ markdown-it-py[linkify]==2.2.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
38
+ markupsafe==2.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
39
+ matplotlib==3.7.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
40
+ mdit-py-plugins==0.3.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
41
+ mdurl==0.1.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
42
+ mpmath==1.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
43
+ multidict==6.0.4 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
44
+ networkx==3.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
45
+ numpy==1.24.3 ; python_version >= "3.10" and python_full_version < "3.11.0"
46
+ nvidia-cublas-cu11==11.10.3.66 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
47
+ nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
48
+ nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
49
+ nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
50
+ nvidia-cudnn-cu11==8.5.0.96 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
51
+ nvidia-cufft-cu11==10.9.0.58 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
52
+ nvidia-curand-cu11==10.2.10.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
53
+ nvidia-cusolver-cu11==11.4.0.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
54
+ nvidia-cusparse-cu11==11.7.4.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
55
+ nvidia-nccl-cu11==2.14.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
56
+ nvidia-nvtx-cu11==11.7.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
57
+ orjson==3.8.10 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
58
+ packaging==23.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
59
+ pandas==2.0.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
60
+ pillow==9.5.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
61
+ pydantic==1.10.7 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
62
+ pydub==0.25.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
63
+ pyparsing==3.0.9 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
64
+ pyrsistent==0.19.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
65
+ python-dateutil==2.8.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
66
+ python-multipart==0.0.6 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
67
+ pytz==2023.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
68
+ pywavelets==1.4.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
69
+ pyyaml==6.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
70
+ requests==2.28.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
71
+ scikit-learn==1.2.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
72
+ scipy==1.10.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
73
+ semantic-version==2.10.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
74
+ setuptools==67.7.2 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
75
+ six==1.16.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
76
+ sniffio==1.3.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
77
+ starlette==0.26.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
78
+ sympy==1.11.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
79
+ threadpoolctl==3.1.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
80
+ toolz==0.12.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
81
+ torch==2.0.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
82
+ tqdm==4.65.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
83
+ triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
84
+ typing-extensions==4.5.0 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
85
+ tzdata==2023.3 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
86
+ uc-micro-py==1.0.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
87
+ urllib3==1.26.15 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
88
+ uvicorn==0.21.1 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
89
+ websockets==11.0.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"
90
+ wheel==0.40.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.10.0" and python_full_version < "3.11.0"
91
+ yarl==1.9.2 ; python_full_version >= "3.10.0" and python_full_version < "3.11.0"