demo / bcgunet /bcgunet.py
JacobLinCool's picture
Create Space
29b2705
raw
history blame
4.21 kB
from os.path import *
import numpy as np
import random
import torch
import torch.nn as nn
import time
import tqdm
from scipy.signal import butter, sosfilt
from .unet import UNet1d
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
sos = butter(order, [low, high], analog=False, btype="band", output="sos")
y = sosfilt(sos, data)
return y
def norm(ecg):
min1, max1 = np.percentile(ecg, [1, 99])
ecg[ecg > max1] = max1
ecg[ecg < min1] = min1
ecg = (ecg - min1) / (max1 - min1)
return ecg
def run(
input_eeg,
input_ecg=None,
sfreq=5000,
iter_num=5000,
winsize_sec=2,
lr=1e-3,
onecycle=True,
):
window = winsize_sec * sfreq
eeg_raw = input_eeg
eeg_channel = eeg_raw.shape[0]
eeg_filtered = eeg_raw * 0
t = time.time()
for ii in range(eeg_channel):
eeg_filtered[ii, ...] = butter_bandpass_filter(
eeg_raw[ii, :], 0.5, sfreq * 0.4, sfreq
)
baseline = eeg_raw - eeg_filtered
if input_ecg is None:
from sklearn.decomposition import PCA
pca = PCA(n_components=1)
ecg = norm(pca.fit_transform(eeg_filtered.T)[:, 0].flatten())
else:
ecg = norm(input_ecg.flatten())
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"
NET = UNet1d(n_channels=1, n_classes=eeg_channel, nfilter=8).to(device)
optimizer = torch.optim.Adam(NET.parameters(), lr=lr)
optimizer.zero_grad()
maxlen = ecg.size
if onecycle:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, lr, total_steps=iter_num
)
else:
# constant learning rate
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1)
loss_list = []
# randomly get windows in ECG signal
index_all = (np.random.random_sample(iter_num) * (maxlen - window)).astype(int)
pbar = tqdm.tqdm(index_all)
count = 0
for index in pbar:
count += 1
ECG = ecg[index : (index + window)]
EEG = eeg_filtered[:, index : (index + window)]
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
# step 3: forward path of UNET
logits = NET(ECG_d)
loss = nn.MSELoss()(logits, EEG_d)
loss_list.append(loss.item())
# Step 5: Perform back-propagation
loss.backward() # accumulate the gradients
optimizer.step() # Update network weights according to the optimizer
optimizer.zero_grad() # empty the gradients
scheduler.step()
if count % 50 == 0:
pbar.set_description(
f"Loss {np.mean(loss_list):.3f}, lr: {optimizer.param_groups[0]['lr']:.5f}"
)
loss_list = []
EEG = eeg_filtered
# ECG = norm(butter_bandpass_filter(data['ECG'], 0.5, 20, sfreq))
ECG = ecg
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float()
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float()
with torch.no_grad():
logits = NET(ECG_d)
BCG_pred = logits.cpu().detach().numpy()[0, ...]
neweeg = EEG - BCG_pred + baseline
return neweeg
def morlet_psd(signal, sample_rate=5000, freq=10, wavelet="morl"):
import pywt
# Define the wavelet and scales to be used
scales = np.arange(sample_rate)
freqs = pywt.scale2frequency("morl", scales) * sample_rate
indx = np.argmin(abs(freqs - freq))
scale = scales[indx]
# scale = pywt.frequency2scale('morl', freq/sample_rate)
# Calculate the wavelet coefficients
coeffs, freq = pywt.cwt(signal, scale, wavelet, 1 / sample_rate)
# Calculate the power (magnitude squared) of the coefficients
power = np.abs(coeffs) ** 2
# Average the power across time to get the power spectral density
psd = np.mean(power, axis=1)
return psd
def get_psd(eeg, sfreq=5000, freq=10):
psd = []
for ii in tqdm.tqdm(range(eeg.shape[0])):
psd.append(morlet_psd(eeg[ii], sample_rate=sfreq, freq=freq))
return np.array(psd)