Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset, DataLoader, random_split | |
import numpy as np | |
import pandas as pd | |
import torchaudio as ta | |
from .pipelines import AudioTrainingPipeline | |
import pytorch_lightning as pl | |
from .preprocess import get_examples | |
from sklearn.model_selection import train_test_split | |
from torchaudio import transforms as taT | |
from torch import nn | |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
class SongDataset(Dataset): | |
def __init__( | |
self, | |
audio_paths: list[str], | |
dance_labels: list[np.ndarray], | |
audio_duration=30, # seconds | |
audio_window_duration=6, # seconds | |
audio_window_jitter=0.0, # seconds | |
audio_pipeline_kwargs={}, | |
resample_frequency=16000, | |
): | |
assert ( | |
audio_duration % audio_window_duration == 0 | |
), "Audio window should divide duration evenly." | |
assert ( | |
audio_window_duration > audio_window_jitter | |
), "Jitter should be a small fraction of the audio window duration." | |
self.audio_paths = audio_paths | |
self.dance_labels = dance_labels | |
audio_info = ta.info(audio_paths[0]) | |
self.sample_rate = audio_info.sample_rate | |
self.audio_window_duration = int(audio_window_duration) | |
self.audio_window_jitter = audio_window_jitter | |
self.audio_duration = int(audio_duration) | |
self.audio_pipeline = AudioTrainingPipeline( | |
self.sample_rate, | |
resample_frequency, | |
audio_window_duration, | |
**audio_pipeline_kwargs, | |
) | |
def __len__(self): | |
return len(self.audio_paths) * self.audio_duration // self.audio_window_duration | |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: | |
waveform = self._waveform_from_index(idx) | |
assert ( | |
waveform.shape[1] > 10 | |
), f"No data found: {self._backtrace_audio_path(idx)}" | |
spectrogram = self.audio_pipeline(waveform) | |
dance_labels = self._label_from_index(idx) | |
example_is_valid = self._validate_output(spectrogram, dance_labels) | |
if example_is_valid: | |
return spectrogram, dance_labels | |
else: | |
# Try the previous one | |
# This happens when some of the audio recordings are really quiet | |
# This WILL NOT leak into other data partitions because songs belong entirely to a partition | |
return self[idx - 1] | |
def _convert_idx(self, idx: int) -> int: | |
return idx * self.audio_window_duration // self.audio_duration | |
def _backtrace_audio_path(self, index: int) -> str: | |
return self.audio_paths[self._convert_idx(index)] | |
def _validate_output(self, x, y): | |
is_finite = not torch.any(torch.isinf(x)) | |
is_numerical = not torch.any(torch.isnan(x)) | |
has_data = torch.any(x != 0.0) | |
is_binary = len(torch.unique(y)) < 3 | |
return all((is_finite, is_numerical, has_data, is_binary)) | |
def _waveform_from_index(self, idx: int) -> torch.Tensor: | |
audio_filepath = self.audio_paths[self._convert_idx(idx)] | |
num_windows = self.audio_duration // self.audio_window_duration | |
frame_index = idx % num_windows | |
jitter_start = -self.audio_window_jitter if frame_index > 0 else 0.0 | |
jitter_end = self.audio_window_jitter if frame_index != num_windows - 1 else 0.0 | |
jitter = int( | |
torch.FloatTensor(1).uniform_(jitter_start, jitter_end) * self.sample_rate | |
) | |
frame_offset = ( | |
frame_index * self.audio_window_duration * self.sample_rate + jitter | |
) | |
num_frames = self.sample_rate * self.audio_window_duration | |
waveform, sample_rate = ta.load( | |
audio_filepath, frame_offset=frame_offset, num_frames=num_frames | |
) | |
assert ( | |
sample_rate == self.sample_rate | |
), f"Expected sample rate of {self.sample_rate}. Found {sample_rate}" | |
return waveform | |
def _label_from_index(self, idx: int) -> torch.Tensor: | |
return torch.from_numpy(self.dance_labels[self._convert_idx(idx)]) | |
class WaveformSongDataset(SongDataset): | |
""" | |
Outputs raw waveforms of the data instead of a spectrogram. | |
""" | |
def __init__(self, *args, resample_frequency=16000, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.resample_frequency = resample_frequency | |
self.resampler = taT.Resample(self.sample_rate, self.resample_frequency) | |
self.pipeline = [] | |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
waveform = self._waveform_from_index(idx) | |
assert ( | |
waveform.shape[1] > 10 | |
), f"No data found: {self._backtrace_audio_path(idx)}" | |
# resample the waveform | |
waveform = self.resampler(waveform) | |
waveform = waveform.mean(0) | |
dance_labels = self._label_from_index(idx) | |
return waveform, dance_labels | |
class HuggingFaceWaveformSongDataset(WaveformSongDataset): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.pipeline = [] | |
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
x, y = super().__getitem__(idx) | |
if len(self.pipeline) > 0: | |
for fn in self.pipeline: | |
x = fn(x) | |
dance_labels = y.argmax() | |
return { | |
"input_values": x["input_values"][0] if hasattr(x, "input_values") else x, | |
"label": dance_labels, | |
} | |
def map(self, fn): | |
""" | |
NOTE this mutates the original, doesn't return a copy like normal maps. | |
""" | |
self.pipeline.append(fn) | |
class DanceDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
song_data_path="data/songs_cleaned.csv", | |
song_audio_path="data/samples", | |
test_proportion=0.15, | |
val_proportion=0.1, | |
target_classes: list[str] = None, | |
min_votes=1, | |
batch_size: int = 64, | |
num_workers=10, | |
dataset_cls=None, | |
dataset_kwargs={}, | |
): | |
super().__init__() | |
self.song_data_path = song_data_path | |
self.song_audio_path = song_audio_path | |
self.val_proportion = val_proportion | |
self.test_proportion = test_proportion | |
self.train_proportion = 1.0 - test_proportion - val_proportion | |
self.target_classes = target_classes | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.dataset_kwargs = dataset_kwargs | |
self.dataset_cls = dataset_cls if dataset_cls is not None else SongDataset | |
df = pd.read_csv(song_data_path) | |
self.x, self.y = get_examples( | |
df, | |
self.song_audio_path, | |
class_list=self.target_classes, | |
multi_label=True, | |
min_votes=min_votes, | |
) | |
def setup(self, stage: str): | |
train_i, val_i, test_i = random_split( | |
np.arange(len(self.x)), | |
[self.train_proportion, self.val_proportion, self.test_proportion], | |
) | |
self.train_ds = self._dataset_from_indices(train_i) | |
self.val_ds = self._dataset_from_indices(val_i) | |
self.test_ds = self._dataset_from_indices(test_i) | |
def _dataset_from_indices(self, idx: list[int]) -> SongDataset: | |
return self.dataset_cls(self.x[idx], self.y[idx], **self.dataset_kwargs) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_ds, | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
shuffle=True, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers | |
) | |
def get_label_weights(self): | |
n_examples, n_classes = self.y.shape | |
return torch.from_numpy(n_examples / (n_classes * sum(self.y))) | |
class WaveformTrainingEnvironment(pl.LightningModule): | |
def __init__( | |
self, | |
model: nn.Module, | |
criterion: nn.Module, | |
feature_extractor, | |
config: dict, | |
learning_rate=1e-4, | |
*args, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
self.model = model | |
self.criterion = criterion | |
self.learning_rate = learning_rate | |
self.config = config | |
self.feature_extractor = feature_extractor | |
self.save_hyperparameters( | |
{ | |
"model": type(model).__name__, | |
"loss": type(criterion).__name__, | |
"config": config, | |
**kwargs, | |
} | |
) | |
def preprocess_inputs(self, x): | |
device = x.device | |
x = list(x.squeeze(1).cpu().numpy()) | |
x = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000) | |
return x["input_values"].to(device) | |
def training_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
) -> torch.Tensor: | |
features, labels = batch | |
features = self.preprocess_inputs(features) | |
outputs = self.model(features).logits | |
outputs = nn.Sigmoid()( | |
outputs | |
) # good for multi label classification, should be softmax otherwise | |
loss = self.criterion(outputs, labels) | |
metrics = calculate_metrics(outputs, labels, prefix="train/", multi_label=True) | |
self.log_dict(metrics, prog_bar=True) | |
return loss | |
def validation_step( | |
self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int | |
): | |
x, y = batch | |
x = self.preprocess_inputs(x) | |
preds = self.model(x).logits | |
preds = nn.Sigmoid()(preds) | |
metrics = calculate_metrics(preds, y, prefix="val/", multi_label=True) | |
metrics["val/loss"] = self.criterion(preds, y) | |
self.log_dict(metrics, prog_bar=True) | |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int): | |
x, y = batch | |
x = self.preprocess_inputs(x) | |
preds = self.model(x).logits | |
preds = nn.Sigmoid()(preds) | |
self.log_dict( | |
calculate_metrics(preds, y, prefix="test/", multi_label=True), prog_bar=True | |
) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') {"scheduler": scheduler, "monitor": "val/loss"} | |
return [optimizer] | |
def calculate_metrics( | |
pred, target, threshold=0.5, prefix="", multi_label=True | |
) -> dict[str, torch.Tensor]: | |
target = target.detach().cpu().numpy() | |
pred = pred.detach().cpu().numpy() | |
params = { | |
"y_true": target if multi_label else target.argmax(1), | |
"y_pred": np.array(pred > threshold, dtype=float) | |
if multi_label | |
else pred.argmax(1), | |
"zero_division": 0, | |
"average": "macro", | |
} | |
metrics = { | |
"precision": precision_score(**params), | |
"recall": recall_score(**params), | |
"f1": f1_score(**params), | |
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]), | |
} | |
return { | |
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items() | |
} | |