Spaces:
Runtime error
Runtime error
File size: 5,115 Bytes
75c6e9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import os
import pickle
from typing import NoReturn
import h5py
from bytesep.utils import read_yaml
def create_indexes(args) -> NoReturn:
r"""Create and write out training indexes into disk. The indexes may contain
information from multiple datasets. During training, training indexes will
be shuffled and iterated for selecting segments to be mixed. E.g., the
training indexes_dict looks like: {
'vocals': [
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
...
]
'accompaniment': [
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
...
]
}
"""
# Arugments & parameters
workspace = args.workspace
config_yaml = args.config_yaml
# Only create indexes for training, because evalution is on entire pieces.
split = "train"
# Read config file.
configs = read_yaml(config_yaml)
sample_rate = configs["sample_rate"]
segment_samples = int(configs["segment_seconds"] * sample_rate)
# Path to write out index.
indexes_path = os.path.join(workspace, configs[split]["indexes"])
os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
source_types = configs[split]["source_types"].keys()
# E.g., ['vocals', 'accompaniment']
indexes_dict = {source_type: [] for source_type in source_types}
# E.g., indexes_dict will looks like: {
# 'vocals': [
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
# ...
# ]
# 'accompaniment': [
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
# ...
# ]
# }
# Get training indexes for each source type.
for source_type in source_types:
# E.g., ['vocals', 'bass', ...]
print("--- {} ---".format(source_type))
dataset_types = configs[split]["source_types"][source_type]
# E.g., ['musdb18', ...]
# Each source can come from mulitple datasets.
for dataset_type in dataset_types:
hdf5s_dir = os.path.join(
workspace, dataset_types[dataset_type]["hdf5s_directory"]
)
hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
# E.g., 'vocals'
hdf5_names = sorted(os.listdir(hdf5s_dir))
print("Hdf5 files num: {}".format(len(hdf5_names)))
# Traverse all packed hdf5 files of a dataset.
for n, hdf5_name in enumerate(hdf5_names):
print(n, hdf5_name)
hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
with h5py.File(hdf5_path, "r") as hf:
bgn_sample = 0
while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
meta = {
'hdf5_path': hdf5_path,
'key_in_hdf5': key_in_hdf5,
'begin_sample': bgn_sample,
'end_sample': bgn_sample + segment_samples,
}
indexes_dict[source_type].append(meta)
bgn_sample += hop_samples
# If the audio length is shorter than the segment length,
# then use the entire audio as a segment.
if bgn_sample == 0:
meta = {
'hdf5_path': hdf5_path,
'key_in_hdf5': key_in_hdf5,
'begin_sample': 0,
'end_sample': segment_samples,
}
indexes_dict[source_type].append(meta)
print(
"Total indexes for {}: {}".format(
source_type, len(indexes_dict[source_type])
)
)
pickle.dump(indexes_dict, open(indexes_path, "wb"))
print("Write index dict to {}".format(indexes_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--workspace", type=str, required=True, help="Directory of workspace."
)
parser.add_argument(
"--config_yaml", type=str, required=True, help="User defined config file."
)
# Parse arguments.
args = parser.parse_args()
# Create training indexes.
create_indexes(args)
|