Spaces:
Sleeping
Sleeping
File size: 5,876 Bytes
b3fb4dd 2f54ec8 49ebc1f b3fb4dd 2f54ec8 b3fb4dd 2f54ec8 b3fb4dd 766ed77 a79c5f2 b3fb4dd 2f54ec8 49ebc1f 766ed77 2f54ec8 b3fb4dd 2f54ec8 49ebc1f 2f54ec8 49ebc1f 2f54ec8 766ed77 2f54ec8 b3fb4dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch
from torch.utils.data import IterableDataset
from torch.fft import fft, fftshift
import torch.nn.functional as F
from itertools import tee
import random
import torchaudio.transforms as T
import hashlib
from typing import NamedTuple, Tuple, Union
from .transforms import compute_all_features
from scipy.signal import savgol_filter as savgol
class WeightsBatch(NamedTuple):
weights: Tuple
biases: Tuple
label: Union[torch.Tensor, int]
def _assert_same_len(self):
assert len(set([len(t) for t in self])) == 1
def as_dict(self):
return self._asdict()
def to(self, device):
"""move batch to device"""
return self.__class__(
weights=tuple(w.to(device) for w in self.weights),
biases=tuple(w.to(device) for w in self.biases),
label=self.label.to(device),
)
def __len__(self):
return len(self.weights[0])
class SplitDataset(IterableDataset):
def __init__(self, dataset, is_train=True, train_ratio=0.8):
self.dataset = dataset
self.is_train = is_train
self.train_ratio = train_ratio
def __iter__(self):
count = 0
for item in self.dataset:
# For first train_ratio portion of items, yield to train
# For remaining items, yield to validation
is_train_item = count < int(self.train_ratio * 100)
if is_train_item == self.is_train:
yield item
count = (count + 1) % 100
class FFTDataset(IterableDataset):
def __init__(self, original_dataset,
max_len=72000,
orig_sample_rate=12000,
target_sample_rate=3000,
features=False):
super().__init__()
self.dataset = original_dataset
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
self.target_sample_rate = target_sample_rate
self.max_len = max_len
self.features = features
def normalize_audio(self, audio):
"""Normalize audio to [0, 1] range"""
audio_min = audio.min()
audio_max = audio.max()
audio = (audio - audio_min) / (audio_max - audio_min)
return audio
def generate_unique_id(self, array):
# Convert the array to bytes
array_bytes = array.tobytes()
# Hash the bytes using SHA256
hash_object = hashlib.sha256(array_bytes)
# Return the hexadecimal representation of the hash
return hash_object.hexdigest()
def __iter__(self):
for item in self.dataset:
# audio_data = savgol(item['audio']['array'], 500, polyorder=1)
audio_data = item['audio']['array']
# item['id'] = self.generate_unique_id(audio_data)
audio_data = torch.tensor(audio_data).float()
pad_len = self.max_len - len(audio_data)
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
audio_data = self.resampler(audio_data)
audio_data = self.normalize_audio(audio_data)
fft_data = fft(audio_data)
magnitude = torch.abs(fft_data)
phase = torch.angle(fft_data)
if self.features:
features = compute_all_features(audio_data, sample_rate=self.target_sample_rate)
# features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()])
item['audio']['features'] = features
magnitude_centered = fftshift(magnitude)
phase_centered = fftshift(phase)
# cwt = features['cwt_power']
# Optionally, remove the DC component
magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero
item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0)
item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0)
# item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0)
item['audio']['array'] = torch.nan_to_num(audio_data, 0)
# item['audio']['features'] = features
yield item
class AudioINRDataset(IterableDataset):
def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True):
"""
Convert audio data into coordinate-value pairs for INR training.
Args:
original_dataset: Original audio dataset
max_len: Maximum length of audio to process
batch_size: Number of points to sample per audio clip
normalize: Whether to normalize the audio values to [0, 1]
"""
self.dataset = original_dataset
self.max_len = max_len
self.dim = dim
self.normalize = normalize
self.sample_size = sample_size
def get_coordinates(self, audio_len):
"""Generate time coordinates"""
# Create normalized time coordinates in [0, 1]
coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim)
return coords # Shape: [audio_len, 1]
def sample_points(self, coords, values):
"""Randomly sample points from the audio"""
if len(coords) > self.sample_size:
idx = torch.randperm(len(coords))[:self.sample_size]
coords = coords[idx]
values = values[idx]
return coords, values
def __iter__(self):
for item in self.dataset:
# Get audio data
audio_data = torch.tensor(item['audio']['array']).float()
# Generate coordinates
coords = self.get_coordinates(len(audio_data))
item['audio']['coords'] = coords
# Sample random points
# coords, values = self.sample_points(coords, audio_data)
# Create the INR training sample
yield item |