|
from os.path import * |
|
import numpy as np |
|
import random |
|
import torch |
|
import torch.nn as nn |
|
import time |
|
import tqdm |
|
from scipy.signal import butter, sosfilt |
|
from .unet import UNet1d |
|
|
|
|
|
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): |
|
nyq = 0.5 * fs |
|
low = lowcut / nyq |
|
high = highcut / nyq |
|
sos = butter(order, [low, high], analog=False, btype="band", output="sos") |
|
y = sosfilt(sos, data) |
|
return y |
|
|
|
|
|
def norm(ecg): |
|
min1, max1 = np.percentile(ecg, [1, 99]) |
|
ecg[ecg > max1] = max1 |
|
ecg[ecg < min1] = min1 |
|
ecg = (ecg - min1) / (max1 - min1) |
|
return ecg |
|
|
|
|
|
def run( |
|
input_eeg, |
|
input_ecg=None, |
|
sfreq=5000, |
|
iter_num=5000, |
|
winsize_sec=2, |
|
lr=1e-3, |
|
onecycle=True, |
|
): |
|
window = winsize_sec * sfreq |
|
eeg_raw = input_eeg |
|
eeg_channel = eeg_raw.shape[0] |
|
|
|
eeg_filtered = eeg_raw * 0 |
|
t = time.time() |
|
for ii in range(eeg_channel): |
|
eeg_filtered[ii, ...] = butter_bandpass_filter( |
|
eeg_raw[ii, :], 0.5, sfreq * 0.4, sfreq |
|
) |
|
|
|
baseline = eeg_raw - eeg_filtered |
|
|
|
if input_ecg is None: |
|
from sklearn.decomposition import PCA |
|
|
|
pca = PCA(n_components=1) |
|
ecg = norm(pca.fit_transform(eeg_filtered.T)[:, 0].flatten()) |
|
else: |
|
ecg = norm(input_ecg.flatten()) |
|
|
|
torch.cuda.empty_cache() |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
NET = UNet1d(n_channels=1, n_classes=eeg_channel, nfilter=8).to(device) |
|
optimizer = torch.optim.Adam(NET.parameters(), lr=lr) |
|
optimizer.zero_grad() |
|
maxlen = ecg.size |
|
if onecycle: |
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, lr, total_steps=iter_num |
|
) |
|
else: |
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1) |
|
|
|
loss_list = [] |
|
|
|
|
|
|
|
index_all = (np.random.random_sample(iter_num) * (maxlen - window)).astype(int) |
|
|
|
pbar = tqdm.tqdm(index_all) |
|
count = 0 |
|
for index in pbar: |
|
count += 1 |
|
ECG = ecg[index : (index + window)] |
|
EEG = eeg_filtered[:, index : (index + window)] |
|
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float() |
|
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float() |
|
|
|
|
|
logits = NET(ECG_d) |
|
loss = nn.MSELoss()(logits, EEG_d) |
|
loss_list.append(loss.item()) |
|
|
|
|
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
|
|
if count % 50 == 0: |
|
pbar.set_description( |
|
f"Loss {np.mean(loss_list):.3f}, lr: {optimizer.param_groups[0]['lr']:.5f}" |
|
) |
|
loss_list = [] |
|
|
|
EEG = eeg_filtered |
|
|
|
ECG = ecg |
|
ECG_d = torch.from_numpy(ECG[None, ...][None, ...]).to(device).float() |
|
EEG_d = torch.from_numpy(EEG[None, ...]).to(device).float() |
|
with torch.no_grad(): |
|
logits = NET(ECG_d) |
|
BCG_pred = logits.cpu().detach().numpy()[0, ...] |
|
|
|
neweeg = EEG - BCG_pred + baseline |
|
|
|
return neweeg |
|
|
|
|
|
def morlet_psd(signal, sample_rate=5000, freq=10, wavelet="morl"): |
|
import pywt |
|
|
|
|
|
|
|
scales = np.arange(sample_rate) |
|
freqs = pywt.scale2frequency("morl", scales) * sample_rate |
|
indx = np.argmin(abs(freqs - freq)) |
|
|
|
scale = scales[indx] |
|
|
|
|
|
|
|
|
|
coeffs, freq = pywt.cwt(signal, scale, wavelet, 1 / sample_rate) |
|
|
|
power = np.abs(coeffs) ** 2 |
|
|
|
|
|
psd = np.mean(power, axis=1) |
|
|
|
return psd |
|
|
|
|
|
def get_psd(eeg, sfreq=5000, freq=10): |
|
psd = [] |
|
for ii in tqdm.tqdm(range(eeg.shape[0])): |
|
psd.append(morlet_psd(eeg[ii], sample_rate=sfreq, freq=freq)) |
|
|
|
return np.array(psd) |
|
|