Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import IterableDataset | |
from torch.fft import fft | |
import torch.nn.functional as F | |
from itertools import tee | |
import random | |
import torchaudio.transforms as T | |
class SplitDataset(IterableDataset): | |
def __init__(self, dataset, is_train=True, train_ratio=0.8): | |
self.dataset = dataset | |
self.is_train = is_train | |
self.train_ratio = train_ratio | |
def __iter__(self): | |
count = 0 | |
for item in self.dataset: | |
# For first train_ratio portion of items, yield to train | |
# For remaining items, yield to validation | |
is_train_item = count < int(self.train_ratio * 100) | |
if is_train_item == self.is_train: | |
yield item | |
count = (count + 1) % 100 | |
class FFTDataset(IterableDataset): | |
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000): | |
self.dataset = original_dataset | |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate) | |
self.max_len = max_len | |
def __iter__(self): | |
for item in self.dataset: | |
# Assuming your audio data is in item['audio'] | |
# Modify this based on your actual data structure | |
audio_data = torch.tensor(item['audio']['array']).float() | |
# pad audio | |
# if len(audio_data) == 0: | |
# continue | |
pad_len = self.max_len - len(audio_data) | |
audio_data = F.pad(audio_data, (0, pad_len), mode='constant') | |
audio_data = self.resampler(audio_data) | |
fft_data = fft(audio_data) | |
# Update the item with FFT data | |
item['audio']['fft'] = fft_data | |
item['audio']['array'] = audio_data | |
yield item |