Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
from torchaudio.transforms import FrequencyMasking | |
from tja import parse_tja, PyParsingMode | |
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB | |
from .model import TaikoConformer6 | |
mel_transform = torchaudio.transforms.MelSpectrogram( | |
sample_rate=SAMPLE_RATE, | |
n_mels=N_MELS, | |
hop_length=HOP_LENGTH, | |
n_fft=2048, | |
) | |
freq_mask = FrequencyMasking(freq_mask_param=15) | |
def preprocess(example, difficulty="oni"): | |
wav_tensor = example["audio"]["array"] | |
sr = example["audio"]["sampling_rate"] | |
# 1) load & resample | |
if sr != SAMPLE_RATE: | |
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE) | |
# normalize audio | |
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8) | |
# add random Gaussian noise | |
if torch.rand(1).item() < 0.5: | |
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor) | |
# 2) mel: (1, N_MELS, T) | |
mel = mel_transform(wav_tensor).unsqueeze(0) | |
# apply SpecAugment | |
mel = freq_mask(mel) | |
_, _, T = mel.shape | |
# 3) build label sequence of length ceil(T / TIME_SUB) | |
T_sub = math.ceil(T / TIME_SUB) | |
# Initialize energy-based labels for Don, Ka, Drumroll | |
don_labels = torch.zeros(T_sub, dtype=torch.float32) | |
ka_labels = torch.zeros(T_sub, dtype=torch.float32) | |
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32) | |
# Define exponential decay tail parameters | |
tail_length = 40 # number of frames for decay tail | |
decay_rate = 8.0 # decay rate parameter, adjust as needed | |
tail_kernel = torch.exp( | |
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate | |
) | |
fps = SAMPLE_RATE / HOP_LENGTH | |
num_valid_notes = 0 | |
for onset in example[difficulty]: | |
typ, t_start, t_end, *_ = onset | |
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more) | |
if typ < 1 or typ > N_TYPES: # Filter out invalid types | |
continue | |
num_valid_notes += 1 | |
exact_frame_start = t_start.item() * fps | |
# Type 1 and 3 are Don, Type 2 and 4 are Ka | |
if typ == 1 or typ == 3 or typ == 2 or typ == 4: | |
exact_hit_time_sub = exact_frame_start / TIME_SUB | |
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels | |
start_points_info = [] | |
rounded_hit_time_sub = round(exact_hit_time_sub) | |
if ( | |
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6 | |
): # Tolerance for float precision | |
idx_single = int(rounded_hit_time_sub) | |
if 0 <= idx_single < T_sub: | |
start_points_info.append({"idx": idx_single, "weight": 1.0}) | |
else: | |
idx_floor = math.floor(exact_hit_time_sub) | |
idx_ceil = idx_floor + 1 | |
frac = exact_hit_time_sub - idx_floor | |
weight_ceil = frac | |
weight_floor = 1.0 - frac | |
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub: | |
start_points_info.append({"idx": idx_floor, "weight": weight_floor}) | |
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub: | |
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil}) | |
for point_info in start_points_info: | |
start_idx = point_info["idx"] | |
weight = point_info["weight"] | |
for k_idx, kernel_val in enumerate(tail_kernel): | |
target_idx = start_idx + k_idx | |
if 0 <= target_idx < T_sub: | |
current_labels[target_idx] = max( | |
current_labels[target_idx].item(), | |
weight * kernel_val.item(), | |
) | |
# Type 5, 6, 7 are Drumroll | |
elif typ >= 5 and typ <= 7: | |
exact_frame_end = t_end.item() * fps | |
exact_start_time_sub = exact_frame_start / TIME_SUB | |
exact_end_time_sub = exact_frame_end / TIME_SUB | |
# Improved drumroll body | |
body_loop_start_idx = math.floor(exact_start_time_sub) | |
body_loop_end_idx = math.ceil(exact_end_time_sub) | |
for dr_idx in range(body_loop_start_idx, body_loop_end_idx): | |
if 0 <= dr_idx < T_sub: | |
drumroll_labels[dr_idx] = 1.0 | |
# Improved drumroll tail (starts from exact_end_time_sub) | |
tail_start_points_info = [] | |
rounded_end_time_sub = round(exact_end_time_sub) | |
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6: | |
idx_single_tail = int(rounded_end_time_sub) | |
if 0 <= idx_single_tail < T_sub: | |
tail_start_points_info.append( | |
{"idx": idx_single_tail, "weight": 1.0} | |
) | |
else: | |
idx_floor_tail = math.floor(exact_end_time_sub) | |
idx_ceil_tail = idx_floor_tail + 1 | |
frac_tail = exact_end_time_sub - idx_floor_tail | |
weight_ceil_tail = frac_tail | |
weight_floor_tail = 1.0 - frac_tail | |
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub: | |
tail_start_points_info.append( | |
{"idx": idx_floor_tail, "weight": weight_floor_tail} | |
) | |
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub: | |
tail_start_points_info.append( | |
{"idx": idx_ceil_tail, "weight": weight_ceil_tail} | |
) | |
for point_info in tail_start_points_info: | |
start_idx = point_info["idx"] | |
weight = point_info["weight"] | |
for k_idx, kernel_val in enumerate(tail_kernel): | |
target_idx = start_idx + k_idx | |
if 0 <= target_idx < T_sub: | |
drumroll_labels[target_idx] = max( | |
drumroll_labels[target_idx].item(), | |
weight * kernel_val.item(), | |
) | |
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE | |
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0 | |
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full) | |
chart = next( | |
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None | |
) | |
difficulty_id = ( | |
0 | |
if difficulty == "easy" | |
else ( | |
1 | |
if difficulty == "normal" | |
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4 | |
) # Assuming 4 for edit/ura | |
) | |
level = chart.level if chart else 0 | |
# --- CNN shape inference and label padding/truncation --- | |
# Simulate CNN to get output time length (T_cnn) | |
dummy_model = TaikoConformer6() | |
with torch.no_grad(): | |
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn) | |
_, _, _, T_cnn = cnn_out.shape | |
# Pad or truncate labels to T_cnn | |
def pad_or_truncate(label, out_len): | |
if label.shape[0] < out_len: | |
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype) | |
return torch.cat([label, pad], dim=0) | |
else: | |
return label[:out_len] | |
don_labels = pad_or_truncate(don_labels, T_cnn) | |
ka_labels = pad_or_truncate(ka_labels, T_cnn) | |
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn) | |
# For conformer input lengths: based on original mel shape (before CNN) | |
conformer_input_length = min(math.ceil(T / TIME_SUB), T_cnn) | |
print( | |
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}" | |
) | |
return { | |
"mel": mel, # (1, N_MELS, T) | |
"don_labels": don_labels, # (T_cnn,) | |
"ka_labels": ka_labels, # (T_cnn,) | |
"drumroll_labels": drumroll_labels, # (T_cnn,) | |
"nps": torch.tensor(nps, dtype=torch.float32), | |
"difficulty": torch.tensor(difficulty_id, dtype=torch.long), | |
"level": torch.tensor(level, dtype=torch.long), | |
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32), | |
"length": torch.tensor( | |
conformer_input_length, dtype=torch.long | |
), # for conformer | |
} | |
def collate_fn(batch): | |
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS) | |
don_labels_list = [b["don_labels"] for b in batch] | |
ka_labels_list = [b["ka_labels"] for b in batch] | |
drumroll_labels_list = [b["drumroll_labels"] for b in batch] | |
nps_list = [b["nps"] for b in batch] | |
difficulty_list = [b["difficulty"] for b in batch] | |
level_list = [b["level"] for b in batch] | |
durations_list = [b["duration_seconds"] for b in batch] | |
lengths_list = [b["length"] for b in batch] | |
# Pad mels | |
padded_mels = nn.utils.rnn.pad_sequence( | |
mels_list, batch_first=True | |
) # (B, T_max, N_MELS) | |
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1) | |
T_max = padded_mels.shape[1] | |
# Pad labels to T_max | |
def pad_label(label, out_len): | |
if label.shape[0] < out_len: | |
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype) | |
return torch.cat([label, pad], dim=0) | |
else: | |
return label[:out_len] | |
don_labels = torch.stack([pad_label(l, T_max) for l in don_labels_list]) | |
ka_labels = torch.stack([pad_label(l, T_max) for l in ka_labels_list]) | |
drumroll_labels = torch.stack([pad_label(l, T_max) for l in drumroll_labels_list]) | |
lengths = torch.tensor( | |
[min(l.item(), T_max) for l in lengths_list], dtype=torch.long | |
) | |
return { | |
"mel": reshaped_mels, | |
"don_labels": don_labels, | |
"ka_labels": ka_labels, | |
"drumroll_labels": drumroll_labels, | |
"lengths": lengths, # for conformer | |
"nps": torch.stack(nps_list), | |
"difficulty": torch.stack(difficulty_list), | |
"level": torch.stack(level_list), | |
"durations": torch.stack(durations_list), | |
} | |