File size: 6,728 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import logging
from collections import OrderedDict
from copy import deepcopy
from functools import partial
from typing import Any, List, Union

import numpy as np
import torch
from speechbrain.dataio.dataset import DynamicItemDataset
from speechbrain.utils.data_pipeline import DynamicItem
from torch.nn.utils.rnn import pad_sequence

logger = logging.getLogger(__name__)


class AugmentedDynamicItemDataset(DynamicItemDataset):
    def __init__(
        self,
        data,
        dynamic_items=[],
        output_keys=[],
        tools: dict = {},
    ):
        super().__init__(data, dynamic_items, output_keys)
        assert isinstance(data, OrderedDict)
        self._tools = {}
        for name, item in tools.items():
            self.add_tool(name, item)

    def _dynamic_tools(self, id, name):
        return self._tools[name]

    def add_tool(self, name: str, item: Any) -> None:
        """
        Store the :code:`item` in this dataset with the name :code:`name` so it can be used in
        :code:`__getitem__`. That is, you can retrieve the :code:`item` with the :code:`takes` argument
        of :obj:`add_dynamic_item`.

        .. code-block:: python

            def tokenize_func(tokenizer, text):
                return tokenizer(text)

            self.add_tool("tokenizer", tokenizer)
            self.add_dynamic_item(tokenize_func, takes=["tokenizer", "text"], provides="tokenized_ids")

        You can also later retreive this tool by :obj:`get_tool` or :obj:`all_tools`
        """
        self._tools[name] = item
        self.add_dynamic_item(
            partial(self._dynamic_tools, name=name), takes="id", provides=name
        )

    def add_tools(self, tools: dict) -> None:
        """
        Store each key-value pair in :code:`tools` as a tool. See :obj:`add_tool` for more information
        """
        for key, value in tools.items():
            self.add_tool(key, value)

    def get_tool(self, key) -> Any:
        """
        See :obj:`add_tool` for more information
        """
        return self._tools[key]

    def has_tool(self, key) -> bool:
        """
        Checks whether has a tool named :code:`key`.
        """
        return key in self._tools

    def all_tools(self, copy=True) -> dict:
        """
        Return:
            dict

            Containing all the tools in :code:`name: value` pairs.
            See :obj:`add_tool` for more information
        """
        return deepcopy(self._tools) if copy else self._tools

    def update_output_keys(self, keys: dict) -> None:
        """
        Compared to :obj:`set_output_keys`, this method update the output keys mapping
        instead of replace it with a new dictionary. This can be useful when you only
        want to replace a few mapping and leave others unchanged.
        """
        mapping = self.pipeline.output_mapping.copy()
        mapping.update(keys)
        self.set_output_keys(mapping)

    def keys(self) -> List[str]:
        """
        List all the :code:`static_item` and :code:`dynamic_item` in the dataset.
        :code:`static_item` resides directly in the memory and are given by the dataset
        initialization dictionary. :code:`dynamic_item` are content computed
        on-the-fly basing on :code:`static_item`.
        """
        available_keys: List[str] = list(self.pipeline.key_to_node.keys())
        for dynamic_item in self.pipeline.dynamic_items:
            provides = dynamic_item.provides
            assert isinstance(provides, (list, tuple))
            available_keys += provides
        available_keys = [
            key
            for key in available_keys
            if not key.startswith("_") and key not in self._tools
        ]
        return available_keys

    def set_info(self, info):
        self._info = info

    def get_info(self, index):
        with self.output_keys_as(self._info):
            return self.__getitem__(index)

    def __getitem__(self, index):
        """
        This remain all the usage of the original SpeechBrain DynamicItemDataset.__getitem__,
        except that by default it uses :obj:`keys` as the default :code:`output_keys`
        """
        if len(self.pipeline.output_mapping) == 0:
            with self.output_keys_as(self.keys()):
                return super().__getitem__(index)
        else:
            return super().__getitem__(index)


class DataPipe:
    def __call__(
        self, dataset: Union[dict, AugmentedDynamicItemDataset], tools: dict = None
    ) -> Any:
        if isinstance(dataset, dict):
            dataset = AugmentedDynamicItemDataset(dataset)

        if tools is not None:
            dataset.add_tools(tools)

        return self.forward(dataset)

    def forward(
        self, dataset: AugmentedDynamicItemDataset
    ) -> AugmentedDynamicItemDataset:
        raise NotImplementedError

    def __getattribute__(self, name):
        value = super().__getattribute__(name)
        if isinstance(value, DynamicItem):
            value.func = value.func.__get__(self)
        return value


class SequentialDataPipe(DataPipe):
    def __init__(self, *pipes: List[DataPipe]) -> None:
        self._pipes = pipes

    def forward(
        self, dataset: AugmentedDynamicItemDataset
    ) -> AugmentedDynamicItemDataset:
        for pipe in self._pipes:
            dataset = pipe(dataset)
        return dataset


def default_collate_fn(samples, padding_value: int = 0):
    """
    Each item in **DynamicItemDataset** is a dict
    This function pad (or transform into numpy list) a batch of dict

    Args:
        samples (List[dict]): Suppose each Container is in

            .. code-block:: yaml

                wav: a single waveform
                label: a single string

    Return:
        dict

        .. code-block:: yaml

            wav: padded waveforms
            label: np.array([a list of string labels])
    """
    assert isinstance(samples[0], dict)
    keys = samples[0].keys()
    padded_samples = dict()
    for key in keys:
        values = [sample[key] for sample in samples]
        if isinstance(values[0], int):
            values = torch.LongTensor(values)
        elif isinstance(values[0], float):
            values = torch.FloatTensor(values)
        elif isinstance(values[0], np.ndarray):
            values = [torch.from_numpy(value).float() for value in values]
            values = pad_sequence(values, batch_first=True, padding_value=padding_value)
        elif isinstance(values[0], torch.Tensor):
            values = pad_sequence(values, batch_first=True, padding_value=padding_value)
        else:
            values = np.array(values, dtype="object")
        padded_samples[key] = values
    return padded_samples