IlayMalinyak
kan
49ebc1f
raw
history blame
1.83 kB
import torch
from torch.utils.data import IterableDataset
from torch.fft import fft
import torch.nn.functional as F
from itertools import tee
import random
import torchaudio.transforms as T
class SplitDataset(IterableDataset):
def __init__(self, dataset, is_train=True, train_ratio=0.8):
self.dataset = dataset
self.is_train = is_train
self.train_ratio = train_ratio
def __iter__(self):
count = 0
for item in self.dataset:
# For first train_ratio portion of items, yield to train
# For remaining items, yield to validation
is_train_item = count < int(self.train_ratio * 100)
if is_train_item == self.is_train:
yield item
count = (count + 1) % 100
class FFTDataset(IterableDataset):
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
self.dataset = original_dataset
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
self.max_len = max_len
def __iter__(self):
for item in self.dataset:
# Assuming your audio data is in item['audio']
# Modify this based on your actual data structure
audio_data = torch.tensor(item['audio']['array']).float()
# pad audio
# if len(audio_data) == 0:
# continue
pad_len = self.max_len - len(audio_data)
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
audio_data = self.resampler(audio_data)
fft_data = fft(audio_data)
# Update the item with FFT data
item['audio']['fft'] = fft_data
item['audio']['array'] = audio_data
yield item