File size: 2,126 Bytes
bcf646b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch.utils.data import DataLoader
from data.dataset import load_ECG_Dataset

import os
import numpy as np
import math


def get_class_weight(labels_dict):
    total = sum(labels_dict.values())
    max_num = max(labels_dict.values())
    mu = 1.0 / (total / max_num)
    class_weight = dict()
    for key, value in labels_dict.items():
        score = math.log(mu * total / float(value))
        class_weight[key] = score if score > 1.0 else 1.0
    return class_weight


class ECGDataloader:
    testdata_path: str
    traindata_path: str
    valdata_path: str

    def __init__(self, data_path, data_type, hparams):
        self.traindata_path = os.path.join(data_path, data_type, f'train.pt')
        self.testdata_path = os.path.join(data_path, data_type, f'test.pt')
        self.validdata_path = os.path.join(data_path, data_type, f'val.pt')
        self.batch_size = hparams['batch_size']

    def train_dataloader(self):
        train_dataset = torch.load(self.traindata_path)
        train_dataset = load_ECG_Dataset(train_dataset)
        cw = train_dataset.y_data.numpy().tolist()
        cw_dict = {}
        for i in range(len(np.unique(train_dataset.y_data.numpy()))):
            cw_dict[i] = cw.count(i)

        train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True,
                                  drop_last=True, num_workers=4)
        return train_loader, get_class_weight(cw_dict)

    def test_dataloader(self):
        test_dataset = torch.load(self.testdata_path)
        test_dataset = load_ECG_Dataset(test_dataset)
        test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False,
                                 drop_last=False, num_workers=4)
        return test_loader

    def valid_dataloader(self):
        valid_dataset = torch.load(self.validdata_path)
        valid_dataset = load_ECG_Dataset(valid_dataset)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True,
                                  drop_last=False, num_workers=4)
        return valid_loader