File size: 13,196 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""
S3PRL Upstream Collection and some utilities

Authors:
  * Leo 2022
"""

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from s3prl import hub
from s3prl.util.pseudo_data import get_pseudo_wavs

__all__ = [
    "S3PRLUpstream",
    "Featurizer",
    "UpstreamDownstreamModel",
]

MIN_SECOND = 0.05
SAMPLE_RATE = 16000


def randomize_upstream(upstream: nn.Module):
    def init_weights(m: nn.Module):
        for p in m.parameters():
            if p.dim() < 2:
                torch.nn.init.normal_(p, mean=p.mean().item(), std=p.std().item())
            else:
                torch.nn.init.xavier_normal_(p)

    upstream.apply(init_weights)


class S3PRLUpstream(nn.Module):
    """
    This is an easy interface for using all the models in S3PRL.
    See :doc:`../tutorial/upstream_collection` for the example usage and all the supported models.

    Args:
        name (str):
            can be "apc", "hubert", "wav2vec2". See :obj:`available_names` for all the supported names

        path_or_url (str):
            The source of the checkpoint. Might be a local path or a URL

        refresh (bool): (default, False)
            If false, only downlaod checkpoint if not yet downloaded before.
            If true, force to re-download the checkpoint.

        extra_conf (dict): (default, None)
            The extra arguments for each specific upstream, the available options are
            shown in each upstream section

        randomize (bool): (default, False)
            If True, randomize the upstream model

    .. note::

        When using **S3PRLUpstream** with :code:`refresh=True` and multiprocessing (e.g. DDP),
        the checkpoint will only be downloaded once, and the other processes will simply
        re-use the newly downloaded checkpoint, instead of re-downloading on every processes,
        which can be very time/bandwidth consuming.

    Example::

        >>> import torch
        >>> from s3prl.nn import S3PRLUpstream
        ...
        >>> model = S3PRLUpstream("hubert")
        >>> model.eval()
        ...
        >>> with torch.no_grad():
        ...     wavs = torch.randn(2, 16000 * 2)
        ...     wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
        ...     all_hs, all_hs_len = model(wavs, wavs_len)
        ...
        >>> for hs, hs_len in zip(all_hs, all_hs_len):
        ...     assert isinstance(hs, torch.FloatTensor)
        ...     assert isinstance(hs_len, torch.LongTensor)
        ...
        ...     batch_size, max_seq_len, hidden_size = hs.shape
        ...     assert hs_len.dim() == 1
    """

    @classmethod
    def available_names(cls, only_registered_ckpt: bool = False) -> List[str]:
        """
        All the available names supported by this S3PRLUpstream

        Args:
            only_registered_ckpt (bool):
                ignore entry names which require to give `path_or_url`.
                That is, the entry names without the registered checkpoint sources.
                These names end with :code:`_local` (for local path), :code:`_url`
                (for URL) or :code:`_custom` (auto-determine path or URL)
        """
        return hub.options(only_registered_ckpt)

    def __init__(
        self,
        name: str,
        path_or_url: str = None,
        refresh: bool = False,
        normalize: bool = False,
        extra_conf: dict = None,
        randomize: bool = False,
    ):
        super().__init__()
        upstream_conf = {"refresh": refresh, **(extra_conf or {})}
        if path_or_url is not None:
            upstream_conf["ckpt"] = path_or_url

        self.upstream = getattr(hub, name)(**upstream_conf)

        if randomize:
            randomize_upstream(self.upstream)

        self.normalize = normalize

        self.upstream.eval()
        with torch.no_grad():
            hs = self.upstream(get_pseudo_wavs())["hidden_states"]
        self.upstream.train()
        self._num_layers = len(hs)

        self._hidden_sizes = []
        for h in hs:
            self._hidden_sizes.append(h.size(-1))

        downsample_rates = self.upstream.get_downsample_rates("hidden_states")
        if isinstance(downsample_rates, int):
            self._downsample_rates = [downsample_rates] * self._num_layers
        elif isinstance(downsample_rates, (tuple, list)):
            self._downsample_rates = downsample_rates
        else:
            raise ValueError

    @property
    def num_layers(self) -> int:
        """
        Number of hidden sizes. All the upstream have a deterministic
        number of layers. That is, layer drop is turned off by default.
        """
        return self._num_layers

    @property
    def downsample_rates(self) -> List[int]:
        """
        Downsampling rate from 16000 Hz audio of each layer.
        Usually, all layers have the same downsampling rate,
        but might not be the case for some advanced upstreams.
        """
        return self._downsample_rates

    @property
    def hidden_sizes(self) -> List[int]:
        """
        The hidden size of each layer
        """
        return self._hidden_sizes

    def _match_length(self, xs, target_max_len: int):
        xs_max_len = xs.size(1)

        if xs_max_len > target_max_len:
            assert xs_max_len // target_max_len == 1, f"{xs_max_len}, {target_max_len}"
            xs = xs[:, :target_max_len, :]

        elif xs_max_len < target_max_len:
            assert target_max_len // xs_max_len == 1, f"{target_max_len}, {xs_max_len}"
            xs = torch.cat(
                (xs, xs[:, -1:, :].repeat(1, target_max_len - xs_max_len, 1)), dim=1
            )

        return xs

    def forward(self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor):
        """
        Args:
            wavs (torch.FloatTensor): (batch_size, seqlen) or (batch_size, seqlen, 1)
            wavs_len (torch.LongTensor): (batch_size, )

        Return:
            List[torch.FloatTensor], List[torch.LongTensor]

            1. all the layers of hidden states: List[ (batch_size, max_seq_len, hidden_size) ]
            2. the valid length for each hidden states: List[ (batch_size, ) ]
        """
        if wavs.dim() == 3:
            wavs = wavs.squeeze(-1)

        original_wavs_len = wavs_len
        if max(original_wavs_len) < MIN_SECOND * SAMPLE_RATE:
            padded_samples = int(MIN_SECOND * SAMPLE_RATE) - max(original_wavs_len)
            wavs = torch.cat(
                (wavs, wavs.new_zeros(wavs.size(0), padded_samples)),
                dim=1,
            )
            wavs_len = wavs_len + padded_samples

        wavs_list = []
        for wav, wav_len in zip(wavs, wavs_len):
            wavs_list.append(wav[:wav_len])

        hidden_states = self.upstream(wavs_list)["hidden_states"]
        assert isinstance(hidden_states, (list, tuple))
        assert (
            len(hidden_states) == self.num_layers
        ), f"{len(hidden_states)}, {self.num_layers}"

        max_wav_len = int(max(wavs_len))
        all_hs = []
        all_lens = []
        for h, stride in zip(hidden_states, self.downsample_rates):
            expected_max_h_len = len(range(0, max_wav_len, stride))
            h = self._match_length(h, expected_max_h_len)
            assert h.size(1) == expected_max_h_len

            h_len = torch.div(original_wavs_len - 1, stride, rounding_mode="floor") + 1
            h = h[:, : max(h_len), :]
            if self.normalize:
                h = F.layer_norm(h, h.shape[-1:])

            all_hs.append(h)
            all_lens.append(h_len)

        return all_hs, all_lens


class Featurizer(nn.Module):
    """
    Featurizer take the :obj:`S3PRLUpstream`'s multiple layer of hidden_states and
    reduce (standardize) them into a single hidden_states, to connect with downstream NNs.

    This basic Featurizer expects all the layers to have same stride and hidden_size
    When the input upstream only have a single layer of hidden states, use that directly.
    If multiple layers are presented, add a trainable weighted-sum on top of those layers.

    Args:
        upstream (:obj:`S3PRLUpstream`):
            the upstream to extract features, this upstream is used only for initialization
            and will not be kept in this Featurizer object
        layer_selections (List[int]):
            To select a subset of hidden states from the given upstream by layer ids (0-index)
            If None (default), than all the layer of hidden states are selected
        normalize (bool):
            Whether to apply layer norm on all the hidden states before weighted-sum
            This can help convergence in some cases, but not used in SUPERB to ensure the
            fidelity of each upstream's extracted representation.

    Example::

        >>> import torch
        >>> from s3prl.nn import S3PRLUpstream, Featurizer
        ...
        >>> model = S3PRLUpstream("hubert")
        >>> model.eval()
        ...
        >>> with torch.no_grad():
        ...     wavs = torch.randn(2, 16000 * 2)
        ...     wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
        ...     all_hs, all_hs_len = model(wavs, wavs_len)
        ...
        >>> featurizer = Featurizer(model)
        >>> hs, hs_len = featurizer(all_hs, all_hs_len)
        ...
        >>> assert isinstance(hs, torch.FloatTensor)
        >>> assert isinstance(hs_len, torch.LongTensor)
        >>> batch_size, max_seq_len, hidden_size = hs.shape
        >>> assert hs_len.dim() == 1
    """

    def __init__(
        self,
        upstream: S3PRLUpstream,
        layer_selections: List[int] = None,
        normalize: bool = False,
    ):
        super().__init__()
        assert len(set(upstream.hidden_sizes)) == 1
        assert len(set(upstream.downsample_rates)) == 1
        self._output_size = upstream.hidden_sizes[0]
        self._downsample_rate = upstream.downsample_rates[0]
        self.normalize = normalize

        if upstream.num_layers > 1:
            if layer_selections is not None:
                assert upstream.num_layers >= len(layer_selections)
                self.layer_selections = sorted(layer_selections)
            else:
                self.layer_selections = list(range(upstream.num_layers))
            self.weights = nn.Parameter(torch.zeros(len(self.layer_selections)))

    @property
    def output_size(self) -> int:
        """
        The hidden size of the final weighted-sum output
        """
        return self._output_size

    @property
    def downsample_rate(self) -> int:
        """
        The downsample rate (from 16k Hz waveform) of the final weighted-sum output
        """
        return self._downsample_rate

    def _weighted_sum(self, all_hs, all_lens):
        assert len(all_hs) == len(all_lens) > 1
        for l in all_lens[1:]:
            torch.allclose(all_lens[0], l)
        stacked_hs = torch.stack(all_hs, dim=0)

        if self.normalize:
            stacked_hs = F.layer_norm(stacked_hs, (stacked_hs.shape[-1],))

        _, *origin_shape = stacked_hs.shape
        stacked_hs = stacked_hs.view(len(self.layer_selections), -1)
        norm_weights = F.softmax(self.weights, dim=-1)
        weighted_hs = (norm_weights.unsqueeze(-1) * stacked_hs).sum(dim=0)
        weighted_hs = weighted_hs.view(*origin_shape)

        return weighted_hs, all_lens[0]

    def forward(
        self, all_hs: List[torch.FloatTensor], all_lens: List[torch.LongTensor]
    ):
        """
        Args:
            all_hs (List[torch.FloatTensor]): List[ (batch_size, seq_len, hidden_size) ]
            all_lens (List[torch.LongTensor]): List[ (batch_size, ) ]

        Return:
            torch.FloatTensor, torch.LongTensor

            1. The weighted-sum result, (batch_size, seq_len, hidden_size)
            2. the valid length of the result, (batch_size, )
        """
        if len(all_hs) == 1:
            return all_hs[0], all_lens[0]

        all_hs = [h for idx, h in enumerate(all_hs) if idx in self.layer_selections]
        all_lens = [l for idx, l in enumerate(all_lens) if idx in self.layer_selections]
        hs, hs_len = self._weighted_sum(all_hs, all_lens)
        return hs, hs_len


class UpstreamDownstreamModel(nn.Module):
    def __init__(
        self,
        upstream: S3PRLUpstream,
        featurizer: Featurizer,
        downstream,
        upstream_trainable: bool = False,
    ):
        super().__init__()
        self.upstream = upstream
        self.featurizer = featurizer
        self.downstream = downstream
        self.upstream_trainable = upstream_trainable

    @property
    def input_size(self):
        return 1

    @property
    def downsample_rate(self):
        return self.featurizer.downsample_rate

    @property
    def output_size(self):
        return self.downstream.output_size

    def forward(self, wav, wav_len, *args, **kwargs):
        with torch.set_grad_enabled(self.upstream_trainable):
            if not self.upstream_trainable:
                self.upstream.eval()
            hs, hs_len = self.upstream(wav, wav_len)

        h, h_len = self.featurizer(hs, hs_len)
        return self.downstream(h, h_len, *args, **kwargs)