Spaces:
Runtime error
Runtime error
File size: 6,148 Bytes
c914273 557fb53 c914273 557fb53 0030bc6 c914273 0030bc6 c914273 a8c0792 c914273 e82ec2b e6fd727 c914273 e82ec2b e6fd727 c914273 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b 0030bc6 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b c914273 e82ec2b 557fb53 0030bc6 557fb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import pandas as pd
import numpy as np
import re
import json
from pathlib import Path
import glob
import os
import shutil
import torchaudio
import torch
from tqdm import tqdm
from preprocessing.utils import url_to_filename
def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
audio_urls = audio_urls.replace(".", np.nan)
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
valid_audio_mask = audio_urls.apply(
lambda url: url is not np.nan and url_to_filename(url) in audio_files
)
return valid_audio_mask
def validate_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
"""
Tests audio urls to ensure that their file exists and the contents is valid.
"""
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
def is_valid(url):
valid_url = type(url) == str and "http" in url
if not valid_url:
return False
filename = url_to_filename(url)
if filename not in audio_files:
return False
try:
w, _ = torchaudio.load(os.path.join(audio_dir, filename))
except:
return False
contents_invalid = (
torch.any(torch.isnan(w))
or torch.any(torch.isinf(w))
or len(torch.unique(w)) <= 2
)
return not contents_invalid
idxs = []
validations = []
for index, url in tqdm(
audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"
):
idxs.append(index)
validations.append(is_valid(url))
return pd.Series(validations, index=idxs)
def fix_dance_rating_counts(dance_ratings: pd.Series) -> pd.Series:
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
dance_ratings = dance_ratings.apply(lambda v: json.loads(v.replace("'", '"')))
def fix_labels(labels: dict) -> dict | float:
new_labels = {}
for k, v in labels.items():
match = tag_pattern.search(k)
if match is None:
new_labels[k] = new_labels.get(k, 0) + v
else:
k = match[1]
sign = 1 if match[2] == "+" else -1
scale = int(match[3])
new_labels[k] = new_labels.get(k, 0) + v * scale * sign
valid = any(v > 0 for v in new_labels.values())
return new_labels if valid else np.nan
return dance_ratings.apply(fix_labels)
def get_unique_labels(dance_labels: pd.Series) -> list:
labels = set()
for dances in dance_labels:
labels |= set(dances)
return sorted(labels)
def vectorize_label_probs(
labels: dict[str, int], unique_labels: np.ndarray
) -> np.ndarray:
"""
Turns label dict into probability distribution vector based on each label count.
"""
label_vec = np.zeros((len(unique_labels),), dtype="float32")
for k, v in labels.items():
item_vec = (unique_labels == k) * v
label_vec += item_vec
label_vec[label_vec < 0] = 0
label_vec /= label_vec.sum()
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
return label_vec
def vectorize_multi_label(
labels: dict[str, int], unique_labels: np.ndarray
) -> np.ndarray:
"""
Turns label dict into binary label vectors for multi-label classification.
"""
probs = vectorize_label_probs(labels, unique_labels)
probs[probs > 0.0] = 1.0
return probs
def sort_yt_files(
aliases_path="data/dance_aliases.json",
all_dances_folder="data/best-ballroom-music",
original_location="data/yt-ballroom-music/",
):
def normalize_string(s):
# Lowercase string and remove special characters
return re.sub(r"\W+", "", s.lower())
with open(aliases_path, "r") as f:
dances = json.load(f)
# Normalize the dance inputs and aliases
normalized_dances = {
normalize_string(dance_id): [normalize_string(alias) for alias in aliases]
for dance_id, aliases in dances.items()
}
# For every wav file in the target folder
bad_files = []
progress_bar = tqdm(os.listdir(all_dances_folder), unit="files moved")
for file_name in progress_bar:
if file_name.endswith(".wav"):
# check if the normalized wav file name contains the normalized dance alias
normalized_file_name = normalize_string(file_name)
matching_dance_ids = [
dance_id
for dance_id, aliases in normalized_dances.items()
if any(alias in normalized_file_name for alias in aliases)
]
if len(matching_dance_ids) == 0:
# See if the dance is in the path
original_filename = file_name.replace(".wav", "")
matches = glob.glob(
os.path.join(original_location, "**", original_filename),
recursive=True,
)
if len(matches) == 1:
normalized_file_name = normalize_string(matches[0])
matching_dance_ids = [
dance_id
for dance_id, aliases in normalized_dances.items()
if any(alias in normalized_file_name for alias in aliases)
]
if "swz" in matching_dance_ids and "vwz" in matching_dance_ids:
matching_dance_ids.remove("swz")
if len(matching_dance_ids) > 1 and "lhp" in matching_dance_ids:
matching_dance_ids.remove("lhp")
if len(matching_dance_ids) != 1:
bad_files.append(file_name)
progress_bar.set_description(f"bad files: {len(bad_files)}")
continue
dst = os.path.join("data", "ballroom-songs", matching_dance_ids[0].upper())
os.makedirs(dst, exist_ok=True)
filepath = os.path.join(all_dances_folder, file_name)
shutil.copy(filepath, os.path.join(dst, file_name))
with open("data/bad_files.json", "w") as f:
json.dump(bad_files, f)
if __name__ == "__main__":
sort_yt_files()
|