Upload 14 files
Browse files- configs/__pycache__/data_configs.cpython-312.pyc +0 -0
- configs/__pycache__/hparams.cpython-312.pyc +0 -0
- configs/data_configs.py +58 -0
- configs/hparams.py +17 -0
- data/__pycache__/dataloader.cpython-312.pyc +0 -0
- data/__pycache__/dataset.cpython-312.pyc +0 -0
- data/dataloader.py +56 -0
- data/dataset.py +39 -0
- main.py +35 -0
- models/MoE_ECGFormer.py +144 -0
- models/__pycache__/MoE_ECGFormer.cpython-312.pyc +0 -0
- test.py +3 -0
- train.py +192 -0
- utils.py +204 -0
configs/__pycache__/data_configs.cpython-312.pyc
ADDED
Binary file (2.05 kB). View file
|
|
configs/__pycache__/hparams.cpython-312.pyc
ADDED
Binary file (1.09 kB). View file
|
|
configs/data_configs.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_dataset_class(dataset_name):
|
2 |
+
"""
|
3 |
+
:param dataset_name:
|
4 |
+
:return: The algorithm class with the given name
|
5 |
+
"""
|
6 |
+
if dataset_name not in globals():
|
7 |
+
raise NotImplementedError("Dataset not found : {}".format(dataset_name))
|
8 |
+
return globals()[dataset_name]
|
9 |
+
|
10 |
+
|
11 |
+
class mit:
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
# data params
|
15 |
+
self.num_classes = 5
|
16 |
+
self.class_names = ['N', 'S', 'V', 'F', 'Q']
|
17 |
+
self.sequence_len = 186
|
18 |
+
|
19 |
+
# model configs
|
20 |
+
self.input_channels = 1
|
21 |
+
self.kernel_size = 8
|
22 |
+
self.stride = 1
|
23 |
+
self.dropout = 0.2
|
24 |
+
|
25 |
+
# features
|
26 |
+
self.mid_channels = 32
|
27 |
+
self.final_out_channels = 128
|
28 |
+
|
29 |
+
# MoE Transformer
|
30 |
+
self.trans_dim = 25
|
31 |
+
self.num_heads = 5
|
32 |
+
self.num_experts = 8
|
33 |
+
|
34 |
+
|
35 |
+
class ptb:
|
36 |
+
def __init__(self):
|
37 |
+
super().__init__()
|
38 |
+
# data params
|
39 |
+
self.num_classes = 2
|
40 |
+
self.class_names = ['normal', 'abnormal']
|
41 |
+
self.sequence_len = 188
|
42 |
+
|
43 |
+
# model_configs
|
44 |
+
self.input_channels = 1
|
45 |
+
self.kernel_size = 8
|
46 |
+
self.stride = 1
|
47 |
+
self.dropout = 0.2
|
48 |
+
|
49 |
+
# features
|
50 |
+
self.mid_channels = 32
|
51 |
+
self.final_out_channels = 128
|
52 |
+
|
53 |
+
# MoE Transformer
|
54 |
+
self.trans_dim = 25
|
55 |
+
self.num_heads = 5
|
56 |
+
self.num_experts = 8
|
57 |
+
|
58 |
+
|
configs/hparams.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_hparams_class(dataset_name):
|
2 |
+
"""Return the algorithm class with the given name."""
|
3 |
+
if dataset_name not in globals():
|
4 |
+
raise NotImplementedError("Algorithm not found: {}".format(dataset_name))
|
5 |
+
return globals()[dataset_name]
|
6 |
+
|
7 |
+
|
8 |
+
class Supervised:
|
9 |
+
def __init__(self):
|
10 |
+
super(Supervised, self).__init__()
|
11 |
+
self.train_params = {
|
12 |
+
'num_epochs': 100,
|
13 |
+
'batch_size': 128,
|
14 |
+
'weight_decay': 1e-4,
|
15 |
+
'learning_rate': 1e-3,
|
16 |
+
'feature_dim': 1*128
|
17 |
+
}
|
data/__pycache__/dataloader.cpython-312.pyc
ADDED
Binary file (3.6 kB). View file
|
|
data/__pycache__/dataset.cpython-312.pyc
ADDED
Binary file (1.86 kB). View file
|
|
data/dataloader.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from data.dataset import load_ECG_Dataset
|
4 |
+
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
def get_class_weight(labels_dict):
|
11 |
+
total = sum(labels_dict.values())
|
12 |
+
max_num = max(labels_dict.values())
|
13 |
+
mu = 1.0 / (total / max_num)
|
14 |
+
class_weight = dict()
|
15 |
+
for key, value in labels_dict.items():
|
16 |
+
score = math.log(mu * total / float(value))
|
17 |
+
class_weight[key] = score if score > 1.0 else 1.0
|
18 |
+
return class_weight
|
19 |
+
|
20 |
+
|
21 |
+
class ECGDataloader:
|
22 |
+
testdata_path: str
|
23 |
+
traindata_path: str
|
24 |
+
valdata_path: str
|
25 |
+
|
26 |
+
def __init__(self, data_path, data_type, hparams):
|
27 |
+
self.traindata_path = os.path.join(data_path, data_type, f'train.pt')
|
28 |
+
self.testdata_path = os.path.join(data_path, data_type, f'test.pt')
|
29 |
+
self.validdata_path = os.path.join(data_path, data_type, f'val.pt')
|
30 |
+
self.batch_size = hparams['batch_size']
|
31 |
+
|
32 |
+
def train_dataloader(self):
|
33 |
+
train_dataset = torch.load(self.traindata_path)
|
34 |
+
train_dataset = load_ECG_Dataset(train_dataset)
|
35 |
+
cw = train_dataset.y_data.numpy().tolist()
|
36 |
+
cw_dict = {}
|
37 |
+
for i in range(len(np.unique(train_dataset.y_data.numpy()))):
|
38 |
+
cw_dict[i] = cw.count(i)
|
39 |
+
|
40 |
+
train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True,
|
41 |
+
drop_last=True, num_workers=4)
|
42 |
+
return train_loader, get_class_weight(cw_dict)
|
43 |
+
|
44 |
+
def test_dataloader(self):
|
45 |
+
test_dataset = torch.load(self.testdata_path)
|
46 |
+
test_dataset = load_ECG_Dataset(test_dataset)
|
47 |
+
test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False,
|
48 |
+
drop_last=False, num_workers=4)
|
49 |
+
return test_loader
|
50 |
+
|
51 |
+
def valid_dataloader(self):
|
52 |
+
valid_dataset = torch.load(self.validdata_path)
|
53 |
+
valid_dataset = load_ECG_Dataset(valid_dataset)
|
54 |
+
valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True,
|
55 |
+
drop_last=False, num_workers=4)
|
56 |
+
return valid_loader
|
data/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class load_ECG_Dataset(Dataset):
|
8 |
+
# Initialize dataset
|
9 |
+
def __init__(self, dataset):
|
10 |
+
|
11 |
+
# Load sample
|
12 |
+
x_data = dataset["samples"]
|
13 |
+
# Convert to pytorch tensor
|
14 |
+
if isinstance(x_data, np.ndarray):
|
15 |
+
x_data = torch.from_numpy(x_data)
|
16 |
+
|
17 |
+
# Load labels
|
18 |
+
y_data = dataset.get("labels")
|
19 |
+
if y_data is not None and isinstance(y_data, np.ndarray):
|
20 |
+
y_data = torch.from_numpy(y_data)
|
21 |
+
|
22 |
+
self.x_data = x_data.float()
|
23 |
+
self.y_data = y_data.long() if y_data is not None else None
|
24 |
+
|
25 |
+
self.len = x_data.shape[0]
|
26 |
+
|
27 |
+
def get_labels(self):
|
28 |
+
return self.y_data
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
sample = {
|
32 |
+
'samples': self.x_data[idx].squeeze(-1),
|
33 |
+
'labels': int(self.y_data[idx])
|
34 |
+
}
|
35 |
+
return sample
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return self.len
|
39 |
+
|
main.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import warnings
|
4 |
+
from train import Trainer
|
5 |
+
import sklearn.exceptions
|
6 |
+
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
|
10 |
+
# ======== Experiments Name ================
|
11 |
+
parser.add_argument('--save_dir', default='experiments_logs', type=str,
|
12 |
+
help='Directory containing all experiments')
|
13 |
+
parser.add_argument('--experiment_description', default='Exp1', type=str, help='experiment name')
|
14 |
+
parser.add_argument('--run_description', default='run1', type=str, help='run name')
|
15 |
+
|
16 |
+
# ========= Select the DATASET ==============
|
17 |
+
parser.add_argument('--dataset', default='mit', type=str, help='mit, ptb')
|
18 |
+
parser.add_argument('--seed_id', default='0', type=str,
|
19 |
+
help='to fix a seed while training')
|
20 |
+
|
21 |
+
# ========= Experiment settings ===============
|
22 |
+
parser.add_argument('--data_path', default=r'/Users/splendor1811/datn/ECGTransForm/datasets',
|
23 |
+
type=str, help='Path containing dataset')
|
24 |
+
|
25 |
+
parser.add_argument('--num_runs', default=1, type=int,
|
26 |
+
help='Number of consecutive run with different seeds')
|
27 |
+
parser.add_argument('--device', default='mps', type=str,
|
28 |
+
help='cpu or cuda')
|
29 |
+
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
if __name__ == "__main__":
|
34 |
+
trainer = Trainer(args)
|
35 |
+
trainer.train()
|
models/MoE_ECGFormer.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class MoE_ECGFormer(nn.Module):
|
6 |
+
def __init__(self, configs, hparams):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
filter_sizes = [5, 9, 11]
|
10 |
+
self.conv1 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[0],
|
11 |
+
stride=configs.stride, bias=False, padding=(filter_sizes[0] // 2))
|
12 |
+
|
13 |
+
self.conv2 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[1],
|
14 |
+
stride=configs.stride, bias=False, padding=(filter_sizes[1] // 2))
|
15 |
+
|
16 |
+
self.conv3 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[2],
|
17 |
+
stride=configs.stride, bias=False, padding=(filter_sizes[2] // 2))
|
18 |
+
|
19 |
+
self.bn = nn.BatchNorm1d(configs.mid_channels)
|
20 |
+
self.relu = nn.ReLU()
|
21 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
|
22 |
+
self.dropout = nn.Dropout(configs.dropout)
|
23 |
+
|
24 |
+
self.conv_block2 = nn.Sequential(
|
25 |
+
nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False,
|
26 |
+
padding=4),
|
27 |
+
nn.BatchNorm1d(configs.mid_channels * 2),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.conv_block3 = nn.Sequential(
|
33 |
+
nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
|
34 |
+
padding=4),
|
35 |
+
nn.BatchNorm1d(configs.final_out_channels),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
|
38 |
+
)
|
39 |
+
|
40 |
+
self.inplanes = 128
|
41 |
+
self.crm = self._make_layer(SEBasicBlock, 128, 3)
|
42 |
+
|
43 |
+
# Transformer_layer
|
44 |
+
self.encoder_layer = nn.TransformerEncoderLayer(d_model=configs.trans_dim,
|
45 |
+
nhead=configs.num_heads,
|
46 |
+
batch_first=True)
|
47 |
+
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
|
48 |
+
self.aap = nn.AdaptiveAvgPool1d(1)
|
49 |
+
self.clf = nn.Linear(hparams['feature_dim'], configs.num_classes)
|
50 |
+
|
51 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
52 |
+
downsample = None
|
53 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
54 |
+
downsample = nn.Sequential(
|
55 |
+
nn.Conv1d(self.inplanes, planes * block.expansion,
|
56 |
+
kernel_size=1, stride=stride, bias=False),
|
57 |
+
nn.BatchNorm1d(planes * block.expansion),
|
58 |
+
)
|
59 |
+
|
60 |
+
layers = [block(self.inplanes, planes, stride, downsample)]
|
61 |
+
self.inplanes = planes * block.expansion
|
62 |
+
for i in range(1, blocks):
|
63 |
+
layers.append(block(self.inplanes, planes))
|
64 |
+
|
65 |
+
return nn.Sequential(*layers)
|
66 |
+
|
67 |
+
def forward(self, x_in):
|
68 |
+
|
69 |
+
# Multi-scale Convolutions
|
70 |
+
x1 = self.conv1(x_in)
|
71 |
+
x2 = self.conv2(x_in)
|
72 |
+
x3 = self.conv3(x_in)
|
73 |
+
|
74 |
+
x_concat = torch.mean(torch.stack([x1, x2, x3], dim=2), dim=2)
|
75 |
+
x_concat = self.dropout(self.maxpool(self.relu(self.bn(x_concat))))
|
76 |
+
|
77 |
+
x = self.conv_block2(x_concat)
|
78 |
+
x = self.conv_block3(x)
|
79 |
+
|
80 |
+
# Channel Recalibration Module
|
81 |
+
x = self.crm(x)
|
82 |
+
|
83 |
+
# Bidirectional MoE Transformer
|
84 |
+
x1 = self.transformer_encoder(x)
|
85 |
+
x2 = self.transformer_encoder(torch.flip(x, [2]))
|
86 |
+
x = x1 + x2
|
87 |
+
|
88 |
+
x = self.aap(x)
|
89 |
+
x_flat = x.reshape(x.size(0), -1)
|
90 |
+
x_out = self.clf(x_flat)
|
91 |
+
return x_out
|
92 |
+
|
93 |
+
|
94 |
+
class SELayer(nn.Module):
|
95 |
+
def __init__(self, channel, reduction=4):
|
96 |
+
super(SELayer, self).__init__()
|
97 |
+
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
98 |
+
self.fc = nn.Sequential(
|
99 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
100 |
+
nn.ReLU(inplace=True),
|
101 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
102 |
+
nn.Sigmoid()
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
b, c, _ = x.size()
|
107 |
+
y = self.avg_pool(x).view(b, c)
|
108 |
+
y = self.fc(y).view(b, c, 1)
|
109 |
+
return x * y.expand_as(x)
|
110 |
+
|
111 |
+
|
112 |
+
class SEBasicBlock(nn.Module):
|
113 |
+
expansion = 1
|
114 |
+
|
115 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
116 |
+
base_width=64, dilation=1, norm_layer=None,
|
117 |
+
*, reduction=4):
|
118 |
+
super(SEBasicBlock, self).__init__()
|
119 |
+
self.conv1 = nn.Conv1d(inplanes, planes, stride)
|
120 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
121 |
+
self.relu = nn.ReLU(inplace=True)
|
122 |
+
self.conv2 = nn.Conv1d(planes, planes, 1)
|
123 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
124 |
+
self.se = SELayer(planes, reduction)
|
125 |
+
self.downsample = downsample
|
126 |
+
self.stride = stride
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
residual = x
|
130 |
+
out = self.conv1(x)
|
131 |
+
out = self.bn1(out)
|
132 |
+
out = self.relu(out)
|
133 |
+
|
134 |
+
out = self.conv2(out)
|
135 |
+
out = self.bn2(out)
|
136 |
+
out = self.se(out)
|
137 |
+
|
138 |
+
if self.downsample is not None:
|
139 |
+
residual = self.downsample(x)
|
140 |
+
|
141 |
+
out += residual
|
142 |
+
out = self.relu(out)
|
143 |
+
|
144 |
+
return out
|
models/__pycache__/MoE_ECGFormer.cpython-312.pyc
ADDED
Binary file (8.74 kB). View file
|
|
test.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from deepface import DeepFace
|
2 |
+
|
3 |
+
print(DeepFace.verify('/Users/splendor1811/Desktop/a1.jpg' ,'/Users/splendor1811/Desktop/a1.jpg'))
|
train.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.MoE_ECGFormer import MoE_ECGFormer
|
2 |
+
from data.dataloader import ECGDataloader
|
3 |
+
from configs.data_configs import get_dataset_class
|
4 |
+
from configs.hparams import get_hparams_class
|
5 |
+
from utils import AverageMeter, to_device, _save_metrics, copy_files
|
6 |
+
from utils import fix_randomness, starting_logs, save_checkpoint, _calc_metrics
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import datetime
|
10 |
+
import os
|
11 |
+
import collections
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
import sklearn.exceptions
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
|
18 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
19 |
+
|
20 |
+
|
21 |
+
class Trainer(object):
|
22 |
+
def __init__(self, args):
|
23 |
+
# dataset parameters
|
24 |
+
self.dataset = args.dataset
|
25 |
+
self.seed_id = args.seed_id
|
26 |
+
self.device = torch.device(args.device)
|
27 |
+
|
28 |
+
# Exp Description
|
29 |
+
self.run_description = f"{args.run_description}_{datetime.datetime.now().strftime('%H_%M')}"
|
30 |
+
self.experiment_description = args.experiment_description
|
31 |
+
|
32 |
+
# paths
|
33 |
+
self.home_path = os.getcwd()
|
34 |
+
self.save_dir = os.path.join(os.getcwd(), "experiments_logs")
|
35 |
+
self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, self.run_description)
|
36 |
+
os.makedirs(self.exp_log_dir, exist_ok=True)
|
37 |
+
|
38 |
+
self.data_path = args.data_path
|
39 |
+
|
40 |
+
# Specify runs
|
41 |
+
self.num_runs = args.num_runs
|
42 |
+
|
43 |
+
# get dataset and base model configs
|
44 |
+
self.dataset_configs, self.hparams_class = self.get_configs()
|
45 |
+
|
46 |
+
# Specify hparams
|
47 |
+
self.hparams = self.hparams_class.train_params
|
48 |
+
|
49 |
+
def get_configs(self):
|
50 |
+
dataset_class = get_dataset_class(self.dataset)
|
51 |
+
hparams_class = get_hparams_class("Supervised")
|
52 |
+
return dataset_class(), hparams_class()
|
53 |
+
|
54 |
+
def load_data(self, data_type):
|
55 |
+
self.train_dl, self.cw_dict = ECGDataloader(self.data_path, data_type, self.hparams).train_dataloader()
|
56 |
+
self.test_dl = ECGDataloader(self.data_path, data_type, self.hparams).test_dataloader()
|
57 |
+
self.valid_dl = ECGDataloader(self.data_path, data_type, self.hparams).valid_dataloader()
|
58 |
+
|
59 |
+
def calc_results_per_run(self):
|
60 |
+
acc, f1 = _calc_metrics(self.pred_labels, self.true_labels, self.dataset_configs.class_names)
|
61 |
+
return acc, f1
|
62 |
+
|
63 |
+
def train(self):
|
64 |
+
copy_files(self.exp_log_dir) # save a copy of training files
|
65 |
+
|
66 |
+
self.metrics = {'accuracy': [], 'f1_score': []}
|
67 |
+
|
68 |
+
# fixing random seed
|
69 |
+
fix_randomness(int(self.seed_id))
|
70 |
+
|
71 |
+
# Logging
|
72 |
+
self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.exp_log_dir, self.seed_id)
|
73 |
+
self.logger.debug(self.hparams)
|
74 |
+
|
75 |
+
# Load data
|
76 |
+
self.load_data(self.dataset)
|
77 |
+
|
78 |
+
model = MoE_ECGFormer(configs=self.dataset_configs, hparams=self.hparams)
|
79 |
+
model.to(self.device)
|
80 |
+
|
81 |
+
# Average meters
|
82 |
+
loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
|
83 |
+
|
84 |
+
self.optimizer = torch.optim.Adam(
|
85 |
+
model.parameters(),
|
86 |
+
lr=self.hparams["learning_rate"],
|
87 |
+
weight_decay=self.hparams["weight_decay"],
|
88 |
+
betas=(0.9, 0.99)
|
89 |
+
)
|
90 |
+
|
91 |
+
weights = [float(value) for value in self.cw_dict.values()]
|
92 |
+
# Now convert the list of floats to a numpy array, then to a PyTorch tensor
|
93 |
+
weights_array = np.array(weights).astype(np.float32) # Ensuring the correct dtype
|
94 |
+
weights_tensor = torch.tensor(weights_array).to(self.device)
|
95 |
+
self.cross_entropy = torch.nn.CrossEntropyLoss(weight=weights_tensor)
|
96 |
+
|
97 |
+
best_acc = 0
|
98 |
+
best_f1 = 0
|
99 |
+
|
100 |
+
# training..
|
101 |
+
ts_acc = 0
|
102 |
+
ts_f1 = 0
|
103 |
+
for epoch in range(1, self.hparams["num_epochs"] + 1):
|
104 |
+
model.train()
|
105 |
+
|
106 |
+
for step, batches in enumerate(self.train_dl):
|
107 |
+
batches = to_device(batches, self.device)
|
108 |
+
|
109 |
+
data = batches['samples'].float()
|
110 |
+
labels = batches['labels'].long()
|
111 |
+
|
112 |
+
# ====== Source =====================
|
113 |
+
self.optimizer.zero_grad()
|
114 |
+
|
115 |
+
# Src original features
|
116 |
+
logits = model(data)
|
117 |
+
|
118 |
+
# Cross-Entropy loss
|
119 |
+
x_ent_loss = self.cross_entropy(logits, labels)
|
120 |
+
|
121 |
+
x_ent_loss.backward()
|
122 |
+
self.optimizer.step()
|
123 |
+
|
124 |
+
losses = {'Total_loss': x_ent_loss.item()}
|
125 |
+
for key, val in losses.items():
|
126 |
+
loss_avg_meters[key].update(val, self.hparams["batch_size"])
|
127 |
+
|
128 |
+
self.evaluate(model, self.valid_dl)
|
129 |
+
tr_acc, tr_f1 = self.calc_results_per_run()
|
130 |
+
# logging
|
131 |
+
self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
|
132 |
+
for key, val in loss_avg_meters.items():
|
133 |
+
self.logger.debug(f'{key}\t: {val.avg:2.4f}')
|
134 |
+
self.logger.debug(f'TRAIN: Acc:{tr_acc:2.4f} \t F1:{tr_f1:2.4f}')
|
135 |
+
|
136 |
+
# VALIDATION part
|
137 |
+
self.evaluate(model, self.valid_dl)
|
138 |
+
ts_acc, ts_f1 = self.calc_results_per_run()
|
139 |
+
if ts_f1 > best_f1: # save best model based on best f1.
|
140 |
+
best_f1 = ts_f1
|
141 |
+
best_acc = ts_acc
|
142 |
+
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "best")
|
143 |
+
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_best")
|
144 |
+
|
145 |
+
# logging
|
146 |
+
self.logger.debug(f'VAL : Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f} (best: {best_f1:2.4f})')
|
147 |
+
self.logger.debug(f'-------------------------------------')
|
148 |
+
|
149 |
+
# LAST EPOCH
|
150 |
+
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_last")
|
151 |
+
self.logger.debug("LAST EPOCH PERFORMANCE on validation set...")
|
152 |
+
self.logger.debug(f'Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f}')
|
153 |
+
|
154 |
+
self.logger.debug(":::::::::::::")
|
155 |
+
# BEST EPOCH
|
156 |
+
self.logger.debug("BEST EPOCH PERFORMANCE on validation set ...")
|
157 |
+
self.logger.debug(f'Acc:{best_acc:2.4f} \t F1:{best_f1:2.4f}')
|
158 |
+
save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "last")
|
159 |
+
|
160 |
+
# TESTING
|
161 |
+
print(" === Evaluating on TEST set ===")
|
162 |
+
self.evaluate(model, self.test_dl)
|
163 |
+
test_acc, test_f1 = self.calc_results_per_run()
|
164 |
+
_save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "test_last")
|
165 |
+
self.logger.debug(f'Acc:{test_acc:2.4f} \t F1:{test_f1:2.4f}')
|
166 |
+
|
167 |
+
def evaluate(self, model, dataset):
|
168 |
+
model.to(self.device).eval()
|
169 |
+
|
170 |
+
total_loss_ = []
|
171 |
+
|
172 |
+
self.pred_labels = np.array([])
|
173 |
+
self.true_labels = np.array([])
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
for batches in dataset:
|
177 |
+
batches = to_device(batches, self.device)
|
178 |
+
data = batches['samples'].float()
|
179 |
+
labels = batches['labels'].long()
|
180 |
+
|
181 |
+
# forward pass
|
182 |
+
predictions = model(data)
|
183 |
+
|
184 |
+
# compute loss
|
185 |
+
loss = F.cross_entropy(predictions, labels)
|
186 |
+
total_loss_.append(loss.item())
|
187 |
+
pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability
|
188 |
+
|
189 |
+
self.pred_labels = np.append(self.pred_labels, pred.cpu().numpy())
|
190 |
+
self.true_labels = np.append(self.true_labels, labels.data.cpu().numpy())
|
191 |
+
|
192 |
+
self.trg_loss = torch.tensor(total_loss_).mean() # average loss
|
utils.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from shutil import copy
|
9 |
+
from datetime import datetime
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import collections
|
12 |
+
import umap
|
13 |
+
import umap.plot
|
14 |
+
from matplotlib.colors import ListedColormap
|
15 |
+
|
16 |
+
|
17 |
+
from sklearn.metrics import classification_report, accuracy_score
|
18 |
+
|
19 |
+
|
20 |
+
def count_parameters(model):
|
21 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
22 |
+
|
23 |
+
|
24 |
+
class AverageMeter(object):
|
25 |
+
"""Computes and stores the average and current value"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.reset()
|
29 |
+
|
30 |
+
def reset(self):
|
31 |
+
self.val = 0
|
32 |
+
self.avg = 0
|
33 |
+
self.sum = 0
|
34 |
+
self.count = 0
|
35 |
+
|
36 |
+
def update(self, val, n=1):
|
37 |
+
self.val = val
|
38 |
+
self.sum += val * n
|
39 |
+
self.count += n
|
40 |
+
self.avg = self.sum / self.count
|
41 |
+
|
42 |
+
|
43 |
+
def fix_randomness(SEED):
|
44 |
+
random.seed(SEED)
|
45 |
+
np.random.seed(SEED)
|
46 |
+
torch.manual_seed(SEED)
|
47 |
+
torch.cuda.manual_seed(SEED)
|
48 |
+
torch.backends.cudnn.deterministic = True
|
49 |
+
torch.backends.cudnn.benchmark = False
|
50 |
+
|
51 |
+
|
52 |
+
def _logger(logger_name, level=logging.DEBUG):
|
53 |
+
"""
|
54 |
+
Method to return a custom logger with the given name and level
|
55 |
+
:param logger_name:
|
56 |
+
:param level:
|
57 |
+
:return:
|
58 |
+
"""
|
59 |
+
logger = logging.getLogger(logger_name)
|
60 |
+
logger.setLevel(level)
|
61 |
+
format_string = "%(message)s"
|
62 |
+
log_format = logging.Formatter(format_string)
|
63 |
+
|
64 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
65 |
+
console_handler.setFormatter(log_format)
|
66 |
+
logger.addHandler(console_handler)
|
67 |
+
# Creating and adding the file handler
|
68 |
+
file_handler = logging.FileHandler(logger_name, mode='a')
|
69 |
+
file_handler.setFormatter(log_format)
|
70 |
+
logger.addHandler(file_handler)
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
def starting_logs(data_type, exp_log_dir, seed_id):
|
75 |
+
log_dir = os.path.join(exp_log_dir, "_seed_"+ str(seed_id))
|
76 |
+
os.makedirs(log_dir,exist_ok=True)
|
77 |
+
log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
|
78 |
+
logger = _logger(log_file_name)
|
79 |
+
logger.debug('=' * 45)
|
80 |
+
logger.debug(f'Logging data type {data_type}')
|
81 |
+
logger.debug("=" * 45)
|
82 |
+
logger.debug(f'Logging seed id {seed_id}')
|
83 |
+
logger.debug("=" * 45)
|
84 |
+
return logger, log_dir
|
85 |
+
|
86 |
+
|
87 |
+
def save_checkpoint(exp_log_dir, model, dataset, dataset_configs, hparams, status):
|
88 |
+
save_dict = {
|
89 |
+
"dataset": dataset,
|
90 |
+
"configs": dataset_configs.__dict__,
|
91 |
+
"hparams": dict(hparams),
|
92 |
+
"model": model.state_dict(),
|
93 |
+
}
|
94 |
+
# save classification report
|
95 |
+
save_path = os.path.join(exp_log_dir, f"checkpoint_{status}.pt")
|
96 |
+
torch.save(save_dict, save_path)
|
97 |
+
|
98 |
+
|
99 |
+
def _calc_metrics(pred_labels, true_labels, classes_names):
|
100 |
+
pred_labels = np.array(pred_labels).astype(int)
|
101 |
+
true_labels = np.array(true_labels).astype(int)
|
102 |
+
|
103 |
+
r = classification_report(true_labels, pred_labels, target_names=classes_names, digits=6, output_dict=True)
|
104 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
105 |
+
|
106 |
+
return accuracy * 100, r["macro avg"]["f1-score"] * 100
|
107 |
+
|
108 |
+
|
109 |
+
def _save_metrics(pred_labels, true_labels, log_dir, status):
|
110 |
+
pred_labels = np.array(pred_labels).astype(int)
|
111 |
+
true_labels = np.array(true_labels).astype(int)
|
112 |
+
|
113 |
+
r = classification_report(true_labels, pred_labels, digits=6, output_dict=True)
|
114 |
+
|
115 |
+
df = pd.DataFrame(r)
|
116 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
117 |
+
df["accuracy"] = accuracy
|
118 |
+
df = df * 100
|
119 |
+
|
120 |
+
# save classification report
|
121 |
+
file_name = f"classification_report_{status}.xlsx"
|
122 |
+
report_save_path = os.path.join(log_dir, file_name)
|
123 |
+
df.to_excel(report_save_path)
|
124 |
+
|
125 |
+
|
126 |
+
def to_device(input, device):
|
127 |
+
if torch.is_tensor(input):
|
128 |
+
return input.to(device=device)
|
129 |
+
elif isinstance(input, str):
|
130 |
+
return input
|
131 |
+
elif isinstance(input, collections.abc.Mapping):
|
132 |
+
return {k: to_device(sample, device=device) for k, sample in input.items()}
|
133 |
+
elif isinstance(input, collections.abc.Sequence):
|
134 |
+
return [to_device(sample, device=device) for sample in input]
|
135 |
+
else:
|
136 |
+
raise TypeError("Input must contain tensor, dict or list, found {type(input)}")
|
137 |
+
|
138 |
+
|
139 |
+
def copy_files(destination):
|
140 |
+
destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES")
|
141 |
+
os.makedirs(destination_dir, exist_ok=True)
|
142 |
+
copy("main.py", os.path.join(destination_dir, "main.py"))
|
143 |
+
copy("data/dataloader.py", os.path.join(destination_dir, "dataloader.py"))
|
144 |
+
copy("data/dataset.py", os.path.join(destination_dir, "dataset.py"))
|
145 |
+
copy(f"models/MoE_ECGFormer.py", os.path.join(destination_dir, f"models.py"))
|
146 |
+
copy(f"configs/data_configs.py", os.path.join(destination_dir, f"data_configs.py"))
|
147 |
+
copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py"))
|
148 |
+
copy(f"train.py", os.path.join(destination_dir, f"train.py"))
|
149 |
+
copy("utils.py", os.path.join(destination_dir, "utils.py"))
|
150 |
+
|
151 |
+
|
152 |
+
def _plot_umap(model, data_loader, device, save_dir):
|
153 |
+
|
154 |
+
classes_names = ['N', 'S', 'V', 'F', 'Q']
|
155 |
+
|
156 |
+
font = {'family': 'Times New Roman',
|
157 |
+
'weight': 'bold',
|
158 |
+
'size': 17}
|
159 |
+
plt.rc('font', **font)
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
# Source flow
|
163 |
+
data = data_loader.dataset.x_data.float().to(device)
|
164 |
+
labels = data_loader.dataset.y_data.view((-1)).long()
|
165 |
+
out = model[0](data)
|
166 |
+
features = model[1](out)
|
167 |
+
|
168 |
+
if not os.path.exists(os.path.join(save_dir, "umap_plots")):
|
169 |
+
os.mkdir(os.path.join(save_dir, "umap_plots"))
|
170 |
+
|
171 |
+
# cmaps = plt.get_cmap('jet')
|
172 |
+
model_reducer = umap.UMAP() # n_neighbors=3, min_dist=0.3, metric='correlation', random_state=42)
|
173 |
+
embedding = model_reducer.fit_transform(features.detach().cpu().numpy())
|
174 |
+
|
175 |
+
# Normalize the labels to [0, 1] for colormap
|
176 |
+
norm_labels = labels / 4.0
|
177 |
+
|
178 |
+
# Create a new colormap by extracting the first 5 colors from "Paired"
|
179 |
+
paired = plt.cm.get_cmap('Paired', 12) # 12 distinct colors
|
180 |
+
new_colors = [paired(0), paired(1), paired(2), paired(4),
|
181 |
+
paired(6)] # Skip every second color, but take both from the first pair
|
182 |
+
new_cmap = ListedColormap(new_colors)
|
183 |
+
|
184 |
+
print("Plotting UMAP ...")
|
185 |
+
plt.figure(figsize=(16, 10))
|
186 |
+
# scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, s=10, cmap='Spectral')
|
187 |
+
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=norm_labels, cmap=new_cmap, s=15)
|
188 |
+
|
189 |
+
handles, _ = scatter.legend_elements(prop='colors')
|
190 |
+
plt.legend(handles, classes_names, title="Classes")
|
191 |
+
file_name = "umap_.png"
|
192 |
+
fig_save_name = os.path.join(save_dir, "umap_plots", file_name)
|
193 |
+
plt.xticks([])
|
194 |
+
plt.yticks([])
|
195 |
+
plt.savefig(fig_save_name, bbox_inches='tight')
|
196 |
+
plt.close()
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|