Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,296 Bytes
568e264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from functools import partial
import sys
import torch
from torch.nn.utils.rnn import pad_sequence
from wenet.dataset import processor
from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource
def padding(data):
""" Padding the data into training data
Args:
data: List[{key, feat, label}
Returns:
Tuple(keys, feats, labels, feats lengths, label lengths)
"""
sample = data
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
# NOTE(Mddct): cv need targets , refine later
"target": padded_feats,
"target_lengths": feats_lengths,
}
return batch
def Dataset(data_type, data_list_file, conf=None, partition=True):
""" Construct dataset from arguments for ssl model
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
"""
assert conf is not None
assert data_type in ['raw', 'shard']
# cycle dataset
cycle = conf.get('cycle', 1)
# stage1 shuffle: source
list_shuffle = conf.get('list_shuffle', True)
list_shuffle_size = sys.maxsize
if list_shuffle:
list_shuffle_conf = conf.get('list_shuffle_conf', {})
list_shuffle_size = list_shuffle_conf.get('shuffle_size',
list_shuffle_size)
if data_type == 'raw':
dataset = WenetRawDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map(processor.parse_json)
else:
dataset = WenetTarShardDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map_ignore_error(processor.decode_wav)
singal_channel_conf = conf.get('singal_channel_conf', {})
dataset = dataset.map(
partial(processor.singal_channel, **singal_channel_conf))
filter_conf = conf.get('filter_conf', {})
dataset = dataset.filter(partial(processor.filter, **filter_conf))
resample_conf = conf.get('resample_conf', {})
dataset = dataset.map(partial(processor.resample, **resample_conf))
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = dataset.map(partial(processor.speed_perturb))
feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf))
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf))
elif feats_type == 'log_mel_spectrogram':
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
dataset = dataset.map(
partial(processor.compute_log_mel_spectrogram,
**log_mel_spectrogram_conf))
spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
spec_trim = conf.get('spec_trim', False)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf))
if spec_sub:
spec_sub_conf = conf.get('spec_sub_conf', {})
dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf))
if spec_trim:
spec_trim_conf = conf.get('spec_trim_conf', {})
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))
shuffle = conf.get('shuffle', True)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size'])
sort = conf.get('sort', True)
if sort:
sort_conf = conf.get('sort_conf', {})
dataset = dataset.sort(buffer_size=sort_conf['sort_size'],
key_func=processor.sort_by_feats)
batch_conf = conf.get('batch_conf', {})
batch_type = batch_conf.get('batch_type', 'static')
assert batch_type in ['static', 'bucket', 'dynamic']
if batch_type == 'static':
assert 'batch_size' in batch_conf
batch_size = batch_conf.get('batch_size', 16)
dataset = dataset.batch(batch_size, wrapper_class=padding)
elif batch_type == 'bucket':
assert 'bucket_boundaries' in batch_conf
assert 'bucket_batch_sizes' in batch_conf
dataset = dataset.bucket_by_sequence_length(
processor.feats_length_fn,
batch_conf['bucket_boundaries'],
batch_conf['bucket_batch_sizes'],
wrapper_class=padding)
else:
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000)
dataset = dataset.dynamic_batch(
processor.DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding,
)
return dataset
def init_dataset(data_type, data_list_file, conf=None, partition=True):
return Dataset(data_type, data_list_file, conf, partition)
|