OSUM / wenet /ssl /init_dataset.py
tomxxie
适配zeroGPU
568e264
raw
history blame
6.3 kB
from functools import partial
import sys
import torch
from torch.nn.utils.rnn import pad_sequence
from wenet.dataset import processor
from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource
def padding(data):
""" Padding the data into training data
Args:
data: List[{key, feat, label}
Returns:
Tuple(keys, feats, labels, feats lengths, label lengths)
"""
sample = data
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
# NOTE(Mddct): cv need targets , refine later
"target": padded_feats,
"target_lengths": feats_lengths,
}
return batch
def Dataset(data_type, data_list_file, conf=None, partition=True):
""" Construct dataset from arguments for ssl model
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
"""
assert conf is not None
assert data_type in ['raw', 'shard']
# cycle dataset
cycle = conf.get('cycle', 1)
# stage1 shuffle: source
list_shuffle = conf.get('list_shuffle', True)
list_shuffle_size = sys.maxsize
if list_shuffle:
list_shuffle_conf = conf.get('list_shuffle_conf', {})
list_shuffle_size = list_shuffle_conf.get('shuffle_size',
list_shuffle_size)
if data_type == 'raw':
dataset = WenetRawDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map(processor.parse_json)
else:
dataset = WenetTarShardDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map_ignore_error(processor.decode_wav)
singal_channel_conf = conf.get('singal_channel_conf', {})
dataset = dataset.map(
partial(processor.singal_channel, **singal_channel_conf))
filter_conf = conf.get('filter_conf', {})
dataset = dataset.filter(partial(processor.filter, **filter_conf))
resample_conf = conf.get('resample_conf', {})
dataset = dataset.map(partial(processor.resample, **resample_conf))
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = dataset.map(partial(processor.speed_perturb))
feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf))
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf))
elif feats_type == 'log_mel_spectrogram':
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
dataset = dataset.map(
partial(processor.compute_log_mel_spectrogram,
**log_mel_spectrogram_conf))
spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
spec_trim = conf.get('spec_trim', False)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf))
if spec_sub:
spec_sub_conf = conf.get('spec_sub_conf', {})
dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf))
if spec_trim:
spec_trim_conf = conf.get('spec_trim_conf', {})
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))
shuffle = conf.get('shuffle', True)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size'])
sort = conf.get('sort', True)
if sort:
sort_conf = conf.get('sort_conf', {})
dataset = dataset.sort(buffer_size=sort_conf['sort_size'],
key_func=processor.sort_by_feats)
batch_conf = conf.get('batch_conf', {})
batch_type = batch_conf.get('batch_type', 'static')
assert batch_type in ['static', 'bucket', 'dynamic']
if batch_type == 'static':
assert 'batch_size' in batch_conf
batch_size = batch_conf.get('batch_size', 16)
dataset = dataset.batch(batch_size, wrapper_class=padding)
elif batch_type == 'bucket':
assert 'bucket_boundaries' in batch_conf
assert 'bucket_batch_sizes' in batch_conf
dataset = dataset.bucket_by_sequence_length(
processor.feats_length_fn,
batch_conf['bucket_boundaries'],
batch_conf['bucket_batch_sizes'],
wrapper_class=padding)
else:
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000)
dataset = dataset.dynamic_batch(
processor.DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding,
)
return dataset
def init_dataset(data_type, data_list_file, conf=None, partition=True):
return Dataset(data_type, data_list_file, conf, partition)