Nick088's picture
added audio sr files, adapted them to zerogpu and optimization for memory
fa90792
raw
history blame
20.2 kB
import os
import pandas as pd
import audiosr.utilities.audio as Audio
from audiosr.utilities.tools import load_json
import random
from torch.utils.data import Dataset
import torch.nn.functional
import torch
import numpy as np
import torchaudio
class AudioDataset(Dataset):
def __init__(
self,
config=None,
split="train",
waveform_only=False,
add_ons=[],
dataset_json_path=None, #
):
"""
Dataset that manages audio recordings
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
:param dataset_json_file
"""
self.config = config
self.split = split
self.pad_wav_start_sample = 0 # If none, random choose
self.trim_wav = False
self.waveform_only = waveform_only
self.add_ons = [eval(x) for x in add_ons]
print("Add-ons:", self.add_ons)
self.build_setting_parameters()
# For an external dataset
if dataset_json_path is not None:
assert type(dataset_json_path) == str
print("Load metadata from %s" % dataset_json_path)
self.data = load_json(dataset_json_path)["data"]
self.id2label, self.index_dict, self.num2label = {}, {}, {}
else:
self.metadata_root = load_json(self.config["metadata_root"])
self.dataset_name = self.config["data"][self.split]
assert split in self.config["data"].keys(), (
"The dataset split %s you specified is not present in the config. You can choose from %s"
% (split, self.config["data"].keys())
)
self.build_dataset()
self.build_id_to_label()
self.build_dsp()
self.label_num = len(self.index_dict)
print("Dataset initialize finished")
def __getitem__(self, index):
(
fname,
waveform,
stft,
log_mel_spec,
label_vector, # the one-hot representation of the audio class
# the metadata of the sampled audio file and the mixup audio file (if exist)
(datum, mix_datum),
random_start,
) = self.feature_extraction(index)
text = self.get_sample_text_caption(datum, mix_datum, label_vector)
data = {
"text": text, # list
"fname": self.text_to_filename(text)
if (len(fname) == 0)
else fname, # list
# tensor, [batchsize, class_num]
"label_vector": "" if (label_vector is None) else label_vector.float(),
# tensor, [batchsize, 1, samples_num]
"waveform": "" if (waveform is None) else waveform.float(),
# tensor, [batchsize, t-steps, f-bins]
"stft": "" if (stft is None) else stft.float(),
# tensor, [batchsize, t-steps, mel-bins]
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
"duration": self.duration,
"sampling_rate": self.sampling_rate,
"random_start_sample_in_original_audio_file": random_start,
}
for add_on in self.add_ons:
data.update(add_on(self.config, data, self.data[index]))
if data["text"] is None:
print("Warning: The model return None on key text", fname)
data["text"] = ""
return data
def text_to_filename(self, text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
def get_dataset_root_path(self, dataset):
assert dataset in self.metadata_root.keys()
return self.metadata_root[dataset]
def get_dataset_metadata_path(self, dataset, key):
# key: train, test, val, class_label_indices
try:
if dataset in self.metadata_root["metadata"]["path"].keys():
return self.metadata_root["metadata"]["path"][dataset][key]
except:
raise ValueError(
'Dataset %s does not metadata "%s" specified' % (dataset, key)
)
# return None
def __len__(self):
return len(self.data)
def feature_extraction(self, index):
if index > len(self.data) - 1:
print(
"The index of the dataloader is out of range: %s/%s"
% (index, len(self.data))
)
index = random.randint(0, len(self.data) - 1)
# Read wave file and extract feature
while True:
try:
label_indices = np.zeros(self.label_num, dtype=np.float32)
datum = self.data[index]
(
log_mel_spec,
stft,
mix_lambda,
waveform,
random_start,
) = self.read_audio_file(datum["wav"])
mix_datum = None
if self.label_num > 0 and "labels" in datum.keys():
for label_str in datum["labels"].split(","):
label_indices[int(self.index_dict[label_str])] = 1.0
# If the key "label" is not in the metadata, return all zero vector
label_indices = torch.FloatTensor(label_indices)
break
except Exception as e:
index = (index + 1) % len(self.data)
print(
"Error encounter during audio feature extraction: ", e, datum["wav"]
)
continue
# The filename of the wav file
fname = datum["wav"]
# t_step = log_mel_spec.size(0)
# waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)])
waveform = torch.FloatTensor(waveform)
return (
fname,
waveform,
stft,
log_mel_spec,
label_indices,
(datum, mix_datum),
random_start,
)
# def augmentation(self, log_mel_spec):
# assert torch.min(log_mel_spec) < 0
# log_mel_spec = log_mel_spec.exp()
# log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
# # this is just to satisfy new torchaudio version.
# log_mel_spec = log_mel_spec.unsqueeze(0)
# if self.freqm != 0:
# log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm)
# if self.timem != 0:
# log_mel_spec = self.time_masking(
# log_mel_spec, self.timem) # self.timem=0
# log_mel_spec = (log_mel_spec + 1e-7).log()
# # squeeze back
# log_mel_spec = log_mel_spec.squeeze(0)
# log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
# return log_mel_spec
def build_setting_parameters(self):
# Read from the json config
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
# self.freqm = self.config["preprocessing"]["mel"]["freqm"]
# self.timem = self.config["preprocessing"]["mel"]["timem"]
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
self.duration = self.config["preprocessing"]["audio"]["duration"]
self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
self.mixup = self.config["augmentation"]["mixup"]
# Calculate parameter derivations
# self.waveform_sample_length = int(self.target_length * self.hopsize)
# if (self.config["balance_sampling_weight"]):
# self.samples_weight = np.loadtxt(
# self.config["balance_sampling_weight"], delimiter=","
# )
if "train" not in self.split:
self.mixup = 0.0
# self.freqm = 0
# self.timem = 0
def _relative_path_to_absolute_path(self, metadata, dataset_name):
root_path = self.get_dataset_root_path(dataset_name)
for i in range(len(metadata["data"])):
assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
assert metadata["data"][i]["wav"][0] != "/", (
"The dataset metadata should only contain relative path to the audio file: "
+ str(metadata["data"][i]["wav"])
)
metadata["data"][i]["wav"] = os.path.join(
root_path, metadata["data"][i]["wav"]
)
return metadata
def build_dataset(self):
self.data = []
print("Build dataset split %s from %s" % (self.split, self.dataset_name))
if type(self.dataset_name) is str:
data_json = load_json(
self.get_dataset_metadata_path(self.dataset_name, key=self.split)
)
data_json = self._relative_path_to_absolute_path(
data_json, self.dataset_name
)
self.data = data_json["data"]
elif type(self.dataset_name) is list:
for dataset_name in self.dataset_name:
data_json = load_json(
self.get_dataset_metadata_path(dataset_name, key=self.split)
)
data_json = self._relative_path_to_absolute_path(
data_json, dataset_name
)
self.data += data_json["data"]
else:
raise Exception("Invalid data format")
print("Data size: {}".format(len(self.data)))
def build_dsp(self):
self.STFT = Audio.stft.TacotronSTFT(
self.config["preprocessing"]["stft"]["filter_length"],
self.config["preprocessing"]["stft"]["hop_length"],
self.config["preprocessing"]["stft"]["win_length"],
self.config["preprocessing"]["mel"]["n_mel_channels"],
self.config["preprocessing"]["audio"]["sampling_rate"],
self.config["preprocessing"]["mel"]["mel_fmin"],
self.config["preprocessing"]["mel"]["mel_fmax"],
)
# self.stft_transform = torchaudio.transforms.Spectrogram(
# n_fft=1024, hop_length=160
# )
# self.melscale_transform = torchaudio.transforms.MelScale(
# sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64
# )
def build_id_to_label(self):
id2label = {}
id2num = {}
num2label = {}
class_label_indices_path = self.get_dataset_metadata_path(
dataset=self.config["data"]["class_label_indices"],
key="class_label_indices",
)
if class_label_indices_path is not None:
df = pd.read_csv(class_label_indices_path)
for _, row in df.iterrows():
index, mid, display_name = row["index"], row["mid"], row["display_name"]
id2label[mid] = display_name
id2num[mid] = index
num2label[index] = display_name
self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
else:
self.id2label, self.index_dict, self.num2label = {}, {}, {}
def resample(self, waveform, sr):
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
# waveform = librosa.resample(waveform, sr, self.sampling_rate)
return waveform
# if sr == 16000:
# return waveform
# if sr == 32000 and self.sampling_rate == 16000:
# waveform = waveform[::2]
# return waveform
# if sr == 48000 and self.sampling_rate == 16000:
# waveform = waveform[::3]
# return waveform
# else:
# raise ValueError(
# "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s"
# % (sr, self.sampling_rate)
# )
def normalize_wav(self, waveform):
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
return waveform * 0.5 # Manually limit the maximum amplitude into 0.5
def random_segment_wav(self, waveform, target_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
# Too short
if (waveform_length - target_length) <= 0:
return waveform, 0
random_start = int(self.random_uniform(0, waveform_length - target_length))
return waveform[:, random_start : random_start + target_length], random_start
def pad_wav(self, waveform, target_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
if waveform_length == target_length:
return waveform
# Pad
temp_wav = np.zeros((1, target_length), dtype=np.float32)
if self.pad_wav_start_sample is None:
rand_start = int(self.random_uniform(0, target_length - waveform_length))
else:
rand_start = 0
temp_wav[:, rand_start : rand_start + waveform_length] = waveform
return temp_wav
def trim_wav(self, waveform):
if np.max(np.abs(waveform)) < 0.0001:
return waveform
def detect_leading_silence(waveform, threshold=0.0001):
chunk_size = 1000
waveform_length = waveform.shape[0]
start = 0
while start + chunk_size < waveform_length:
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
start += chunk_size
else:
break
return start
def detect_ending_silence(waveform, threshold=0.0001):
chunk_size = 1000
waveform_length = waveform.shape[0]
start = waveform_length
while start - chunk_size > 0:
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
start -= chunk_size
else:
break
if start == waveform_length:
return start
else:
return start + chunk_size
start = detect_leading_silence(waveform)
end = detect_ending_silence(waveform)
return waveform[start:end]
def read_wav_file(self, filename):
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
waveform, sr = torchaudio.load(filename)
waveform, random_start = self.random_segment_wav(
waveform, target_length=int(sr * self.duration)
)
waveform = self.resample(waveform, sr)
# random_start = int(random_start * (self.sampling_rate / sr))
waveform = waveform.numpy()[0, ...]
waveform = self.normalize_wav(waveform)
if self.trim_wav:
waveform = self.trim_wav(waveform)
waveform = waveform[None, ...]
waveform = self.pad_wav(
waveform, target_length=int(self.sampling_rate * self.duration)
)
return waveform, random_start
def mix_two_waveforms(self, waveform1, waveform2):
mix_lambda = np.random.beta(5, 5)
mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
return self.normalize_wav(mix_waveform), mix_lambda
def read_audio_file(self, filename, filename2=None):
if os.path.exists(filename):
waveform, random_start = self.read_wav_file(filename)
else:
print(
'Warning [dataset.py]: The wav path "',
filename,
'" is not find in the metadata. Use empty waveform instead.',
)
target_length = int(self.sampling_rate * self.duration)
waveform = torch.zeros((1, target_length))
random_start = 0
mix_lambda = 0.0
# log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN
if not self.waveform_only:
log_mel_spec, stft = self.wav_feature_extraction(waveform)
else:
# Load waveform data only
# Use zero array to keep the format unified
log_mel_spec, stft = None, None
return log_mel_spec, stft, mix_lambda, waveform, random_start
def get_sample_text_caption(self, datum, mix_datum, label_indices):
text = self.label_indices_to_text(datum, label_indices)
if mix_datum is not None:
text += " " + self.label_indices_to_text(mix_datum, label_indices)
return text
# This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
def wav_feature_extraction(self, waveform):
waveform = waveform[0, ...]
waveform = torch.FloatTensor(waveform)
log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
log_mel_spec = torch.FloatTensor(log_mel_spec.T)
stft = torch.FloatTensor(stft.T)
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
return log_mel_spec, stft
# @profile
# def wav_feature_extraction_torchaudio(self, waveform):
# waveform = waveform[0, ...]
# waveform = torch.FloatTensor(waveform)
# stft = self.stft_transform(waveform)
# mel_spec = self.melscale_transform(stft)
# log_mel_spec = torch.log(mel_spec + 1e-7)
# log_mel_spec = torch.FloatTensor(log_mel_spec.T)
# stft = torch.FloatTensor(stft.T)
# log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
# return log_mel_spec, stft
def pad_spec(self, log_mel_spec):
n_frames = log_mel_spec.shape[0]
p = self.target_length - n_frames
# cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
log_mel_spec = m(log_mel_spec)
elif p < 0:
log_mel_spec = log_mel_spec[0 : self.target_length, :]
if log_mel_spec.size(-1) % 2 != 0:
log_mel_spec = log_mel_spec[..., :-1]
return log_mel_spec
def _read_datum_caption(self, datum):
caption_keys = [x for x in datum.keys() if ("caption" in x)]
random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
return datum[caption_keys[random_index]]
def _is_contain_caption(self, datum):
caption_keys = [x for x in datum.keys() if ("caption" in x)]
return len(caption_keys) > 0
def label_indices_to_text(self, datum, label_indices):
if self._is_contain_caption(datum):
return self._read_datum_caption(datum)
elif "label" in datum.keys():
name_indices = torch.where(label_indices > 0.1)[0]
# description_header = "This audio contains the sound of "
description_header = ""
labels = ""
for id, each in enumerate(name_indices):
if id == len(name_indices) - 1:
labels += "%s." % self.num2label[int(each)]
else:
labels += "%s, " % self.num2label[int(each)]
return description_header + labels
else:
return "" # TODO, if both label and caption are not provided, return empty string
def random_uniform(self, start, end):
val = torch.rand(1).item()
return start + (end - start) * val
def frequency_masking(self, log_mel_spec, freqm):
bs, freq, tsteps = log_mel_spec.size()
mask_len = int(self.random_uniform(freqm // 8, freqm))
mask_start = int(self.random_uniform(start=0, end=freq - mask_len))
log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0
return log_mel_spec
def time_masking(self, log_mel_spec, timem):
bs, freq, tsteps = log_mel_spec.size()
mask_len = int(self.random_uniform(timem // 8, timem))
mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len))
log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0
return log_mel_spec