splendor1811 commited on
Commit
bcf646b
·
verified ·
1 Parent(s): 265dca5

Upload 14 files

Browse files
configs/__pycache__/data_configs.cpython-312.pyc ADDED
Binary file (2.05 kB). View file
 
configs/__pycache__/hparams.cpython-312.pyc ADDED
Binary file (1.09 kB). View file
 
configs/data_configs.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_dataset_class(dataset_name):
2
+ """
3
+ :param dataset_name:
4
+ :return: The algorithm class with the given name
5
+ """
6
+ if dataset_name not in globals():
7
+ raise NotImplementedError("Dataset not found : {}".format(dataset_name))
8
+ return globals()[dataset_name]
9
+
10
+
11
+ class mit:
12
+ def __init__(self):
13
+ super().__init__()
14
+ # data params
15
+ self.num_classes = 5
16
+ self.class_names = ['N', 'S', 'V', 'F', 'Q']
17
+ self.sequence_len = 186
18
+
19
+ # model configs
20
+ self.input_channels = 1
21
+ self.kernel_size = 8
22
+ self.stride = 1
23
+ self.dropout = 0.2
24
+
25
+ # features
26
+ self.mid_channels = 32
27
+ self.final_out_channels = 128
28
+
29
+ # MoE Transformer
30
+ self.trans_dim = 25
31
+ self.num_heads = 5
32
+ self.num_experts = 8
33
+
34
+
35
+ class ptb:
36
+ def __init__(self):
37
+ super().__init__()
38
+ # data params
39
+ self.num_classes = 2
40
+ self.class_names = ['normal', 'abnormal']
41
+ self.sequence_len = 188
42
+
43
+ # model_configs
44
+ self.input_channels = 1
45
+ self.kernel_size = 8
46
+ self.stride = 1
47
+ self.dropout = 0.2
48
+
49
+ # features
50
+ self.mid_channels = 32
51
+ self.final_out_channels = 128
52
+
53
+ # MoE Transformer
54
+ self.trans_dim = 25
55
+ self.num_heads = 5
56
+ self.num_experts = 8
57
+
58
+
configs/hparams.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_hparams_class(dataset_name):
2
+ """Return the algorithm class with the given name."""
3
+ if dataset_name not in globals():
4
+ raise NotImplementedError("Algorithm not found: {}".format(dataset_name))
5
+ return globals()[dataset_name]
6
+
7
+
8
+ class Supervised:
9
+ def __init__(self):
10
+ super(Supervised, self).__init__()
11
+ self.train_params = {
12
+ 'num_epochs': 100,
13
+ 'batch_size': 128,
14
+ 'weight_decay': 1e-4,
15
+ 'learning_rate': 1e-3,
16
+ 'feature_dim': 1*128
17
+ }
data/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (3.6 kB). View file
 
data/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (1.86 kB). View file
 
data/dataloader.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from data.dataset import load_ECG_Dataset
4
+
5
+ import os
6
+ import numpy as np
7
+ import math
8
+
9
+
10
+ def get_class_weight(labels_dict):
11
+ total = sum(labels_dict.values())
12
+ max_num = max(labels_dict.values())
13
+ mu = 1.0 / (total / max_num)
14
+ class_weight = dict()
15
+ for key, value in labels_dict.items():
16
+ score = math.log(mu * total / float(value))
17
+ class_weight[key] = score if score > 1.0 else 1.0
18
+ return class_weight
19
+
20
+
21
+ class ECGDataloader:
22
+ testdata_path: str
23
+ traindata_path: str
24
+ valdata_path: str
25
+
26
+ def __init__(self, data_path, data_type, hparams):
27
+ self.traindata_path = os.path.join(data_path, data_type, f'train.pt')
28
+ self.testdata_path = os.path.join(data_path, data_type, f'test.pt')
29
+ self.validdata_path = os.path.join(data_path, data_type, f'val.pt')
30
+ self.batch_size = hparams['batch_size']
31
+
32
+ def train_dataloader(self):
33
+ train_dataset = torch.load(self.traindata_path)
34
+ train_dataset = load_ECG_Dataset(train_dataset)
35
+ cw = train_dataset.y_data.numpy().tolist()
36
+ cw_dict = {}
37
+ for i in range(len(np.unique(train_dataset.y_data.numpy()))):
38
+ cw_dict[i] = cw.count(i)
39
+
40
+ train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True,
41
+ drop_last=True, num_workers=4)
42
+ return train_loader, get_class_weight(cw_dict)
43
+
44
+ def test_dataloader(self):
45
+ test_dataset = torch.load(self.testdata_path)
46
+ test_dataset = load_ECG_Dataset(test_dataset)
47
+ test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False,
48
+ drop_last=False, num_workers=4)
49
+ return test_loader
50
+
51
+ def valid_dataloader(self):
52
+ valid_dataset = torch.load(self.validdata_path)
53
+ valid_dataset = load_ECG_Dataset(valid_dataset)
54
+ valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True,
55
+ drop_last=False, num_workers=4)
56
+ return valid_loader
data/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ import numpy as np
5
+
6
+
7
+ class load_ECG_Dataset(Dataset):
8
+ # Initialize dataset
9
+ def __init__(self, dataset):
10
+
11
+ # Load sample
12
+ x_data = dataset["samples"]
13
+ # Convert to pytorch tensor
14
+ if isinstance(x_data, np.ndarray):
15
+ x_data = torch.from_numpy(x_data)
16
+
17
+ # Load labels
18
+ y_data = dataset.get("labels")
19
+ if y_data is not None and isinstance(y_data, np.ndarray):
20
+ y_data = torch.from_numpy(y_data)
21
+
22
+ self.x_data = x_data.float()
23
+ self.y_data = y_data.long() if y_data is not None else None
24
+
25
+ self.len = x_data.shape[0]
26
+
27
+ def get_labels(self):
28
+ return self.y_data
29
+
30
+ def __getitem__(self, idx):
31
+ sample = {
32
+ 'samples': self.x_data[idx].squeeze(-1),
33
+ 'labels': int(self.y_data[idx])
34
+ }
35
+ return sample
36
+
37
+ def __len__(self):
38
+ return self.len
39
+
main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import warnings
4
+ from train import Trainer
5
+ import sklearn.exceptions
6
+ warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ # ======== Experiments Name ================
11
+ parser.add_argument('--save_dir', default='experiments_logs', type=str,
12
+ help='Directory containing all experiments')
13
+ parser.add_argument('--experiment_description', default='Exp1', type=str, help='experiment name')
14
+ parser.add_argument('--run_description', default='run1', type=str, help='run name')
15
+
16
+ # ========= Select the DATASET ==============
17
+ parser.add_argument('--dataset', default='mit', type=str, help='mit, ptb')
18
+ parser.add_argument('--seed_id', default='0', type=str,
19
+ help='to fix a seed while training')
20
+
21
+ # ========= Experiment settings ===============
22
+ parser.add_argument('--data_path', default=r'/Users/splendor1811/datn/ECGTransForm/datasets',
23
+ type=str, help='Path containing dataset')
24
+
25
+ parser.add_argument('--num_runs', default=1, type=int,
26
+ help='Number of consecutive run with different seeds')
27
+ parser.add_argument('--device', default='mps', type=str,
28
+ help='cpu or cuda')
29
+
30
+
31
+ args = parser.parse_args()
32
+
33
+ if __name__ == "__main__":
34
+ trainer = Trainer(args)
35
+ trainer.train()
models/MoE_ECGFormer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MoE_ECGFormer(nn.Module):
6
+ def __init__(self, configs, hparams):
7
+ super().__init__()
8
+
9
+ filter_sizes = [5, 9, 11]
10
+ self.conv1 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[0],
11
+ stride=configs.stride, bias=False, padding=(filter_sizes[0] // 2))
12
+
13
+ self.conv2 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[1],
14
+ stride=configs.stride, bias=False, padding=(filter_sizes[1] // 2))
15
+
16
+ self.conv3 = nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=filter_sizes[2],
17
+ stride=configs.stride, bias=False, padding=(filter_sizes[2] // 2))
18
+
19
+ self.bn = nn.BatchNorm1d(configs.mid_channels)
20
+ self.relu = nn.ReLU()
21
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
22
+ self.dropout = nn.Dropout(configs.dropout)
23
+
24
+ self.conv_block2 = nn.Sequential(
25
+ nn.Conv1d(configs.mid_channels, configs.mid_channels * 2, kernel_size=8, stride=1, bias=False,
26
+ padding=4),
27
+ nn.BatchNorm1d(configs.mid_channels * 2),
28
+ nn.ReLU(),
29
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
30
+ )
31
+
32
+ self.conv_block3 = nn.Sequential(
33
+ nn.Conv1d(configs.mid_channels * 2, configs.final_out_channels, kernel_size=8, stride=1, bias=False,
34
+ padding=4),
35
+ nn.BatchNorm1d(configs.final_out_channels),
36
+ nn.ReLU(),
37
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
38
+ )
39
+
40
+ self.inplanes = 128
41
+ self.crm = self._make_layer(SEBasicBlock, 128, 3)
42
+
43
+ # Transformer_layer
44
+ self.encoder_layer = nn.TransformerEncoderLayer(d_model=configs.trans_dim,
45
+ nhead=configs.num_heads,
46
+ batch_first=True)
47
+ self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
48
+ self.aap = nn.AdaptiveAvgPool1d(1)
49
+ self.clf = nn.Linear(hparams['feature_dim'], configs.num_classes)
50
+
51
+ def _make_layer(self, block, planes, blocks, stride=1):
52
+ downsample = None
53
+ if stride != 1 or self.inplanes != planes * block.expansion:
54
+ downsample = nn.Sequential(
55
+ nn.Conv1d(self.inplanes, planes * block.expansion,
56
+ kernel_size=1, stride=stride, bias=False),
57
+ nn.BatchNorm1d(planes * block.expansion),
58
+ )
59
+
60
+ layers = [block(self.inplanes, planes, stride, downsample)]
61
+ self.inplanes = planes * block.expansion
62
+ for i in range(1, blocks):
63
+ layers.append(block(self.inplanes, planes))
64
+
65
+ return nn.Sequential(*layers)
66
+
67
+ def forward(self, x_in):
68
+
69
+ # Multi-scale Convolutions
70
+ x1 = self.conv1(x_in)
71
+ x2 = self.conv2(x_in)
72
+ x3 = self.conv3(x_in)
73
+
74
+ x_concat = torch.mean(torch.stack([x1, x2, x3], dim=2), dim=2)
75
+ x_concat = self.dropout(self.maxpool(self.relu(self.bn(x_concat))))
76
+
77
+ x = self.conv_block2(x_concat)
78
+ x = self.conv_block3(x)
79
+
80
+ # Channel Recalibration Module
81
+ x = self.crm(x)
82
+
83
+ # Bidirectional MoE Transformer
84
+ x1 = self.transformer_encoder(x)
85
+ x2 = self.transformer_encoder(torch.flip(x, [2]))
86
+ x = x1 + x2
87
+
88
+ x = self.aap(x)
89
+ x_flat = x.reshape(x.size(0), -1)
90
+ x_out = self.clf(x_flat)
91
+ return x_out
92
+
93
+
94
+ class SELayer(nn.Module):
95
+ def __init__(self, channel, reduction=4):
96
+ super(SELayer, self).__init__()
97
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
98
+ self.fc = nn.Sequential(
99
+ nn.Linear(channel, channel // reduction, bias=False),
100
+ nn.ReLU(inplace=True),
101
+ nn.Linear(channel // reduction, channel, bias=False),
102
+ nn.Sigmoid()
103
+ )
104
+
105
+ def forward(self, x):
106
+ b, c, _ = x.size()
107
+ y = self.avg_pool(x).view(b, c)
108
+ y = self.fc(y).view(b, c, 1)
109
+ return x * y.expand_as(x)
110
+
111
+
112
+ class SEBasicBlock(nn.Module):
113
+ expansion = 1
114
+
115
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
116
+ base_width=64, dilation=1, norm_layer=None,
117
+ *, reduction=4):
118
+ super(SEBasicBlock, self).__init__()
119
+ self.conv1 = nn.Conv1d(inplanes, planes, stride)
120
+ self.bn1 = nn.BatchNorm1d(planes)
121
+ self.relu = nn.ReLU(inplace=True)
122
+ self.conv2 = nn.Conv1d(planes, planes, 1)
123
+ self.bn2 = nn.BatchNorm1d(planes)
124
+ self.se = SELayer(planes, reduction)
125
+ self.downsample = downsample
126
+ self.stride = stride
127
+
128
+ def forward(self, x):
129
+ residual = x
130
+ out = self.conv1(x)
131
+ out = self.bn1(out)
132
+ out = self.relu(out)
133
+
134
+ out = self.conv2(out)
135
+ out = self.bn2(out)
136
+ out = self.se(out)
137
+
138
+ if self.downsample is not None:
139
+ residual = self.downsample(x)
140
+
141
+ out += residual
142
+ out = self.relu(out)
143
+
144
+ return out
models/__pycache__/MoE_ECGFormer.cpython-312.pyc ADDED
Binary file (8.74 kB). View file
 
test.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from deepface import DeepFace
2
+
3
+ print(DeepFace.verify('/Users/splendor1811/Desktop/a1.jpg' ,'/Users/splendor1811/Desktop/a1.jpg'))
train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.MoE_ECGFormer import MoE_ECGFormer
2
+ from data.dataloader import ECGDataloader
3
+ from configs.data_configs import get_dataset_class
4
+ from configs.hparams import get_hparams_class
5
+ from utils import AverageMeter, to_device, _save_metrics, copy_files
6
+ from utils import fix_randomness, starting_logs, save_checkpoint, _calc_metrics
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import datetime
10
+ import os
11
+ import collections
12
+ import numpy as np
13
+
14
+ import warnings
15
+ import sklearn.exceptions
16
+
17
+ warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)
18
+ warnings.simplefilter(action='ignore', category=FutureWarning)
19
+
20
+
21
+ class Trainer(object):
22
+ def __init__(self, args):
23
+ # dataset parameters
24
+ self.dataset = args.dataset
25
+ self.seed_id = args.seed_id
26
+ self.device = torch.device(args.device)
27
+
28
+ # Exp Description
29
+ self.run_description = f"{args.run_description}_{datetime.datetime.now().strftime('%H_%M')}"
30
+ self.experiment_description = args.experiment_description
31
+
32
+ # paths
33
+ self.home_path = os.getcwd()
34
+ self.save_dir = os.path.join(os.getcwd(), "experiments_logs")
35
+ self.exp_log_dir = os.path.join(self.save_dir, self.experiment_description, self.run_description)
36
+ os.makedirs(self.exp_log_dir, exist_ok=True)
37
+
38
+ self.data_path = args.data_path
39
+
40
+ # Specify runs
41
+ self.num_runs = args.num_runs
42
+
43
+ # get dataset and base model configs
44
+ self.dataset_configs, self.hparams_class = self.get_configs()
45
+
46
+ # Specify hparams
47
+ self.hparams = self.hparams_class.train_params
48
+
49
+ def get_configs(self):
50
+ dataset_class = get_dataset_class(self.dataset)
51
+ hparams_class = get_hparams_class("Supervised")
52
+ return dataset_class(), hparams_class()
53
+
54
+ def load_data(self, data_type):
55
+ self.train_dl, self.cw_dict = ECGDataloader(self.data_path, data_type, self.hparams).train_dataloader()
56
+ self.test_dl = ECGDataloader(self.data_path, data_type, self.hparams).test_dataloader()
57
+ self.valid_dl = ECGDataloader(self.data_path, data_type, self.hparams).valid_dataloader()
58
+
59
+ def calc_results_per_run(self):
60
+ acc, f1 = _calc_metrics(self.pred_labels, self.true_labels, self.dataset_configs.class_names)
61
+ return acc, f1
62
+
63
+ def train(self):
64
+ copy_files(self.exp_log_dir) # save a copy of training files
65
+
66
+ self.metrics = {'accuracy': [], 'f1_score': []}
67
+
68
+ # fixing random seed
69
+ fix_randomness(int(self.seed_id))
70
+
71
+ # Logging
72
+ self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.exp_log_dir, self.seed_id)
73
+ self.logger.debug(self.hparams)
74
+
75
+ # Load data
76
+ self.load_data(self.dataset)
77
+
78
+ model = MoE_ECGFormer(configs=self.dataset_configs, hparams=self.hparams)
79
+ model.to(self.device)
80
+
81
+ # Average meters
82
+ loss_avg_meters = collections.defaultdict(lambda: AverageMeter())
83
+
84
+ self.optimizer = torch.optim.Adam(
85
+ model.parameters(),
86
+ lr=self.hparams["learning_rate"],
87
+ weight_decay=self.hparams["weight_decay"],
88
+ betas=(0.9, 0.99)
89
+ )
90
+
91
+ weights = [float(value) for value in self.cw_dict.values()]
92
+ # Now convert the list of floats to a numpy array, then to a PyTorch tensor
93
+ weights_array = np.array(weights).astype(np.float32) # Ensuring the correct dtype
94
+ weights_tensor = torch.tensor(weights_array).to(self.device)
95
+ self.cross_entropy = torch.nn.CrossEntropyLoss(weight=weights_tensor)
96
+
97
+ best_acc = 0
98
+ best_f1 = 0
99
+
100
+ # training..
101
+ ts_acc = 0
102
+ ts_f1 = 0
103
+ for epoch in range(1, self.hparams["num_epochs"] + 1):
104
+ model.train()
105
+
106
+ for step, batches in enumerate(self.train_dl):
107
+ batches = to_device(batches, self.device)
108
+
109
+ data = batches['samples'].float()
110
+ labels = batches['labels'].long()
111
+
112
+ # ====== Source =====================
113
+ self.optimizer.zero_grad()
114
+
115
+ # Src original features
116
+ logits = model(data)
117
+
118
+ # Cross-Entropy loss
119
+ x_ent_loss = self.cross_entropy(logits, labels)
120
+
121
+ x_ent_loss.backward()
122
+ self.optimizer.step()
123
+
124
+ losses = {'Total_loss': x_ent_loss.item()}
125
+ for key, val in losses.items():
126
+ loss_avg_meters[key].update(val, self.hparams["batch_size"])
127
+
128
+ self.evaluate(model, self.valid_dl)
129
+ tr_acc, tr_f1 = self.calc_results_per_run()
130
+ # logging
131
+ self.logger.debug(f'[Epoch : {epoch}/{self.hparams["num_epochs"]}]')
132
+ for key, val in loss_avg_meters.items():
133
+ self.logger.debug(f'{key}\t: {val.avg:2.4f}')
134
+ self.logger.debug(f'TRAIN: Acc:{tr_acc:2.4f} \t F1:{tr_f1:2.4f}')
135
+
136
+ # VALIDATION part
137
+ self.evaluate(model, self.valid_dl)
138
+ ts_acc, ts_f1 = self.calc_results_per_run()
139
+ if ts_f1 > best_f1: # save best model based on best f1.
140
+ best_f1 = ts_f1
141
+ best_acc = ts_acc
142
+ save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "best")
143
+ _save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_best")
144
+
145
+ # logging
146
+ self.logger.debug(f'VAL : Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f} (best: {best_f1:2.4f})')
147
+ self.logger.debug(f'-------------------------------------')
148
+
149
+ # LAST EPOCH
150
+ _save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "validation_last")
151
+ self.logger.debug("LAST EPOCH PERFORMANCE on validation set...")
152
+ self.logger.debug(f'Acc:{ts_acc:2.4f} \t F1:{ts_f1:2.4f}')
153
+
154
+ self.logger.debug(":::::::::::::")
155
+ # BEST EPOCH
156
+ self.logger.debug("BEST EPOCH PERFORMANCE on validation set ...")
157
+ self.logger.debug(f'Acc:{best_acc:2.4f} \t F1:{best_f1:2.4f}')
158
+ save_checkpoint(self.exp_log_dir, model, self.dataset, self.dataset_configs, self.hparams, "last")
159
+
160
+ # TESTING
161
+ print(" === Evaluating on TEST set ===")
162
+ self.evaluate(model, self.test_dl)
163
+ test_acc, test_f1 = self.calc_results_per_run()
164
+ _save_metrics(self.pred_labels, self.true_labels, self.exp_log_dir, "test_last")
165
+ self.logger.debug(f'Acc:{test_acc:2.4f} \t F1:{test_f1:2.4f}')
166
+
167
+ def evaluate(self, model, dataset):
168
+ model.to(self.device).eval()
169
+
170
+ total_loss_ = []
171
+
172
+ self.pred_labels = np.array([])
173
+ self.true_labels = np.array([])
174
+
175
+ with torch.no_grad():
176
+ for batches in dataset:
177
+ batches = to_device(batches, self.device)
178
+ data = batches['samples'].float()
179
+ labels = batches['labels'].long()
180
+
181
+ # forward pass
182
+ predictions = model(data)
183
+
184
+ # compute loss
185
+ loss = F.cross_entropy(predictions, labels)
186
+ total_loss_.append(loss.item())
187
+ pred = predictions.detach().argmax(dim=1) # get the index of the max log-probability
188
+
189
+ self.pred_labels = np.append(self.pred_labels, pred.cpu().numpy())
190
+ self.true_labels = np.append(self.true_labels, labels.data.cpu().numpy())
191
+
192
+ self.trg_loss = torch.tensor(total_loss_).mean() # average loss
utils.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy as np
7
+ import pandas as pd
8
+ from shutil import copy
9
+ from datetime import datetime
10
+ import matplotlib.pyplot as plt
11
+ import collections
12
+ import umap
13
+ import umap.plot
14
+ from matplotlib.colors import ListedColormap
15
+
16
+
17
+ from sklearn.metrics import classification_report, accuracy_score
18
+
19
+
20
+ def count_parameters(model):
21
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
22
+
23
+
24
+ class AverageMeter(object):
25
+ """Computes and stores the average and current value"""
26
+
27
+ def __init__(self):
28
+ self.reset()
29
+
30
+ def reset(self):
31
+ self.val = 0
32
+ self.avg = 0
33
+ self.sum = 0
34
+ self.count = 0
35
+
36
+ def update(self, val, n=1):
37
+ self.val = val
38
+ self.sum += val * n
39
+ self.count += n
40
+ self.avg = self.sum / self.count
41
+
42
+
43
+ def fix_randomness(SEED):
44
+ random.seed(SEED)
45
+ np.random.seed(SEED)
46
+ torch.manual_seed(SEED)
47
+ torch.cuda.manual_seed(SEED)
48
+ torch.backends.cudnn.deterministic = True
49
+ torch.backends.cudnn.benchmark = False
50
+
51
+
52
+ def _logger(logger_name, level=logging.DEBUG):
53
+ """
54
+ Method to return a custom logger with the given name and level
55
+ :param logger_name:
56
+ :param level:
57
+ :return:
58
+ """
59
+ logger = logging.getLogger(logger_name)
60
+ logger.setLevel(level)
61
+ format_string = "%(message)s"
62
+ log_format = logging.Formatter(format_string)
63
+
64
+ console_handler = logging.StreamHandler(sys.stdout)
65
+ console_handler.setFormatter(log_format)
66
+ logger.addHandler(console_handler)
67
+ # Creating and adding the file handler
68
+ file_handler = logging.FileHandler(logger_name, mode='a')
69
+ file_handler.setFormatter(log_format)
70
+ logger.addHandler(file_handler)
71
+ return logger
72
+
73
+
74
+ def starting_logs(data_type, exp_log_dir, seed_id):
75
+ log_dir = os.path.join(exp_log_dir, "_seed_"+ str(seed_id))
76
+ os.makedirs(log_dir,exist_ok=True)
77
+ log_file_name = os.path.join(log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
78
+ logger = _logger(log_file_name)
79
+ logger.debug('=' * 45)
80
+ logger.debug(f'Logging data type {data_type}')
81
+ logger.debug("=" * 45)
82
+ logger.debug(f'Logging seed id {seed_id}')
83
+ logger.debug("=" * 45)
84
+ return logger, log_dir
85
+
86
+
87
+ def save_checkpoint(exp_log_dir, model, dataset, dataset_configs, hparams, status):
88
+ save_dict = {
89
+ "dataset": dataset,
90
+ "configs": dataset_configs.__dict__,
91
+ "hparams": dict(hparams),
92
+ "model": model.state_dict(),
93
+ }
94
+ # save classification report
95
+ save_path = os.path.join(exp_log_dir, f"checkpoint_{status}.pt")
96
+ torch.save(save_dict, save_path)
97
+
98
+
99
+ def _calc_metrics(pred_labels, true_labels, classes_names):
100
+ pred_labels = np.array(pred_labels).astype(int)
101
+ true_labels = np.array(true_labels).astype(int)
102
+
103
+ r = classification_report(true_labels, pred_labels, target_names=classes_names, digits=6, output_dict=True)
104
+ accuracy = accuracy_score(true_labels, pred_labels)
105
+
106
+ return accuracy * 100, r["macro avg"]["f1-score"] * 100
107
+
108
+
109
+ def _save_metrics(pred_labels, true_labels, log_dir, status):
110
+ pred_labels = np.array(pred_labels).astype(int)
111
+ true_labels = np.array(true_labels).astype(int)
112
+
113
+ r = classification_report(true_labels, pred_labels, digits=6, output_dict=True)
114
+
115
+ df = pd.DataFrame(r)
116
+ accuracy = accuracy_score(true_labels, pred_labels)
117
+ df["accuracy"] = accuracy
118
+ df = df * 100
119
+
120
+ # save classification report
121
+ file_name = f"classification_report_{status}.xlsx"
122
+ report_save_path = os.path.join(log_dir, file_name)
123
+ df.to_excel(report_save_path)
124
+
125
+
126
+ def to_device(input, device):
127
+ if torch.is_tensor(input):
128
+ return input.to(device=device)
129
+ elif isinstance(input, str):
130
+ return input
131
+ elif isinstance(input, collections.abc.Mapping):
132
+ return {k: to_device(sample, device=device) for k, sample in input.items()}
133
+ elif isinstance(input, collections.abc.Sequence):
134
+ return [to_device(sample, device=device) for sample in input]
135
+ else:
136
+ raise TypeError("Input must contain tensor, dict or list, found {type(input)}")
137
+
138
+
139
+ def copy_files(destination):
140
+ destination_dir = os.path.join(destination, "MODEL_BACKUP_FILES")
141
+ os.makedirs(destination_dir, exist_ok=True)
142
+ copy("main.py", os.path.join(destination_dir, "main.py"))
143
+ copy("data/dataloader.py", os.path.join(destination_dir, "dataloader.py"))
144
+ copy("data/dataset.py", os.path.join(destination_dir, "dataset.py"))
145
+ copy(f"models/MoE_ECGFormer.py", os.path.join(destination_dir, f"models.py"))
146
+ copy(f"configs/data_configs.py", os.path.join(destination_dir, f"data_configs.py"))
147
+ copy(f"configs/hparams.py", os.path.join(destination_dir, f"hparams.py"))
148
+ copy(f"train.py", os.path.join(destination_dir, f"train.py"))
149
+ copy("utils.py", os.path.join(destination_dir, "utils.py"))
150
+
151
+
152
+ def _plot_umap(model, data_loader, device, save_dir):
153
+
154
+ classes_names = ['N', 'S', 'V', 'F', 'Q']
155
+
156
+ font = {'family': 'Times New Roman',
157
+ 'weight': 'bold',
158
+ 'size': 17}
159
+ plt.rc('font', **font)
160
+
161
+ with torch.no_grad():
162
+ # Source flow
163
+ data = data_loader.dataset.x_data.float().to(device)
164
+ labels = data_loader.dataset.y_data.view((-1)).long()
165
+ out = model[0](data)
166
+ features = model[1](out)
167
+
168
+ if not os.path.exists(os.path.join(save_dir, "umap_plots")):
169
+ os.mkdir(os.path.join(save_dir, "umap_plots"))
170
+
171
+ # cmaps = plt.get_cmap('jet')
172
+ model_reducer = umap.UMAP() # n_neighbors=3, min_dist=0.3, metric='correlation', random_state=42)
173
+ embedding = model_reducer.fit_transform(features.detach().cpu().numpy())
174
+
175
+ # Normalize the labels to [0, 1] for colormap
176
+ norm_labels = labels / 4.0
177
+
178
+ # Create a new colormap by extracting the first 5 colors from "Paired"
179
+ paired = plt.cm.get_cmap('Paired', 12) # 12 distinct colors
180
+ new_colors = [paired(0), paired(1), paired(2), paired(4),
181
+ paired(6)] # Skip every second color, but take both from the first pair
182
+ new_cmap = ListedColormap(new_colors)
183
+
184
+ print("Plotting UMAP ...")
185
+ plt.figure(figsize=(16, 10))
186
+ # scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, s=10, cmap='Spectral')
187
+ scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=norm_labels, cmap=new_cmap, s=15)
188
+
189
+ handles, _ = scatter.legend_elements(prop='colors')
190
+ plt.legend(handles, classes_names, title="Classes")
191
+ file_name = "umap_.png"
192
+ fig_save_name = os.path.join(save_dir, "umap_plots", file_name)
193
+ plt.xticks([])
194
+ plt.yticks([])
195
+ plt.savefig(fig_save_name, bbox_inches='tight')
196
+ plt.close()
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+