|
""" |
|
S3PRL Upstream Collection and some utilities |
|
|
|
Authors: |
|
* Leo 2022 |
|
""" |
|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from s3prl import hub |
|
from s3prl.util.pseudo_data import get_pseudo_wavs |
|
|
|
__all__ = [ |
|
"S3PRLUpstream", |
|
"Featurizer", |
|
"UpstreamDownstreamModel", |
|
] |
|
|
|
MIN_SECOND = 0.05 |
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
def randomize_upstream(upstream: nn.Module): |
|
def init_weights(m: nn.Module): |
|
for p in m.parameters(): |
|
if p.dim() < 2: |
|
torch.nn.init.normal_(p, mean=p.mean().item(), std=p.std().item()) |
|
else: |
|
torch.nn.init.xavier_normal_(p) |
|
|
|
upstream.apply(init_weights) |
|
|
|
|
|
class S3PRLUpstream(nn.Module): |
|
""" |
|
This is an easy interface for using all the models in S3PRL. |
|
See :doc:`../tutorial/upstream_collection` for the example usage and all the supported models. |
|
|
|
Args: |
|
name (str): |
|
can be "apc", "hubert", "wav2vec2". See :obj:`available_names` for all the supported names |
|
|
|
path_or_url (str): |
|
The source of the checkpoint. Might be a local path or a URL |
|
|
|
refresh (bool): (default, False) |
|
If false, only downlaod checkpoint if not yet downloaded before. |
|
If true, force to re-download the checkpoint. |
|
|
|
extra_conf (dict): (default, None) |
|
The extra arguments for each specific upstream, the available options are |
|
shown in each upstream section |
|
|
|
randomize (bool): (default, False) |
|
If True, randomize the upstream model |
|
|
|
.. note:: |
|
|
|
When using **S3PRLUpstream** with :code:`refresh=True` and multiprocessing (e.g. DDP), |
|
the checkpoint will only be downloaded once, and the other processes will simply |
|
re-use the newly downloaded checkpoint, instead of re-downloading on every processes, |
|
which can be very time/bandwidth consuming. |
|
|
|
Example:: |
|
|
|
>>> import torch |
|
>>> from s3prl.nn import S3PRLUpstream |
|
... |
|
>>> model = S3PRLUpstream("hubert") |
|
>>> model.eval() |
|
... |
|
>>> with torch.no_grad(): |
|
... wavs = torch.randn(2, 16000 * 2) |
|
... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2]) |
|
... all_hs, all_hs_len = model(wavs, wavs_len) |
|
... |
|
>>> for hs, hs_len in zip(all_hs, all_hs_len): |
|
... assert isinstance(hs, torch.FloatTensor) |
|
... assert isinstance(hs_len, torch.LongTensor) |
|
... |
|
... batch_size, max_seq_len, hidden_size = hs.shape |
|
... assert hs_len.dim() == 1 |
|
""" |
|
|
|
@classmethod |
|
def available_names(cls, only_registered_ckpt: bool = False) -> List[str]: |
|
""" |
|
All the available names supported by this S3PRLUpstream |
|
|
|
Args: |
|
only_registered_ckpt (bool): |
|
ignore entry names which require to give `path_or_url`. |
|
That is, the entry names without the registered checkpoint sources. |
|
These names end with :code:`_local` (for local path), :code:`_url` |
|
(for URL) or :code:`_custom` (auto-determine path or URL) |
|
""" |
|
return hub.options(only_registered_ckpt) |
|
|
|
def __init__( |
|
self, |
|
name: str, |
|
path_or_url: str = None, |
|
refresh: bool = False, |
|
normalize: bool = False, |
|
extra_conf: dict = None, |
|
randomize: bool = False, |
|
): |
|
super().__init__() |
|
upstream_conf = {"refresh": refresh, **(extra_conf or {})} |
|
if path_or_url is not None: |
|
upstream_conf["ckpt"] = path_or_url |
|
|
|
self.upstream = getattr(hub, name)(**upstream_conf) |
|
|
|
if randomize: |
|
randomize_upstream(self.upstream) |
|
|
|
self.normalize = normalize |
|
|
|
self.upstream.eval() |
|
with torch.no_grad(): |
|
hs = self.upstream(get_pseudo_wavs())["hidden_states"] |
|
self.upstream.train() |
|
self._num_layers = len(hs) |
|
|
|
self._hidden_sizes = [] |
|
for h in hs: |
|
self._hidden_sizes.append(h.size(-1)) |
|
|
|
downsample_rates = self.upstream.get_downsample_rates("hidden_states") |
|
if isinstance(downsample_rates, int): |
|
self._downsample_rates = [downsample_rates] * self._num_layers |
|
elif isinstance(downsample_rates, (tuple, list)): |
|
self._downsample_rates = downsample_rates |
|
else: |
|
raise ValueError |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
""" |
|
Number of hidden sizes. All the upstream have a deterministic |
|
number of layers. That is, layer drop is turned off by default. |
|
""" |
|
return self._num_layers |
|
|
|
@property |
|
def downsample_rates(self) -> List[int]: |
|
""" |
|
Downsampling rate from 16000 Hz audio of each layer. |
|
Usually, all layers have the same downsampling rate, |
|
but might not be the case for some advanced upstreams. |
|
""" |
|
return self._downsample_rates |
|
|
|
@property |
|
def hidden_sizes(self) -> List[int]: |
|
""" |
|
The hidden size of each layer |
|
""" |
|
return self._hidden_sizes |
|
|
|
def _match_length(self, xs, target_max_len: int): |
|
xs_max_len = xs.size(1) |
|
|
|
if xs_max_len > target_max_len: |
|
assert xs_max_len // target_max_len == 1, f"{xs_max_len}, {target_max_len}" |
|
xs = xs[:, :target_max_len, :] |
|
|
|
elif xs_max_len < target_max_len: |
|
assert target_max_len // xs_max_len == 1, f"{target_max_len}, {xs_max_len}" |
|
xs = torch.cat( |
|
(xs, xs[:, -1:, :].repeat(1, target_max_len - xs_max_len, 1)), dim=1 |
|
) |
|
|
|
return xs |
|
|
|
def forward(self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor): |
|
""" |
|
Args: |
|
wavs (torch.FloatTensor): (batch_size, seqlen) or (batch_size, seqlen, 1) |
|
wavs_len (torch.LongTensor): (batch_size, ) |
|
|
|
Return: |
|
List[torch.FloatTensor], List[torch.LongTensor] |
|
|
|
1. all the layers of hidden states: List[ (batch_size, max_seq_len, hidden_size) ] |
|
2. the valid length for each hidden states: List[ (batch_size, ) ] |
|
""" |
|
if wavs.dim() == 3: |
|
wavs = wavs.squeeze(-1) |
|
|
|
original_wavs_len = wavs_len |
|
if max(original_wavs_len) < MIN_SECOND * SAMPLE_RATE: |
|
padded_samples = int(MIN_SECOND * SAMPLE_RATE) - max(original_wavs_len) |
|
wavs = torch.cat( |
|
(wavs, wavs.new_zeros(wavs.size(0), padded_samples)), |
|
dim=1, |
|
) |
|
wavs_len = wavs_len + padded_samples |
|
|
|
wavs_list = [] |
|
for wav, wav_len in zip(wavs, wavs_len): |
|
wavs_list.append(wav[:wav_len]) |
|
|
|
hidden_states = self.upstream(wavs_list)["hidden_states"] |
|
assert isinstance(hidden_states, (list, tuple)) |
|
assert ( |
|
len(hidden_states) == self.num_layers |
|
), f"{len(hidden_states)}, {self.num_layers}" |
|
|
|
max_wav_len = int(max(wavs_len)) |
|
all_hs = [] |
|
all_lens = [] |
|
for h, stride in zip(hidden_states, self.downsample_rates): |
|
expected_max_h_len = len(range(0, max_wav_len, stride)) |
|
h = self._match_length(h, expected_max_h_len) |
|
assert h.size(1) == expected_max_h_len |
|
|
|
h_len = torch.div(original_wavs_len - 1, stride, rounding_mode="floor") + 1 |
|
h = h[:, : max(h_len), :] |
|
if self.normalize: |
|
h = F.layer_norm(h, h.shape[-1:]) |
|
|
|
all_hs.append(h) |
|
all_lens.append(h_len) |
|
|
|
return all_hs, all_lens |
|
|
|
|
|
class Featurizer(nn.Module): |
|
""" |
|
Featurizer take the :obj:`S3PRLUpstream`'s multiple layer of hidden_states and |
|
reduce (standardize) them into a single hidden_states, to connect with downstream NNs. |
|
|
|
This basic Featurizer expects all the layers to have same stride and hidden_size |
|
When the input upstream only have a single layer of hidden states, use that directly. |
|
If multiple layers are presented, add a trainable weighted-sum on top of those layers. |
|
|
|
Args: |
|
upstream (:obj:`S3PRLUpstream`): |
|
the upstream to extract features, this upstream is used only for initialization |
|
and will not be kept in this Featurizer object |
|
layer_selections (List[int]): |
|
To select a subset of hidden states from the given upstream by layer ids (0-index) |
|
If None (default), than all the layer of hidden states are selected |
|
normalize (bool): |
|
Whether to apply layer norm on all the hidden states before weighted-sum |
|
This can help convergence in some cases, but not used in SUPERB to ensure the |
|
fidelity of each upstream's extracted representation. |
|
|
|
Example:: |
|
|
|
>>> import torch |
|
>>> from s3prl.nn import S3PRLUpstream, Featurizer |
|
... |
|
>>> model = S3PRLUpstream("hubert") |
|
>>> model.eval() |
|
... |
|
>>> with torch.no_grad(): |
|
... wavs = torch.randn(2, 16000 * 2) |
|
... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2]) |
|
... all_hs, all_hs_len = model(wavs, wavs_len) |
|
... |
|
>>> featurizer = Featurizer(model) |
|
>>> hs, hs_len = featurizer(all_hs, all_hs_len) |
|
... |
|
>>> assert isinstance(hs, torch.FloatTensor) |
|
>>> assert isinstance(hs_len, torch.LongTensor) |
|
>>> batch_size, max_seq_len, hidden_size = hs.shape |
|
>>> assert hs_len.dim() == 1 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
upstream: S3PRLUpstream, |
|
layer_selections: List[int] = None, |
|
normalize: bool = False, |
|
): |
|
super().__init__() |
|
assert len(set(upstream.hidden_sizes)) == 1 |
|
assert len(set(upstream.downsample_rates)) == 1 |
|
self._output_size = upstream.hidden_sizes[0] |
|
self._downsample_rate = upstream.downsample_rates[0] |
|
self.normalize = normalize |
|
|
|
if upstream.num_layers > 1: |
|
if layer_selections is not None: |
|
assert upstream.num_layers >= len(layer_selections) |
|
self.layer_selections = sorted(layer_selections) |
|
else: |
|
self.layer_selections = list(range(upstream.num_layers)) |
|
self.weights = nn.Parameter(torch.zeros(len(self.layer_selections))) |
|
|
|
@property |
|
def output_size(self) -> int: |
|
""" |
|
The hidden size of the final weighted-sum output |
|
""" |
|
return self._output_size |
|
|
|
@property |
|
def downsample_rate(self) -> int: |
|
""" |
|
The downsample rate (from 16k Hz waveform) of the final weighted-sum output |
|
""" |
|
return self._downsample_rate |
|
|
|
def _weighted_sum(self, all_hs, all_lens): |
|
assert len(all_hs) == len(all_lens) > 1 |
|
for l in all_lens[1:]: |
|
torch.allclose(all_lens[0], l) |
|
stacked_hs = torch.stack(all_hs, dim=0) |
|
|
|
if self.normalize: |
|
stacked_hs = F.layer_norm(stacked_hs, (stacked_hs.shape[-1],)) |
|
|
|
_, *origin_shape = stacked_hs.shape |
|
stacked_hs = stacked_hs.view(len(self.layer_selections), -1) |
|
norm_weights = F.softmax(self.weights, dim=-1) |
|
weighted_hs = (norm_weights.unsqueeze(-1) * stacked_hs).sum(dim=0) |
|
weighted_hs = weighted_hs.view(*origin_shape) |
|
|
|
return weighted_hs, all_lens[0] |
|
|
|
def forward( |
|
self, all_hs: List[torch.FloatTensor], all_lens: List[torch.LongTensor] |
|
): |
|
""" |
|
Args: |
|
all_hs (List[torch.FloatTensor]): List[ (batch_size, seq_len, hidden_size) ] |
|
all_lens (List[torch.LongTensor]): List[ (batch_size, ) ] |
|
|
|
Return: |
|
torch.FloatTensor, torch.LongTensor |
|
|
|
1. The weighted-sum result, (batch_size, seq_len, hidden_size) |
|
2. the valid length of the result, (batch_size, ) |
|
""" |
|
if len(all_hs) == 1: |
|
return all_hs[0], all_lens[0] |
|
|
|
all_hs = [h for idx, h in enumerate(all_hs) if idx in self.layer_selections] |
|
all_lens = [l for idx, l in enumerate(all_lens) if idx in self.layer_selections] |
|
hs, hs_len = self._weighted_sum(all_hs, all_lens) |
|
return hs, hs_len |
|
|
|
|
|
class UpstreamDownstreamModel(nn.Module): |
|
def __init__( |
|
self, |
|
upstream: S3PRLUpstream, |
|
featurizer: Featurizer, |
|
downstream, |
|
upstream_trainable: bool = False, |
|
): |
|
super().__init__() |
|
self.upstream = upstream |
|
self.featurizer = featurizer |
|
self.downstream = downstream |
|
self.upstream_trainable = upstream_trainable |
|
|
|
@property |
|
def input_size(self): |
|
return 1 |
|
|
|
@property |
|
def downsample_rate(self): |
|
return self.featurizer.downsample_rate |
|
|
|
@property |
|
def output_size(self): |
|
return self.downstream.output_size |
|
|
|
def forward(self, wav, wav_len, *args, **kwargs): |
|
with torch.set_grad_enabled(self.upstream_trainable): |
|
if not self.upstream_trainable: |
|
self.upstream.eval() |
|
hs, hs_len = self.upstream(wav, wav_len) |
|
|
|
h, h_len = self.featurizer(hs, hs_len) |
|
return self.downstream(h, h_len, *args, **kwargs) |
|
|