IlayMalinyak
tested locally
a79c5f2
import torch
from torch.utils.data import IterableDataset
from torch.fft import fft, fftshift
import torch.nn.functional as F
from itertools import tee
import random
import torchaudio.transforms as T
import hashlib
from typing import NamedTuple, Tuple, Union
from .transforms import compute_all_features
from scipy.signal import savgol_filter as savgol
class WeightsBatch(NamedTuple):
weights: Tuple
biases: Tuple
label: Union[torch.Tensor, int]
def _assert_same_len(self):
assert len(set([len(t) for t in self])) == 1
def as_dict(self):
return self._asdict()
def to(self, device):
"""move batch to device"""
return self.__class__(
weights=tuple(w.to(device) for w in self.weights),
biases=tuple(w.to(device) for w in self.biases),
label=self.label.to(device),
)
def __len__(self):
return len(self.weights[0])
class SplitDataset(IterableDataset):
def __init__(self, dataset, is_train=True, train_ratio=0.8):
self.dataset = dataset
self.is_train = is_train
self.train_ratio = train_ratio
def __iter__(self):
count = 0
for item in self.dataset:
# For first train_ratio portion of items, yield to train
# For remaining items, yield to validation
is_train_item = count < int(self.train_ratio * 100)
if is_train_item == self.is_train:
yield item
count = (count + 1) % 100
class FFTDataset(IterableDataset):
def __init__(self, original_dataset,
max_len=72000,
orig_sample_rate=12000,
target_sample_rate=3000,
features=False):
super().__init__()
self.dataset = original_dataset
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
self.target_sample_rate = target_sample_rate
self.max_len = max_len
self.features = features
def normalize_audio(self, audio):
"""Normalize audio to [0, 1] range"""
audio_min = audio.min()
audio_max = audio.max()
audio = (audio - audio_min) / (audio_max - audio_min)
return audio
def generate_unique_id(self, array):
# Convert the array to bytes
array_bytes = array.tobytes()
# Hash the bytes using SHA256
hash_object = hashlib.sha256(array_bytes)
# Return the hexadecimal representation of the hash
return hash_object.hexdigest()
def __iter__(self):
for item in self.dataset:
# audio_data = savgol(item['audio']['array'], 500, polyorder=1)
audio_data = item['audio']['array']
# item['id'] = self.generate_unique_id(audio_data)
audio_data = torch.tensor(audio_data).float()
pad_len = self.max_len - len(audio_data)
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
audio_data = self.resampler(audio_data)
audio_data = self.normalize_audio(audio_data)
fft_data = fft(audio_data)
magnitude = torch.abs(fft_data)
phase = torch.angle(fft_data)
if self.features:
features = compute_all_features(audio_data, sample_rate=self.target_sample_rate)
# features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()])
item['audio']['features'] = features
magnitude_centered = fftshift(magnitude)
phase_centered = fftshift(phase)
# cwt = features['cwt_power']
# Optionally, remove the DC component
magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero
item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0)
item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0)
# item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0)
item['audio']['array'] = torch.nan_to_num(audio_data, 0)
# item['audio']['features'] = features
yield item
class AudioINRDataset(IterableDataset):
def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True):
"""
Convert audio data into coordinate-value pairs for INR training.
Args:
original_dataset: Original audio dataset
max_len: Maximum length of audio to process
batch_size: Number of points to sample per audio clip
normalize: Whether to normalize the audio values to [0, 1]
"""
self.dataset = original_dataset
self.max_len = max_len
self.dim = dim
self.normalize = normalize
self.sample_size = sample_size
def get_coordinates(self, audio_len):
"""Generate time coordinates"""
# Create normalized time coordinates in [0, 1]
coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim)
return coords # Shape: [audio_len, 1]
def sample_points(self, coords, values):
"""Randomly sample points from the audio"""
if len(coords) > self.sample_size:
idx = torch.randperm(len(coords))[:self.sample_size]
coords = coords[idx]
values = values[idx]
return coords, values
def __iter__(self):
for item in self.dataset:
# Get audio data
audio_data = torch.tensor(item['audio']['array']).float()
# Generate coordinates
coords = self.get_coordinates(len(audio_data))
item['audio']['coords'] = coords
# Sample random points
# coords, values = self.sample_points(coords, audio_data)
# Create the INR training sample
yield item