File size: 10,132 Bytes
812b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import math
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import FrequencyMasking
from tja import parse_tja, PyParsingMode
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
from .model import TaikoConformer6


mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=N_MELS,
    hop_length=HOP_LENGTH,
    n_fft=2048,
)


freq_mask = FrequencyMasking(freq_mask_param=15)


def preprocess(example, difficulty="oni"):
    wav_tensor = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    # 1) load & resample
    if sr != SAMPLE_RATE:
        wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
    # normalize audio
    wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
    # add random Gaussian noise
    if torch.rand(1).item() < 0.5:
        wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
    # 2) mel: (1, N_MELS, T)
    mel = mel_transform(wav_tensor).unsqueeze(0)
    # apply SpecAugment
    mel = freq_mask(mel)
    _, _, T = mel.shape
    # 3) build label sequence of length ceil(T / TIME_SUB)
    T_sub = math.ceil(T / TIME_SUB)

    # Initialize energy-based labels for Don, Ka, Drumroll
    don_labels = torch.zeros(T_sub, dtype=torch.float32)
    ka_labels = torch.zeros(T_sub, dtype=torch.float32)
    drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)

    # Define exponential decay tail parameters
    tail_length = 40  # number of frames for decay tail
    decay_rate = 8.0  # decay rate parameter, adjust as needed
    tail_kernel = torch.exp(
        -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
    )

    fps = SAMPLE_RATE / HOP_LENGTH
    num_valid_notes = 0
    for onset in example[difficulty]:
        typ, t_start, t_end, *_ = onset

        # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
        if typ < 1 or typ > N_TYPES:  # Filter out invalid types
            continue

        num_valid_notes += 1

        exact_frame_start = t_start.item() * fps

        # Type 1 and 3 are Don, Type 2 and 4 are Ka
        if typ == 1 or typ == 3 or typ == 2 or typ == 4:
            exact_hit_time_sub = exact_frame_start / TIME_SUB

            current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels

            start_points_info = []
            rounded_hit_time_sub = round(exact_hit_time_sub)

            if (
                abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
            ):  # Tolerance for float precision
                idx_single = int(rounded_hit_time_sub)
                if 0 <= idx_single < T_sub:
                    start_points_info.append({"idx": idx_single, "weight": 1.0})
            else:
                idx_floor = math.floor(exact_hit_time_sub)
                idx_ceil = idx_floor + 1

                frac = exact_hit_time_sub - idx_floor
                weight_ceil = frac
                weight_floor = 1.0 - frac

                if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
                    start_points_info.append({"idx": idx_floor, "weight": weight_floor})
                if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
                    start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})

            for point_info in start_points_info:
                start_idx = point_info["idx"]
                weight = point_info["weight"]
                for k_idx, kernel_val in enumerate(tail_kernel):
                    target_idx = start_idx + k_idx
                    if 0 <= target_idx < T_sub:
                        current_labels[target_idx] = max(
                            current_labels[target_idx].item(),
                            weight * kernel_val.item(),
                        )

        # Type 5, 6, 7 are Drumroll
        elif typ >= 5 and typ <= 7:
            exact_frame_end = t_end.item() * fps
            exact_start_time_sub = exact_frame_start / TIME_SUB
            exact_end_time_sub = exact_frame_end / TIME_SUB

            # Improved drumroll body
            body_loop_start_idx = math.floor(exact_start_time_sub)
            body_loop_end_idx = math.ceil(exact_end_time_sub)

            for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
                if 0 <= dr_idx < T_sub:
                    drumroll_labels[dr_idx] = 1.0

            # Improved drumroll tail (starts from exact_end_time_sub)
            tail_start_points_info = []
            rounded_end_time_sub = round(exact_end_time_sub)
            if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
                idx_single_tail = int(rounded_end_time_sub)
                if 0 <= idx_single_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_single_tail, "weight": 1.0}
                    )
            else:
                idx_floor_tail = math.floor(exact_end_time_sub)
                idx_ceil_tail = idx_floor_tail + 1

                frac_tail = exact_end_time_sub - idx_floor_tail
                weight_ceil_tail = frac_tail
                weight_floor_tail = 1.0 - frac_tail

                if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_floor_tail, "weight": weight_floor_tail}
                    )
                if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
                    tail_start_points_info.append(
                        {"idx": idx_ceil_tail, "weight": weight_ceil_tail}
                    )

            for point_info in tail_start_points_info:
                start_idx = point_info["idx"]
                weight = point_info["weight"]
                for k_idx, kernel_val in enumerate(tail_kernel):
                    target_idx = start_idx + k_idx
                    if 0 <= target_idx < T_sub:
                        drumroll_labels[target_idx] = max(
                            drumroll_labels[target_idx].item(),
                            weight * kernel_val.item(),
                        )

    duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
    nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0

    parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
    chart = next(
        (chart for chart in parsed.charts if chart.course.lower() == difficulty), None
    )
    difficulty_id = (
        0
        if difficulty == "easy"
        else (
            1
            if difficulty == "normal"
            else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
        )  # Assuming 4 for edit/ura
    )
    level = chart.level if chart else 0

    # --- CNN shape inference and label padding/truncation ---
    # Simulate CNN to get output time length (T_cnn)
    dummy_model = TaikoConformer6()
    with torch.no_grad():
        cnn_out = dummy_model.cnn(mel.unsqueeze(0))  # (1, C, F, T_cnn)
        _, _, _, T_cnn = cnn_out.shape

    # Pad or truncate labels to T_cnn
    def pad_or_truncate(label, out_len):
        if label.shape[0] < out_len:
            pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
            return torch.cat([label, pad], dim=0)
        else:
            return label[:out_len]

    don_labels = pad_or_truncate(don_labels, T_cnn)
    ka_labels = pad_or_truncate(ka_labels, T_cnn)
    drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)

    # For conformer input lengths: based on original mel shape (before CNN)
    conformer_input_length = min(math.ceil(T / TIME_SUB), T_cnn)

    print(
        f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
    )

    return {
        "mel": mel,  # (1, N_MELS, T)
        "don_labels": don_labels,  # (T_cnn,)
        "ka_labels": ka_labels,  # (T_cnn,)
        "drumroll_labels": drumroll_labels,  # (T_cnn,)
        "nps": torch.tensor(nps, dtype=torch.float32),
        "difficulty": torch.tensor(difficulty_id, dtype=torch.long),
        "level": torch.tensor(level, dtype=torch.long),
        "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
        "length": torch.tensor(
            conformer_input_length, dtype=torch.long
        ),  # for conformer
    }


def collate_fn(batch):
    mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch]  # (T, N_MELS)
    don_labels_list = [b["don_labels"] for b in batch]
    ka_labels_list = [b["ka_labels"] for b in batch]
    drumroll_labels_list = [b["drumroll_labels"] for b in batch]
    nps_list = [b["nps"] for b in batch]
    difficulty_list = [b["difficulty"] for b in batch]
    level_list = [b["level"] for b in batch]
    durations_list = [b["duration_seconds"] for b in batch]
    lengths_list = [b["length"] for b in batch]

    # Pad mels
    padded_mels = nn.utils.rnn.pad_sequence(
        mels_list, batch_first=True
    )  # (B, T_max, N_MELS)
    reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
    T_max = padded_mels.shape[1]

    # Pad labels to T_max
    def pad_label(label, out_len):
        if label.shape[0] < out_len:
            pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
            return torch.cat([label, pad], dim=0)
        else:
            return label[:out_len]

    don_labels = torch.stack([pad_label(l, T_max) for l in don_labels_list])
    ka_labels = torch.stack([pad_label(l, T_max) for l in ka_labels_list])
    drumroll_labels = torch.stack([pad_label(l, T_max) for l in drumroll_labels_list])
    lengths = torch.tensor(
        [min(l.item(), T_max) for l in lengths_list], dtype=torch.long
    )

    return {
        "mel": reshaped_mels,
        "don_labels": don_labels,
        "ka_labels": ka_labels,
        "drumroll_labels": drumroll_labels,
        "lengths": lengths,  # for conformer
        "nps": torch.stack(nps_list),
        "difficulty": torch.stack(difficulty_list),
        "level": torch.stack(level_list),
        "durations": torch.stack(durations_list),
    }