import torch from torch.utils.data import IterableDataset from torch.fft import fft, fftshift import torch.nn.functional as F from itertools import tee import random import torchaudio.transforms as T import hashlib from typing import NamedTuple, Tuple, Union from .transforms import compute_all_features from scipy.signal import savgol_filter as savgol class WeightsBatch(NamedTuple): weights: Tuple biases: Tuple label: Union[torch.Tensor, int] def _assert_same_len(self): assert len(set([len(t) for t in self])) == 1 def as_dict(self): return self._asdict() def to(self, device): """move batch to device""" return self.__class__( weights=tuple(w.to(device) for w in self.weights), biases=tuple(w.to(device) for w in self.biases), label=self.label.to(device), ) def __len__(self): return len(self.weights[0]) class SplitDataset(IterableDataset): def __init__(self, dataset, is_train=True, train_ratio=0.8): self.dataset = dataset self.is_train = is_train self.train_ratio = train_ratio def __iter__(self): count = 0 for item in self.dataset: # For first train_ratio portion of items, yield to train # For remaining items, yield to validation is_train_item = count < int(self.train_ratio * 100) if is_train_item == self.is_train: yield item count = (count + 1) % 100 class FFTDataset(IterableDataset): def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000, features=False): super().__init__() self.dataset = original_dataset self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate) self.target_sample_rate = target_sample_rate self.max_len = max_len self.features = features def normalize_audio(self, audio): """Normalize audio to [0, 1] range""" audio_min = audio.min() audio_max = audio.max() audio = (audio - audio_min) / (audio_max - audio_min) return audio def generate_unique_id(self, array): # Convert the array to bytes array_bytes = array.tobytes() # Hash the bytes using SHA256 hash_object = hashlib.sha256(array_bytes) # Return the hexadecimal representation of the hash return hash_object.hexdigest() def __iter__(self): for item in self.dataset: # audio_data = savgol(item['audio']['array'], 500, polyorder=1) audio_data = item['audio']['array'] # item['id'] = self.generate_unique_id(audio_data) audio_data = torch.tensor(audio_data).float() pad_len = self.max_len - len(audio_data) audio_data = F.pad(audio_data, (0, pad_len), mode='constant') audio_data = self.resampler(audio_data) audio_data = self.normalize_audio(audio_data) fft_data = fft(audio_data) magnitude = torch.abs(fft_data) phase = torch.angle(fft_data) if self.features: features = compute_all_features(audio_data, sample_rate=self.target_sample_rate) # features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()]) item['audio']['features'] = features magnitude_centered = fftshift(magnitude) phase_centered = fftshift(phase) # cwt = features['cwt_power'] # Optionally, remove the DC component magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0) item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0) # item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0) item['audio']['array'] = torch.nan_to_num(audio_data, 0) # item['audio']['features'] = features yield item class AudioINRDataset(IterableDataset): def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True): """ Convert audio data into coordinate-value pairs for INR training. Args: original_dataset: Original audio dataset max_len: Maximum length of audio to process batch_size: Number of points to sample per audio clip normalize: Whether to normalize the audio values to [0, 1] """ self.dataset = original_dataset self.max_len = max_len self.dim = dim self.normalize = normalize self.sample_size = sample_size def get_coordinates(self, audio_len): """Generate time coordinates""" # Create normalized time coordinates in [0, 1] coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim) return coords # Shape: [audio_len, 1] def sample_points(self, coords, values): """Randomly sample points from the audio""" if len(coords) > self.sample_size: idx = torch.randperm(len(coords))[:self.sample_size] coords = coords[idx] values = values[idx] return coords, values def __iter__(self): for item in self.dataset: # Get audio data audio_data = torch.tensor(item['audio']['array']).float() # Generate coordinates coords = self.get_coordinates(len(audio_data)) item['audio']['coords'] = coords # Sample random points # coords, values = self.sample_points(coords, audio_data) # Create the INR training sample yield item