lmzjms's picture
Upload 1162 files
0b32ad6 verified
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
__all__ = [
"default_collate_fn",
]
def default_collate_fn(samples, padding_value: int = 0):
"""
Each item in **DynamicItemDataset** is a dict
This function pad (or transform into numpy list) a batch of dict
Args:
samples (List[dict]): Suppose each Container is in
.. code-block:: yaml
wav: a single waveform
label: a single string
Return:
dict
.. code-block:: yaml
wav: padded waveforms
label: np.array([a list of string labels])
"""
assert isinstance(samples[0], dict)
keys = samples[0].keys()
padded_samples = dict()
for key in keys:
values = [sample[key] for sample in samples]
if isinstance(values[0], int):
values = torch.LongTensor(values)
elif isinstance(values[0], float):
values = torch.FloatTensor(values)
elif isinstance(values[0], np.ndarray):
values = [torch.from_numpy(value).float() for value in values]
values = pad_sequence(values, batch_first=True, padding_value=padding_value)
elif isinstance(values[0], torch.Tensor):
values = pad_sequence(values, batch_first=True, padding_value=padding_value)
else:
values = np.array(values, dtype="object")
padded_samples[key] = values
return padded_samples