Runtime error
Runtime error
File size: 4,075 Bytes
c914273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import pandas as pd
import numpy as np
import re
import json
from pathlib import Path
import os
import torch
import torchaudio.transforms as taT
def url_to_filename(url:str) -> str:
return f"{url.split('/')[-1]}.wav"
def get_songs_with_audio(df:pd.DataFrame, audio_dir:str) -> pd.DataFrame:
audio_urls = df["Sample"].replace(".", np.nan)
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
valid_audio = audio_urls.apply(lambda url : url is not np.nan and url_to_filename(url) in audio_files)
df = df[valid_audio]
return df
def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
dance_ratings = dance_ratings.apply(lambda v : json.loads(v.replace("'", "\"")))
def fix_labels(labels:dict) -> dict | float:
new_labels = {}
for k, v in labels.items():
match =
if match is None:
new_labels[k] = new_labels.get(k, 0) + v
k = match[1]
sign = 1 if match[2] == '+' else -1
scale = int(match[3])
new_labels[k] = new_labels.get(k, 0) + v * scale * sign
valid = any(v > 0 for v in new_labels.values())
return new_labels if valid else np.nan
return dance_ratings.apply(fix_labels)
def get_unique_labels(dance_labels:pd.Series) -> list:
labels = set()
for dances in dance_labels:
labels |= set(dances)
return sorted(labels)
def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
Turns label dict into probability distribution vector based on each label count.
label_vec = np.zeros((len(unique_labels),), dtype="float32")
for k, v in labels.items():
item_vec = (unique_labels == k) * v
label_vec += item_vec
lv_cache = label_vec.copy()
label_vec[label_vec<0] = 0
label_vec /= label_vec.sum()
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
return label_vec
def vectorize_multi_label(labels: dict[str,int], unique_labels:np.ndarray) -> np.ndarray:
Turns label dict into binary label vectors for multi-label classification.
probs = vectorize_label_probs(labels,unique_labels)
probs[probs > 0.0] = 1.0
return probs
def get_examples(df:pd.DataFrame, audio_dir:str, class_list=None) -> tuple[list[str], list[np.ndarray]]:
sampled_songs = get_songs_with_audio(df, audio_dir)
sampled_songs.loc[:,"DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
if class_list is not None:
class_list = set(class_list)
sampled_songs.loc[:,"DanceRating"] = sampled_songs["DanceRating"].apply(
lambda labels : {k: v for k,v in labels.items() if k in class_list}
if not pd.isna(labels) and any(label in class_list and amt > 0 for label, amt in labels.items())
else np.nan)
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
labels = sampled_songs["DanceRating"]
unique_labels = np.array(get_unique_labels(labels))
labels = labels.apply(lambda i : vectorize_multi_label(i, unique_labels))
audio_paths = [os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]]
return audio_paths, list(labels)
class AudioPipeline(torch.nn.Module):
def __init__(
self.resample = taT.Resample(orig_freq=input_freq, new_freq=resample_freq)
self.spec = taT.MelSpectrogram(sample_rate=resample_freq, n_mels=64, n_fft=1024)
self.to_db = taT.AmplitudeToDB()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = self.resample(waveform)
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
return spectrogram