Spaces:
Runtime error
Runtime error
from sklearn.base import ClassifierMixin, BaseEstimator | |
import pandas as pd | |
from torch import nn | |
import torch | |
from typing import Iterator | |
import numpy as np | |
import json | |
from tqdm import tqdm | |
import librosa | |
DANCE_INFO_FILE = "data/dance_info.csv" | |
dance_info_df = pd.read_csv( | |
DANCE_INFO_FILE, | |
converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))}, | |
) | |
class DanceTreeClassifier(BaseEstimator, ClassifierMixin): | |
""" | |
Trains a series of binary classifiers to classify each dance when a song falls into its bpm range. | |
Features: | |
- Spectrogram | |
- BPM | |
""" | |
def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None: | |
self.device = device | |
self.epochs = epochs | |
self.verbose = verbose | |
self.lr = lr | |
self.classifiers = {} | |
self.optimizers = {} | |
self.criterion = nn.BCELoss() | |
def get_valid_dances_from_bpm(self, bpm: float) -> list[str]: | |
mask = dance_info_df["tempoRange"].apply( | |
lambda interval: interval["min"] <= bpm <= interval["max"] | |
) | |
return list(dance_info_df["id"][mask]) | |
def fit(self, x, y): | |
""" | |
x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time) | |
y: (batch_size, n_classes) | |
""" | |
progress_bar = tqdm(range(self.epochs)) | |
for _ in progress_bar: | |
# TODO: Introduce batches | |
epoch_loss = 0 | |
pred_count = 0 | |
step = 0 | |
for (spec, bpm), label in zip(x, y): | |
step += 1 | |
# find all models that are in the bpm range | |
matching_dances = self.get_valid_dances_from_bpm(bpm) | |
spec = torch.from_numpy(spec).to(self.device) | |
for dance in matching_dances: | |
if dance not in self.classifiers or dance not in self.optimizers: | |
classifier = DanceCNN().to(self.device) | |
self.classifiers[dance] = classifier | |
self.optimizers[dance] = torch.optim.Adam( | |
classifier.parameters(), lr=self.lr | |
) | |
models = [ | |
(dance, model, self.optimizers[dance]) | |
for dance, model in self.classifiers.items() | |
if dance in matching_dances | |
] | |
for model_i, (dance, model, opt) in enumerate(models): | |
opt.zero_grad() | |
output = model(spec) | |
target = torch.tensor([float(dance == label)], device=self.device) | |
loss = self.criterion(output, target) | |
epoch_loss += loss.item() | |
pred_count += 1 | |
loss.backward() | |
opt.step() | |
progress_bar.set_description( | |
f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}" | |
) | |
def predict(self, x) -> list[str]: | |
results = [] | |
for spec, bpm in zip(*x): | |
matching_dances = self.get_valid_dances_from_bpm(bpm) | |
dance_i = torch.tensor( | |
[self.classifiers[dance](spec) for dance in matching_dances] | |
).argmax() | |
results.append(matching_dances[dance_i]) | |
return results | |
class DanceCNN(nn.Module): | |
def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
kernel_size = (3, 9) | |
self.cnn = nn.Sequential( | |
nn.Conv2d(1, 16, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(16, 32, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(32, 32, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
nn.Conv2d(32, 16, kernel_size=kernel_size), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 10)), | |
) | |
embedding_dimension = 16 * 6 * 8 | |
self.classifier = nn.Sequential( | |
nn.Linear(embedding_dimension, 200), | |
nn.ReLU(), | |
nn.Linear(200, 1), | |
nn.Sigmoid(), | |
) | |
def forward(self, x): | |
x = self.cnn(x) | |
x = x.flatten() if len(x.shape) == 3 else x.flatten(1) | |
return self.classifier(x) | |
def features_from_path( | |
paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000 | |
) -> Iterator[tuple[np.array, float]]: | |
""" | |
Loads audio and bpm from an audio path. | |
""" | |
for path in paths: | |
waveform, sr = librosa.load(path, mono=True, sr=resample_freq) | |
num_frames = audio_window_duration * sr | |
tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr) | |
spec = librosa.feature.melspectrogram(y=waveform, sr=sr) | |
mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20) | |
spec_normalized = (spec - spec.mean()) / spec.std() | |
spec_padded = librosa.util.fix_length( | |
spec_normalized, size=sr * audio_duration, axis=1 | |
) | |
batched_spec = np.expand_dims(spec_padded, axis=0) | |
for i in range(audio_duration // audio_window_duration): | |
spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames] | |
yield (spec_window, tempo) | |