File size: 574 Bytes
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import pandas as pd
import torch

from utmosv2.dataset import MultiSpecDataset, SSLExtDataset


class SSLLMultiSpecExtDataset(torch.utils.data.Dataset):
    def __init__(self, cfg, data: pd.DataFrame, phase: str, transform=None):
        self.data = data
        self.ssl = SSLExtDataset(cfg, data, phase)
        self.multi_spec = MultiSpecDataset(cfg, data, phase, transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x1, d, target = self.ssl[idx]
        x2, _ = self.multi_spec[idx]

        return x1, x2, d, target