Spaces:
Runtime error
Runtime error
File size: 4,757 Bytes
75c6e9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from typing import Dict, List
import torch
class BasicBatchDataPreprocessor:
def __init__(self, target_source_types: List[str]):
r"""Batch data preprocessor. Used for preparing mixtures and targets for
training. If there are multiple target source types, the waveforms of
those sources will be stacked along the channel dimension.
Args:
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
"""
self.target_source_types = target_source_types
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
r"""Format waveforms and targets for training.
Args:
batch_data_dict: dict, e.g., {
'mixture': (batch_size, channels_num, segment_samples),
'vocals': (batch_size, channels_num, segment_samples),
'bass': (batch_size, channels_num, segment_samples),
...,
}
Returns:
input_dict: dict, e.g., {
'waveform': (batch_size, channels_num, segment_samples),
}
output_dict: dict, e.g., {
'target': (batch_size, target_sources_num * channels_num, segment_samples)
}
"""
mixtures = batch_data_dict['mixture']
# mixtures: (batch_size, channels_num, segment_samples)
# Concatenate waveforms of multiple targets along the channel axis.
targets = torch.cat(
[batch_data_dict[source_type] for source_type in self.target_source_types],
dim=1,
)
# targets: (batch_size, target_sources_num * channels_num, segment_samples)
input_dict = {'waveform': mixtures}
target_dict = {'waveform': targets}
return input_dict, target_dict
class ConditionalSisoBatchDataPreprocessor:
def __init__(self, target_source_types: List[str]):
r"""Conditional single input single output (SISO) batch data
preprocessor. Select one target source from several target sources as
training target and prepare the corresponding conditional vector.
Args:
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
"""
self.target_source_types = target_source_types
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
r"""Format waveforms and targets for training.
Args:
batch_data_dict: dict, e.g., {
'mixture': (batch_size, channels_num, segment_samples),
'vocals': (batch_size, channels_num, segment_samples),
'bass': (batch_size, channels_num, segment_samples),
...,
}
Returns:
input_dict: dict, e.g., {
'waveform': (batch_size, channels_num, segment_samples),
'condition': (batch_size, target_sources_num),
}
output_dict: dict, e.g., {
'target': (batch_size, channels_num, segment_samples)
}
"""
batch_size = len(batch_data_dict['mixture'])
target_sources_num = len(self.target_source_types)
assert (
batch_size % target_sources_num == 0
), "Batch size should be \
evenly divided by target sources number."
mixtures = batch_data_dict['mixture']
# mixtures: (batch_size, channels_num, segment_samples)
conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
# conditions: (batch_size, target_sources_num)
targets = []
for n in range(batch_size):
k = n % target_sources_num # source class index
source_type = self.target_source_types[k]
targets.append(batch_data_dict[source_type][n])
conditions[n, k] = 1
# conditions will looks like:
# [[1, 0, 0, 0],
# [0, 1, 0, 0],
# [0, 0, 1, 0],
# [0, 0, 0, 1],
# [1, 0, 0, 0],
# [0, 1, 0, 0],
# ...,
# ]
targets = torch.stack(targets, dim=0)
# targets: (batch_size, channels_num, segment_samples)
input_dict = {
'waveform': mixtures,
'condition': conditions,
}
target_dict = {'waveform': targets}
return input_dict, target_dict
def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
r"""Get batch data preprocessor class."""
if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
return BasicBatchDataPreprocessor
elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
return ConditionalSisoBatchDataPreprocessor
else:
raise NotImplementedError
|