nabeelraza commited on
Commit
e9321a8
·
1 Parent(s): 661e23b

Add: included butter filter

Browse files
app.py CHANGED
@@ -3,10 +3,6 @@ from make_predictions import make_prediction
3
  import gradio as gr
4
 
5
  def predict(file):
6
- # print(file.name)
7
- # print(file)
8
- with open(file.name, 'rb') as f:
9
- print(f.read())
10
  print("=====================================")
11
  ans = make_prediction(file.name)
12
  print(ans)
 
3
  import gradio as gr
4
 
5
  def predict(file):
 
 
 
 
6
  print("=====================================")
7
  ans = make_prediction(file.name)
8
  print(ans)
baseline_wander_removal.py CHANGED
@@ -1,6 +1,6 @@
1
  import numpy as np
2
  from scipy.signal import medfilt
3
- # import matplotlib.pyplot as plt
4
  # import seaborn as sns
5
 
6
  # sns.set_theme()
@@ -80,3 +80,19 @@ def band_pass_filter(signal):
80
  # result = result/max_val
81
 
82
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from scipy.signal import medfilt
3
+ import matplotlib.pyplot as plt
4
  # import seaborn as sns
5
 
6
  # sns.set_theme()
 
80
  # result = result/max_val
81
 
82
  return result
83
+
84
+
85
+ if __name__ == '__main__':
86
+ signal = np.loadtxt("./Lead 2/n1_lead2.txt")[:1000]
87
+ signal_bs = bw_remover(BASIC_SRATE, signal)
88
+ signal_flt = band_pass_filter(signal_bs)
89
+ plt.figure(figsize=(16, 9))
90
+ plt.subplot(2, 1, 1)
91
+ plt.plot(signal)
92
+ plt.grid(True)
93
+ plt.title("RAW signal")
94
+ plt.subplot(2, 1, 2)
95
+ plt.plot(signal_flt)
96
+ plt.title("baseline removed signal")
97
+ plt.grid(True)
98
+ plt.show()
make_predictions.py CHANGED
@@ -1,22 +1,17 @@
1
  import numpy as np
2
- from read_data import prepare_all_leads
3
- from postprocessing import make_predictions_indi, labels_map
4
  from scipy import stats as st
5
  import warnings
6
 
7
  warnings.filterwarnings("ignore")
8
 
9
- def make_prediction(path):
10
- leads = prepare_all_leads(path)
11
- # print(leads[0][0].shape)
12
- # print(leads[1][0].shape)
13
- # print(leads[2][0].shape)
14
- # visualize_sig([leads[0][0], leads[1][0], leads[2][0]])
15
- x = make_predictions_indi(*leads)
16
- # print(x.mean(axis=0))
17
- # print(np.argmax(x, axis=1))
18
- index = st.mode(np.argmax(x, axis=1))[0][0]
19
- confidence = x.mean(axis=0)[index]
20
  return labels_map[index], float(confidence)
21
 
22
  if __name__ == "__main__":
 
1
  import numpy as np
2
+ from read_data import prepare_all_leads, visualize_sig
3
+ from postprocessing import predict_disease, labels_map
4
  from scipy import stats as st
5
  import warnings
6
 
7
  warnings.filterwarnings("ignore")
8
 
9
+ def make_prediction(path, visualize=False, butter_filter=True):
10
+ leads = prepare_all_leads(path, butter_filter=butter_filter)
11
+ visualize_sig([leads[0][0], leads[1][0], leads[2][0]]) if visualize else None
12
+ all, x, y, z = predict_disease(*leads)
13
+ index = st.mode(np.argmax(all, axis=1))[0][0]
14
+ confidence = all.mean(axis=0)[index]
 
 
 
 
 
15
  return labels_map[index], float(confidence)
16
 
17
  if __name__ == "__main__":
model/{ResNet-lead-0.pth → seven-diseases/ResNet-lead-0-BEST.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f573373d1fb5d86be738594aebd2ac0a42dde9c3358da27324277cc88ce0fd03
3
- size 2081879
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33a3ffa62adedac7bfeb52161baaff9ebd3532663ddb6b25818e97a76def0c75
3
+ size 2056539
model/{ResNet-lead-1.pth → seven-diseases/ResNet-lead-1-BEST.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e649f7c20301dfd614f1504ff8c23da284a1e422acbdbb74f63acbece25c3f97
3
- size 2081879
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7180c82754cd4f66f378204a29e9bd8b0b2c628495673114e68a55e0e31fac8
3
+ size 2056539
model/{ResNet-lead-2.pth → seven-diseases/ResNet-lead-2-BEST.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb6a7635539ee96d4dd6011afb46e06a94d54a49e2f70234cfbfb990f1149918
3
- size 2081879
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:167cd4161b3fac862bedfce0132e13cd4c3947834b89d26bfe55a49db319fe91
3
+ size 2056539
model_def.py CHANGED
@@ -1,28 +1,37 @@
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
-
5
-
6
  class NeuralNetwork(nn.Module):
7
  def __init__(self):
8
  super().__init__()
9
  n_filters = 64
10
  self.conv_1 = nn.Conv1d( 1, n_filters, 8, stride=1, padding='same')
 
11
  self.conv_2 = nn.Conv1d(n_filters, n_filters, 5, stride=1, padding='same')
 
12
  self.conv_3 = nn.Conv1d(n_filters, n_filters, 3, stride=1, padding='same')
 
13
  self.conv_4 = nn.Conv1d( 1, n_filters, 1, stride=1, padding='same') # Expanding for addition
 
14
 
15
  self.conv_5 = nn.Conv1d( n_filters, n_filters*2, 8, stride=1, padding='same')
 
16
  self.conv_6 = nn.Conv1d(n_filters*2, n_filters*2, 5, stride=1, padding='same')
 
17
  self.conv_7 = nn.Conv1d(n_filters*2, n_filters*2, 3, stride=1, padding='same')
 
18
  self.conv_8 = nn.Conv1d( n_filters, n_filters*2, 1, stride=1, padding='same')
19
-
 
20
  self.conv_9 = nn.Conv1d(n_filters*2, n_filters*2, 8, stride=1, padding='same')
 
21
  self.conv_10 = nn.Conv1d(n_filters*2, n_filters*2, 5, stride=1, padding='same')
 
22
  self.conv_11 = nn.Conv1d(n_filters*2, n_filters*2, 3, stride=1, padding='same')
23
- self.conv_12 = nn.Conv1d(n_filters*2, n_filters*2, 1, stride=1, padding='same')
24
-
25
- self.classifier = nn.Linear(128, 5)
 
26
  self.log_softmax = nn.LogSoftmax(dim=1)
27
 
28
  def forward(self, x):
@@ -30,38 +39,57 @@ class NeuralNetwork(nn.Module):
30
 
31
  # Block 1
32
  a = self.conv_1(x)
 
33
  a = F.relu(a)
 
34
  b = self.conv_2(a)
 
35
  b = F.relu(b)
 
36
  c = self.conv_3(b)
 
 
37
  shortcut = self.conv_4(x)
 
 
38
  output_1 = torch.add(c, shortcut)
39
  output_1 = F.relu(output_1)
40
 
41
  #Block 2
42
  a = self.conv_5(output_1)
 
43
  a = F.relu(a)
 
44
  b = self.conv_6(a)
 
45
  b = F.relu(b)
 
46
  c = self.conv_7(b)
 
47
  shortcut = self.conv_8(output_1)
 
 
48
  output_2 = torch.add(c, shortcut)
49
  output_2 = F.relu(output_2)
50
 
51
  #Block 3
52
  a = self.conv_9(output_2)
 
53
  a = F.relu(a)
 
54
  b = self.conv_10(a)
 
55
  b = F.relu(b)
 
56
  c = self.conv_11(b)
57
- shortcut = self.conv_12(output_2)
 
 
 
 
58
  output_3 = torch.add(c, shortcut)
59
  output_3 = F.relu(output_3)
60
- res = self.classifier(output_3.mean((2,)))
61
- logits = self.log_softmax(res)
62
- return logits
63
-
64
-
65
- if __name__ == '__main__':
66
- model = NeuralNetwork()
67
- print(model)
 
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
 
 
4
  class NeuralNetwork(nn.Module):
5
  def __init__(self):
6
  super().__init__()
7
  n_filters = 64
8
  self.conv_1 = nn.Conv1d( 1, n_filters, 8, stride=1, padding='same')
9
+ self.norm_1 = nn.BatchNorm1d(n_filters)
10
  self.conv_2 = nn.Conv1d(n_filters, n_filters, 5, stride=1, padding='same')
11
+ self.norm_2 = nn.BatchNorm1d(n_filters)
12
  self.conv_3 = nn.Conv1d(n_filters, n_filters, 3, stride=1, padding='same')
13
+ self.norm_3 = nn.BatchNorm1d(n_filters)
14
  self.conv_4 = nn.Conv1d( 1, n_filters, 1, stride=1, padding='same') # Expanding for addition
15
+ self.norm_4 = nn.BatchNorm1d(n_filters)
16
 
17
  self.conv_5 = nn.Conv1d( n_filters, n_filters*2, 8, stride=1, padding='same')
18
+ self.norm_5 = nn.BatchNorm1d(n_filters*2)
19
  self.conv_6 = nn.Conv1d(n_filters*2, n_filters*2, 5, stride=1, padding='same')
20
+ self.norm_6 = nn.BatchNorm1d(n_filters*2)
21
  self.conv_7 = nn.Conv1d(n_filters*2, n_filters*2, 3, stride=1, padding='same')
22
+ self.norm_7 = nn.BatchNorm1d(n_filters*2)
23
  self.conv_8 = nn.Conv1d( n_filters, n_filters*2, 1, stride=1, padding='same')
24
+ self.norm_8 = nn.BatchNorm1d(n_filters*2)
25
+
26
  self.conv_9 = nn.Conv1d(n_filters*2, n_filters*2, 8, stride=1, padding='same')
27
+ self.norm_9 = nn.BatchNorm1d(n_filters*2)
28
  self.conv_10 = nn.Conv1d(n_filters*2, n_filters*2, 5, stride=1, padding='same')
29
+ self.norm_10 = nn.BatchNorm1d(n_filters*2)
30
  self.conv_11 = nn.Conv1d(n_filters*2, n_filters*2, 3, stride=1, padding='same')
31
+ self.norm_11 = nn.BatchNorm1d(n_filters*2)
32
+ # self.conv_12 = nn.Conv1d(n_filters*2, n_filters*2, 1, stride=1, padding='same')
33
+ self.norm_12 = nn.BatchNorm1d(n_filters*2)
34
+ self.classifier = nn.Linear(128, 7)
35
  self.log_softmax = nn.LogSoftmax(dim=1)
36
 
37
  def forward(self, x):
 
39
 
40
  # Block 1
41
  a = self.conv_1(x)
42
+ a = self.norm_1(a)
43
  a = F.relu(a)
44
+
45
  b = self.conv_2(a)
46
+ b = self.norm_2(b)
47
  b = F.relu(b)
48
+
49
  c = self.conv_3(b)
50
+ c = self.norm_3(c)
51
+
52
  shortcut = self.conv_4(x)
53
+ shortcut = self.norm_4(shortcut)
54
+
55
  output_1 = torch.add(c, shortcut)
56
  output_1 = F.relu(output_1)
57
 
58
  #Block 2
59
  a = self.conv_5(output_1)
60
+ a = self.norm_5(a)
61
  a = F.relu(a)
62
+
63
  b = self.conv_6(a)
64
+ b = self.norm_6(b)
65
  b = F.relu(b)
66
+
67
  c = self.conv_7(b)
68
+ c = self.norm_7(c)
69
  shortcut = self.conv_8(output_1)
70
+ shortcut = self.norm_8(shortcut)
71
+
72
  output_2 = torch.add(c, shortcut)
73
  output_2 = F.relu(output_2)
74
 
75
  #Block 3
76
  a = self.conv_9(output_2)
77
+ a = self.norm_9(a)
78
  a = F.relu(a)
79
+
80
  b = self.conv_10(a)
81
+ b = self.norm_10(b)
82
  b = F.relu(b)
83
+
84
  c = self.conv_11(b)
85
+ c = self.norm_11(c)
86
+
87
+ # shortcut = self.conv_12(output_2)
88
+ shortcut = self.norm_12(shortcut)
89
+
90
  output_3 = torch.add(c, shortcut)
91
  output_3 = F.relu(output_3)
92
+
93
+ logits = self.classifier(output_3.mean((2,)))
94
+ res = self.log_softmax(logits)
95
+ return res
 
 
 
 
postprocessing.py CHANGED
@@ -4,23 +4,33 @@ from torch import nn
4
  from model_def import NeuralNetwork
5
 
6
  labels_map = {
7
- 0 : "atrial fibrillation",
8
- 1 : "sinus arrhythmia",
9
- 2 : "bradycardia",
10
- 3 : "1st degree av block",
11
- 4 : "sinus rhythm",
 
 
12
  }
13
 
14
- PATH = 'model/' #ResNet-lead-0.pth'
 
 
 
 
 
 
 
 
15
 
16
  lead_1_model = NeuralNetwork()
17
- lead_1_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-0.pth", map_location=torch.device('cpu')))
18
 
19
  lead_2_model = NeuralNetwork()
20
- lead_2_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-1.pth", map_location=torch.device('cpu')))
21
 
22
  lead_3_model = NeuralNetwork()
23
- lead_3_model.load_state_dict(torch.load(f"{PATH}/ResNet-lead-2.pth", map_location=torch.device('cpu')))
24
 
25
  def helper(sig, model):
26
  inpt = sig[:, np.newaxis, :]
@@ -30,9 +40,14 @@ def helper(sig, model):
30
  prediction_scores = res.numpy()
31
  return prediction_scores
32
 
33
- def make_predictions_indi(lead1, lead2, lead3):
34
  p1 = helper(lead1, lead_1_model)
35
  p2 = helper(lead2, lead_2_model)
36
  p3 = helper(lead3, lead_3_model)
 
 
 
37
  p_avg = (p1 + p2 + p3)/3
38
- return p_avg
 
 
 
4
  from model_def import NeuralNetwork
5
 
6
  labels_map = {
7
+ 0: 'sinus_rhythm',
8
+ 1: 'atrial_fibrillation',
9
+ 2: 'av_block',
10
+ 3: 'bradycardia',
11
+ 4: 'sinus_arrhythmia',
12
+ 5: 'sinus_rhythm-sinus_arrhythmia',
13
+ 6: 'sinus_rhythm-av_block'
14
  }
15
 
16
+ PATH = 'model/seven-diseases/' #ResNet-lead-0.pth'
17
+ lead_1 = f"{PATH}/ResNet-lead-0-BEST.pth"
18
+ lead_2 = f"{PATH}/ResNet-lead-1-BEST.pth"
19
+ lead_3 = f"{PATH}/ResNet-lead-2-BEST.pth"
20
+
21
+ # PATH = 'model/' #ResNet-lead-0.pth'
22
+ # lead_1 = f"{PATH}/ResNet-lead-0.pth"
23
+ # lead_2 = f"{PATH}/ResNet-lead-1.pth"
24
+ # lead_3 = f"{PATH}/ResNet-lead-2.pth"
25
 
26
  lead_1_model = NeuralNetwork()
27
+ lead_1_model.load_state_dict(torch.load(lead_1, map_location=torch.device('cpu')))
28
 
29
  lead_2_model = NeuralNetwork()
30
+ lead_2_model.load_state_dict(torch.load(lead_2, map_location=torch.device('cpu')))
31
 
32
  lead_3_model = NeuralNetwork()
33
+ lead_3_model.load_state_dict(torch.load(lead_3, map_location=torch.device('cpu')))
34
 
35
  def helper(sig, model):
36
  inpt = sig[:, np.newaxis, :]
 
40
  prediction_scores = res.numpy()
41
  return prediction_scores
42
 
43
+ def predict_disease(lead1, lead2, lead3):
44
  p1 = helper(lead1, lead_1_model)
45
  p2 = helper(lead2, lead_2_model)
46
  p3 = helper(lead3, lead_3_model)
47
+ # print(p1.argmax(axis=1))
48
+ # print(p2.argmax(axis=1))
49
+ # print(p3.argmax(axis=1))
50
  p_avg = (p1 + p2 + p3)/3
51
+ # print(p_avg.argmax(axis=1))
52
+ # print("-----------------")
53
+ return p_avg, p1, p2, p3
read_data.py CHANGED
@@ -1,53 +1,86 @@
1
  import numpy as np
2
- from scipy.signal import resample
3
  from baseline_wander_removal import bw_remover
4
  import pywt
5
- # import matplotlib.pyplot as plt
6
- import codecs
7
 
8
- def normalize(sig):
9
- return 2*((sig-np.min(sig))/(np.max(sig)-np.min(sig)))
10
 
11
- def prepare_all_leads(path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if path.endswith(".txt"):
13
- sig = np.loadtxt(path, delimiter=',', unpack=True)
14
  elif path.endswith(".npy"):
15
  sig = np.load(path, allow_pickle=True)
16
  x = pywt.wavedec(sig[0], 'db6', level=2)[0]
17
  y = pywt.wavedec(sig[1], 'db6', level=2)[0]
18
  z = pywt.wavedec(sig[2], 'db6', level=2)[0]
19
- # visualize_sig([x, y, z])
20
  return x[None, :], y[None, :], z[None, :]
21
- # return sig
22
- freq = sig.shape[1] // 60
 
 
 
 
 
23
  sig[0] = bw_remover(freq, sig[0])
24
  sig[1] = bw_remover(freq, sig[1])
25
  sig[2] = bw_remover(freq, sig[2])
26
 
27
- # visualize_sig(sig)
 
 
 
 
 
 
 
28
  sig_length = freq*2
29
- total_samples = sig[0].shape[0] // sig_length
30
 
31
  lead_1 = []
32
  lead_2 = []
33
  lead_3 = []
34
  for i in range(total_samples):
35
  x = sig[0][i*sig_length:(i+1)*sig_length]
36
- x = pywt.wavedec(x, 'db6', level=1)[0]
37
- x = resample(x, 259)
38
- x = normalize(x)
39
- lead_1.append(x)
40
-
41
  y = sig[1][i*sig_length:(i+1)*sig_length]
42
- y = pywt.wavedec(y, 'db6', level=1)[0]
43
- y = resample(y, 259)
44
- y = normalize(y)
45
- lead_2.append(y)
46
-
47
  z = sig[2][i*sig_length:(i+1)*sig_length]
48
- z = pywt.wavedec(z, 'db6', level=1)[0]
49
- z = resample(z, 259)
50
- z = normalize(z)
 
 
 
 
51
  lead_3.append(z)
52
-
53
  return np.asarray(lead_1), np.asarray(lead_2), np.asarray(lead_3)
 
 
 
 
 
 
1
  import numpy as np
2
+ from scipy.signal import resample, butter, filtfilt
3
  from baseline_wander_removal import bw_remover
4
  import pywt
5
+ import matplotlib.pyplot as plt
 
6
 
7
+ def normalize(sig, val=2):
8
+ return val*((sig-np.min(sig))/(np.max(sig)-np.min(sig)))
9
 
10
+ def butter_lowpass_filter(data, cutoff, fs, order):
11
+ nyq = 0.5 * fs
12
+ normal_cutoff = cutoff / nyq
13
+ # Get the filter coefficients
14
+ b, a = butter(order, normal_cutoff, btype='low', analog=False)
15
+ y = filtfilt(b, a, data)
16
+ return y
17
+
18
+ def visualize_sig(sig, filename="test.png"):
19
+ fig, ax = plt.subplots(3, 1, figsize=(10, 10))
20
+ ax[0].plot(sig[0])
21
+ ax[1].plot(sig[1])
22
+ ax[2].plot(sig[2])
23
+ # plt.show()
24
+ plt.savefig(filename)
25
+
26
+ def preprocess_one_chunk(chunk, lvl=2):
27
+ chunk = resample(chunk, 1000)
28
+ chunk = pywt.wavedec(chunk, 'db6', level=lvl)[0]
29
+ # print(x.shape)
30
+ # x = pad(x)
31
+ chunk = normalize(chunk)
32
+ return chunk
33
+
34
+ def prepare_all_leads(path, butter_filter=False):
35
  if path.endswith(".txt"):
36
+ signal = np.loadtxt(path, delimiter=',', unpack=True)
37
  elif path.endswith(".npy"):
38
  sig = np.load(path, allow_pickle=True)
39
  x = pywt.wavedec(sig[0], 'db6', level=2)[0]
40
  y = pywt.wavedec(sig[1], 'db6', level=2)[0]
41
  z = pywt.wavedec(sig[2], 'db6', level=2)[0]
 
42
  return x[None, :], y[None, :], z[None, :]
43
+ freq = signal.shape[1] // 60
44
+ if freq < 250:
45
+ lvl = 1
46
+ else:
47
+ lvl = 2
48
+ sig = [signal[0], signal[1], signal[2]]
49
+
50
  sig[0] = bw_remover(freq, sig[0])
51
  sig[1] = bw_remover(freq, sig[1])
52
  sig[2] = bw_remover(freq, sig[2])
53
 
54
+ if butter_filter:
55
+ cutoff = 20 # desired cutoff frequency of the filter, Hz , slightly higher than actual 1.2 Hz
56
+ order = 1 # sin wave can be approx represented as quadratic
57
+
58
+ sig[0] = butter_lowpass_filter(sig[0], cutoff, freq, order)
59
+ sig[1] = butter_lowpass_filter(sig[1], cutoff, freq, order)
60
+ sig[2] = butter_lowpass_filter(sig[2], cutoff, freq, order)
61
+
62
  sig_length = freq*2
63
+ total_samples = sig[0].shape[0] // 1000
64
 
65
  lead_1 = []
66
  lead_2 = []
67
  lead_3 = []
68
  for i in range(total_samples):
69
  x = sig[0][i*sig_length:(i+1)*sig_length]
 
 
 
 
 
70
  y = sig[1][i*sig_length:(i+1)*sig_length]
 
 
 
 
 
71
  z = sig[2][i*sig_length:(i+1)*sig_length]
72
+
73
+ x = preprocess_one_chunk(x, lvl=lvl)
74
+ y = preprocess_one_chunk(y, lvl=lvl)
75
+ z = preprocess_one_chunk(z, lvl=lvl)
76
+
77
+ lead_1.append(x)
78
+ lead_2.append(y)
79
  lead_3.append(z)
80
+
81
  return np.asarray(lead_1), np.asarray(lead_2), np.asarray(lead_3)
82
+
83
+
84
+ def pad(sig):
85
+ sig = np.pad(sig, (0, 258-sig.shape[0]), 'constant')
86
+ return sig