File size: 5,876 Bytes
b3fb4dd
 
2f54ec8
49ebc1f
b3fb4dd
 
 
2f54ec8
 
 
b3fb4dd
2f54ec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fb4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766ed77
 
 
 
 
a79c5f2
b3fb4dd
 
2f54ec8
49ebc1f
766ed77
2f54ec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fb4dd
 
2f54ec8
 
 
 
 
49ebc1f
 
 
2f54ec8
 
49ebc1f
2f54ec8
 
766ed77
 
 
 
2f54ec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fb4dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
from torch.utils.data import IterableDataset
from torch.fft import fft, fftshift
import torch.nn.functional as F
from itertools import tee
import random
import torchaudio.transforms as T
import hashlib
from typing import NamedTuple, Tuple, Union
from .transforms import compute_all_features

from scipy.signal import savgol_filter as savgol


class WeightsBatch(NamedTuple):
    weights: Tuple
    biases: Tuple
    label: Union[torch.Tensor, int]

    def _assert_same_len(self):
        assert len(set([len(t) for t in self])) == 1

    def as_dict(self):
        return self._asdict()

    def to(self, device):
        """move batch to device"""
        return self.__class__(
            weights=tuple(w.to(device) for w in self.weights),
            biases=tuple(w.to(device) for w in self.biases),
            label=self.label.to(device),
        )

    def __len__(self):
        return len(self.weights[0])

class SplitDataset(IterableDataset):
    def __init__(self, dataset, is_train=True, train_ratio=0.8):
        self.dataset = dataset
        self.is_train = is_train
        self.train_ratio = train_ratio
        
    def __iter__(self):
        count = 0
        for item in self.dataset:
            # For first train_ratio portion of items, yield to train
            # For remaining items, yield to validation
            is_train_item = count < int(self.train_ratio * 100)
            if is_train_item == self.is_train:
                yield item
            count = (count + 1) % 100


class FFTDataset(IterableDataset):
    def __init__(self, original_dataset,
                 max_len=72000,
                 orig_sample_rate=12000,
                 target_sample_rate=3000,
                 features=False):
        super().__init__()
        self.dataset = original_dataset
        self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
        self.target_sample_rate = target_sample_rate
        self.max_len = max_len
        self.features = features


    def normalize_audio(self, audio):
        """Normalize audio to [0, 1] range"""
        audio_min = audio.min()
        audio_max = audio.max()
        audio = (audio - audio_min) / (audio_max - audio_min)
        return audio

    def generate_unique_id(self, array):
        # Convert the array to bytes
        array_bytes = array.tobytes()
        # Hash the bytes using SHA256
        hash_object = hashlib.sha256(array_bytes)
        # Return the hexadecimal representation of the hash
        return hash_object.hexdigest()

    def __iter__(self):
        for item in self.dataset:
            # audio_data = savgol(item['audio']['array'], 500, polyorder=1)
            audio_data = item['audio']['array']
            # item['id'] = self.generate_unique_id(audio_data)
            audio_data = torch.tensor(audio_data).float()

            pad_len = self.max_len - len(audio_data)
            audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
            audio_data = self.resampler(audio_data)

            audio_data = self.normalize_audio(audio_data)
            fft_data = fft(audio_data)
            magnitude = torch.abs(fft_data)
            phase = torch.angle(fft_data)
            if self.features:
                features = compute_all_features(audio_data, sample_rate=self.target_sample_rate)
                # features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()])
                item['audio']['features'] = features
            magnitude_centered = fftshift(magnitude)
            phase_centered = fftshift(phase)
            # cwt = features['cwt_power']

            # Optionally, remove the DC component
            magnitude_centered[len(magnitude_centered) // 2] = 0  # Set DC component to zero

            item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0)
            item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0)
            # item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0)
            item['audio']['array'] = torch.nan_to_num(audio_data, 0)
            # item['audio']['features'] = features
            yield item


class AudioINRDataset(IterableDataset):
    def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True):
        """
        Convert audio data into coordinate-value pairs for INR training.

        Args:
            original_dataset: Original audio dataset
            max_len: Maximum length of audio to process
            batch_size: Number of points to sample per audio clip
            normalize: Whether to normalize the audio values to [0, 1]
        """
        self.dataset = original_dataset
        self.max_len = max_len
        self.dim = dim
        self.normalize = normalize
        self.sample_size = sample_size

    def get_coordinates(self, audio_len):
        """Generate time coordinates"""
        # Create normalized time coordinates in [0, 1]
        coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim)
        return coords  # Shape: [audio_len, 1]

    def sample_points(self, coords, values):
        """Randomly sample points from the audio"""
        if len(coords) > self.sample_size:
            idx = torch.randperm(len(coords))[:self.sample_size]
            coords = coords[idx]
            values = values[idx]
        return coords, values

    def __iter__(self):
        for item in self.dataset:
            # Get audio data
            audio_data = torch.tensor(item['audio']['array']).float()

            # Generate coordinates
            coords = self.get_coordinates(len(audio_data))

            item['audio']['coords'] = coords

            # Sample random points
            # coords, values = self.sample_points(coords, audio_data)

            # Create the INR training sample
            yield item