File size: 2,721 Bytes
ff522d1
e9321a8
ff522d1
 
e9321a8
ff522d1
e9321a8
 
ff522d1
e9321a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff522d1
e9321a8
ff522d1
 
 
 
 
 
e9321a8
0cdd4d6
e9321a8
 
 
 
 
 
ff522d1
 
 
 
e9321a8
 
 
 
 
 
 
 
ff522d1
e9321a8
ff522d1
 
 
 
 
 
 
 
e9321a8
 
 
 
 
 
 
ff522d1
e9321a8
ff522d1
e9321a8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from scipy.signal import resample, butter, filtfilt
from baseline_wander_removal import bw_remover
import pywt
import matplotlib.pyplot as plt

def normalize(sig, val=2):
    return val*((sig-np.min(sig))/(np.max(sig)-np.min(sig)))

def butter_lowpass_filter(data, cutoff, fs, order):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

def visualize_sig(sig, filename="test.png"):
    fig, ax = plt.subplots(3, 1, figsize=(10, 10))
    ax[0].plot(sig[0])
    ax[1].plot(sig[1])
    ax[2].plot(sig[2])
    # plt.show()
    plt.savefig(filename)

def preprocess_one_chunk(chunk, lvl=2):
    chunk = resample(chunk, 1000)
    chunk = pywt.wavedec(chunk, 'db6', level=lvl)[0]
    # print(x.shape)
    # x = pad(x)
    chunk = normalize(chunk)
    return chunk

def prepare_all_leads(path, butter_filter=False):
    if path.endswith(".txt"):
        signal = np.loadtxt(path, delimiter=',', unpack=True)
    elif path.endswith(".npy"):
        sig = np.load(path, allow_pickle=True)
        x = pywt.wavedec(sig[0], 'db6', level=2)[0]
        y = pywt.wavedec(sig[1], 'db6', level=2)[0]
        z = pywt.wavedec(sig[2], 'db6', level=2)[0]
        return x[None, :], y[None, :], z[None, :]
    freq = signal.shape[1] // 60
    print(freq)
    if freq < 250:
        lvl = 1
    else:
        lvl = 2
    sig = [signal[0], signal[1], signal[2]]

    sig[0] = bw_remover(freq, sig[0])
    sig[1] = bw_remover(freq, sig[1])
    sig[2] = bw_remover(freq, sig[2])

    if butter_filter:
        cutoff = 20      # desired cutoff frequency of the filter, Hz ,      slightly higher than actual 1.2 Hz
        order = 1       # sin wave can be approx represented as quadratic

        sig[0] = butter_lowpass_filter(sig[0], cutoff, freq, order)
        sig[1] = butter_lowpass_filter(sig[1], cutoff, freq, order)
        sig[2] = butter_lowpass_filter(sig[2], cutoff, freq, order)

    sig_length = freq*2
    total_samples = sig[0].shape[0] // 1000

    lead_1 = []
    lead_2 = []
    lead_3 = []
    for i in range(total_samples):
        x = sig[0][i*sig_length:(i+1)*sig_length]
        y = sig[1][i*sig_length:(i+1)*sig_length]
        z = sig[2][i*sig_length:(i+1)*sig_length]
        
        x = preprocess_one_chunk(x, lvl=lvl)
        y = preprocess_one_chunk(y, lvl=lvl)
        z = preprocess_one_chunk(z, lvl=lvl)
        
        lead_1.append(x)
        lead_2.append(y)
        lead_3.append(z)

    return np.asarray(lead_1), np.asarray(lead_2), np.asarray(lead_3)


def pad(sig):
    sig = np.pad(sig, (0, 258-sig.shape[0]), 'constant')
    return sig