UTMOSv2 / utmosv2 /dataset /ssl_multispec.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame contribute delete
574 Bytes
import pandas as pd
import torch
from utmosv2.dataset import MultiSpecDataset, SSLExtDataset
class SSLLMultiSpecExtDataset(torch.utils.data.Dataset):
def __init__(self, cfg, data: pd.DataFrame, phase: str, transform=None):
self.data = data
self.ssl = SSLExtDataset(cfg, data, phase)
self.multi_spec = MultiSpecDataset(cfg, data, phase, transform)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x1, d, target = self.ssl[idx]
x2, _ = self.multi_spec[idx]
return x1, x2, d, target