File size: 1,468 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence

__all__ = [
    "default_collate_fn",
]


def default_collate_fn(samples, padding_value: int = 0):
    """
    Each item in **DynamicItemDataset** is a dict
    This function pad (or transform into numpy list) a batch of dict

    Args:
        samples (List[dict]): Suppose each Container is in

            .. code-block:: yaml

                wav: a single waveform
                label: a single string

    Return:
        dict

        .. code-block:: yaml

            wav: padded waveforms
            label: np.array([a list of string labels])
    """
    assert isinstance(samples[0], dict)
    keys = samples[0].keys()
    padded_samples = dict()
    for key in keys:
        values = [sample[key] for sample in samples]
        if isinstance(values[0], int):
            values = torch.LongTensor(values)
        elif isinstance(values[0], float):
            values = torch.FloatTensor(values)
        elif isinstance(values[0], np.ndarray):
            values = [torch.from_numpy(value).float() for value in values]
            values = pad_sequence(values, batch_first=True, padding_value=padding_value)
        elif isinstance(values[0], torch.Tensor):
            values = pad_sequence(values, batch_first=True, padding_value=padding_value)
        else:
            values = np.array(values, dtype="object")
        padded_samples[key] = values
    return padded_samples