Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import IterableDataset | |
from torch.fft import fft, fftshift | |
import torch.nn.functional as F | |
from itertools import tee | |
import random | |
import torchaudio.transforms as T | |
import hashlib | |
from typing import NamedTuple, Tuple, Union | |
from .transforms import compute_all_features | |
from scipy.signal import savgol_filter as savgol | |
class WeightsBatch(NamedTuple): | |
weights: Tuple | |
biases: Tuple | |
label: Union[torch.Tensor, int] | |
def _assert_same_len(self): | |
assert len(set([len(t) for t in self])) == 1 | |
def as_dict(self): | |
return self._asdict() | |
def to(self, device): | |
"""move batch to device""" | |
return self.__class__( | |
weights=tuple(w.to(device) for w in self.weights), | |
biases=tuple(w.to(device) for w in self.biases), | |
label=self.label.to(device), | |
) | |
def __len__(self): | |
return len(self.weights[0]) | |
class SplitDataset(IterableDataset): | |
def __init__(self, dataset, is_train=True, train_ratio=0.8): | |
self.dataset = dataset | |
self.is_train = is_train | |
self.train_ratio = train_ratio | |
def __iter__(self): | |
count = 0 | |
for item in self.dataset: | |
# For first train_ratio portion of items, yield to train | |
# For remaining items, yield to validation | |
is_train_item = count < int(self.train_ratio * 100) | |
if is_train_item == self.is_train: | |
yield item | |
count = (count + 1) % 100 | |
class FFTDataset(IterableDataset): | |
def __init__(self, original_dataset, | |
max_len=72000, | |
orig_sample_rate=12000, | |
target_sample_rate=3000, | |
features=False): | |
super().__init__() | |
self.dataset = original_dataset | |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate) | |
self.target_sample_rate = target_sample_rate | |
self.max_len = max_len | |
self.features = features | |
def normalize_audio(self, audio): | |
"""Normalize audio to [0, 1] range""" | |
audio_min = audio.min() | |
audio_max = audio.max() | |
audio = (audio - audio_min) / (audio_max - audio_min) | |
return audio | |
def generate_unique_id(self, array): | |
# Convert the array to bytes | |
array_bytes = array.tobytes() | |
# Hash the bytes using SHA256 | |
hash_object = hashlib.sha256(array_bytes) | |
# Return the hexadecimal representation of the hash | |
return hash_object.hexdigest() | |
def __iter__(self): | |
for item in self.dataset: | |
# audio_data = savgol(item['audio']['array'], 500, polyorder=1) | |
audio_data = item['audio']['array'] | |
# item['id'] = self.generate_unique_id(audio_data) | |
audio_data = torch.tensor(audio_data).float() | |
pad_len = self.max_len - len(audio_data) | |
audio_data = F.pad(audio_data, (0, pad_len), mode='constant') | |
audio_data = self.resampler(audio_data) | |
audio_data = self.normalize_audio(audio_data) | |
fft_data = fft(audio_data) | |
magnitude = torch.abs(fft_data) | |
phase = torch.angle(fft_data) | |
if self.features: | |
features = compute_all_features(audio_data, sample_rate=self.target_sample_rate) | |
# features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()]) | |
item['audio']['features'] = features | |
magnitude_centered = fftshift(magnitude) | |
phase_centered = fftshift(phase) | |
# cwt = features['cwt_power'] | |
# Optionally, remove the DC component | |
magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero | |
item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0) | |
item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0) | |
# item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0) | |
item['audio']['array'] = torch.nan_to_num(audio_data, 0) | |
# item['audio']['features'] = features | |
yield item | |
class AudioINRDataset(IterableDataset): | |
def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True): | |
""" | |
Convert audio data into coordinate-value pairs for INR training. | |
Args: | |
original_dataset: Original audio dataset | |
max_len: Maximum length of audio to process | |
batch_size: Number of points to sample per audio clip | |
normalize: Whether to normalize the audio values to [0, 1] | |
""" | |
self.dataset = original_dataset | |
self.max_len = max_len | |
self.dim = dim | |
self.normalize = normalize | |
self.sample_size = sample_size | |
def get_coordinates(self, audio_len): | |
"""Generate time coordinates""" | |
# Create normalized time coordinates in [0, 1] | |
coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim) | |
return coords # Shape: [audio_len, 1] | |
def sample_points(self, coords, values): | |
"""Randomly sample points from the audio""" | |
if len(coords) > self.sample_size: | |
idx = torch.randperm(len(coords))[:self.sample_size] | |
coords = coords[idx] | |
values = values[idx] | |
return coords, values | |
def __iter__(self): | |
for item in self.dataset: | |
# Get audio data | |
audio_data = torch.tensor(item['audio']['array']).float() | |
# Generate coordinates | |
coords = self.get_coordinates(len(audio_data)) | |
item['audio']['coords'] = coords | |
# Sample random points | |
# coords, values = self.sample_points(coords, audio_data) | |
# Create the INR training sample | |
yield item |