File size: 13,196 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 |
"""
S3PRL Upstream Collection and some utilities
Authors:
* Leo 2022
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from s3prl import hub
from s3prl.util.pseudo_data import get_pseudo_wavs
__all__ = [
"S3PRLUpstream",
"Featurizer",
"UpstreamDownstreamModel",
]
MIN_SECOND = 0.05
SAMPLE_RATE = 16000
def randomize_upstream(upstream: nn.Module):
def init_weights(m: nn.Module):
for p in m.parameters():
if p.dim() < 2:
torch.nn.init.normal_(p, mean=p.mean().item(), std=p.std().item())
else:
torch.nn.init.xavier_normal_(p)
upstream.apply(init_weights)
class S3PRLUpstream(nn.Module):
"""
This is an easy interface for using all the models in S3PRL.
See :doc:`../tutorial/upstream_collection` for the example usage and all the supported models.
Args:
name (str):
can be "apc", "hubert", "wav2vec2". See :obj:`available_names` for all the supported names
path_or_url (str):
The source of the checkpoint. Might be a local path or a URL
refresh (bool): (default, False)
If false, only downlaod checkpoint if not yet downloaded before.
If true, force to re-download the checkpoint.
extra_conf (dict): (default, None)
The extra arguments for each specific upstream, the available options are
shown in each upstream section
randomize (bool): (default, False)
If True, randomize the upstream model
.. note::
When using **S3PRLUpstream** with :code:`refresh=True` and multiprocessing (e.g. DDP),
the checkpoint will only be downloaded once, and the other processes will simply
re-use the newly downloaded checkpoint, instead of re-downloading on every processes,
which can be very time/bandwidth consuming.
Example::
>>> import torch
>>> from s3prl.nn import S3PRLUpstream
...
>>> model = S3PRLUpstream("hubert")
>>> model.eval()
...
>>> with torch.no_grad():
... wavs = torch.randn(2, 16000 * 2)
... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
... all_hs, all_hs_len = model(wavs, wavs_len)
...
>>> for hs, hs_len in zip(all_hs, all_hs_len):
... assert isinstance(hs, torch.FloatTensor)
... assert isinstance(hs_len, torch.LongTensor)
...
... batch_size, max_seq_len, hidden_size = hs.shape
... assert hs_len.dim() == 1
"""
@classmethod
def available_names(cls, only_registered_ckpt: bool = False) -> List[str]:
"""
All the available names supported by this S3PRLUpstream
Args:
only_registered_ckpt (bool):
ignore entry names which require to give `path_or_url`.
That is, the entry names without the registered checkpoint sources.
These names end with :code:`_local` (for local path), :code:`_url`
(for URL) or :code:`_custom` (auto-determine path or URL)
"""
return hub.options(only_registered_ckpt)
def __init__(
self,
name: str,
path_or_url: str = None,
refresh: bool = False,
normalize: bool = False,
extra_conf: dict = None,
randomize: bool = False,
):
super().__init__()
upstream_conf = {"refresh": refresh, **(extra_conf or {})}
if path_or_url is not None:
upstream_conf["ckpt"] = path_or_url
self.upstream = getattr(hub, name)(**upstream_conf)
if randomize:
randomize_upstream(self.upstream)
self.normalize = normalize
self.upstream.eval()
with torch.no_grad():
hs = self.upstream(get_pseudo_wavs())["hidden_states"]
self.upstream.train()
self._num_layers = len(hs)
self._hidden_sizes = []
for h in hs:
self._hidden_sizes.append(h.size(-1))
downsample_rates = self.upstream.get_downsample_rates("hidden_states")
if isinstance(downsample_rates, int):
self._downsample_rates = [downsample_rates] * self._num_layers
elif isinstance(downsample_rates, (tuple, list)):
self._downsample_rates = downsample_rates
else:
raise ValueError
@property
def num_layers(self) -> int:
"""
Number of hidden sizes. All the upstream have a deterministic
number of layers. That is, layer drop is turned off by default.
"""
return self._num_layers
@property
def downsample_rates(self) -> List[int]:
"""
Downsampling rate from 16000 Hz audio of each layer.
Usually, all layers have the same downsampling rate,
but might not be the case for some advanced upstreams.
"""
return self._downsample_rates
@property
def hidden_sizes(self) -> List[int]:
"""
The hidden size of each layer
"""
return self._hidden_sizes
def _match_length(self, xs, target_max_len: int):
xs_max_len = xs.size(1)
if xs_max_len > target_max_len:
assert xs_max_len // target_max_len == 1, f"{xs_max_len}, {target_max_len}"
xs = xs[:, :target_max_len, :]
elif xs_max_len < target_max_len:
assert target_max_len // xs_max_len == 1, f"{target_max_len}, {xs_max_len}"
xs = torch.cat(
(xs, xs[:, -1:, :].repeat(1, target_max_len - xs_max_len, 1)), dim=1
)
return xs
def forward(self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor):
"""
Args:
wavs (torch.FloatTensor): (batch_size, seqlen) or (batch_size, seqlen, 1)
wavs_len (torch.LongTensor): (batch_size, )
Return:
List[torch.FloatTensor], List[torch.LongTensor]
1. all the layers of hidden states: List[ (batch_size, max_seq_len, hidden_size) ]
2. the valid length for each hidden states: List[ (batch_size, ) ]
"""
if wavs.dim() == 3:
wavs = wavs.squeeze(-1)
original_wavs_len = wavs_len
if max(original_wavs_len) < MIN_SECOND * SAMPLE_RATE:
padded_samples = int(MIN_SECOND * SAMPLE_RATE) - max(original_wavs_len)
wavs = torch.cat(
(wavs, wavs.new_zeros(wavs.size(0), padded_samples)),
dim=1,
)
wavs_len = wavs_len + padded_samples
wavs_list = []
for wav, wav_len in zip(wavs, wavs_len):
wavs_list.append(wav[:wav_len])
hidden_states = self.upstream(wavs_list)["hidden_states"]
assert isinstance(hidden_states, (list, tuple))
assert (
len(hidden_states) == self.num_layers
), f"{len(hidden_states)}, {self.num_layers}"
max_wav_len = int(max(wavs_len))
all_hs = []
all_lens = []
for h, stride in zip(hidden_states, self.downsample_rates):
expected_max_h_len = len(range(0, max_wav_len, stride))
h = self._match_length(h, expected_max_h_len)
assert h.size(1) == expected_max_h_len
h_len = torch.div(original_wavs_len - 1, stride, rounding_mode="floor") + 1
h = h[:, : max(h_len), :]
if self.normalize:
h = F.layer_norm(h, h.shape[-1:])
all_hs.append(h)
all_lens.append(h_len)
return all_hs, all_lens
class Featurizer(nn.Module):
"""
Featurizer take the :obj:`S3PRLUpstream`'s multiple layer of hidden_states and
reduce (standardize) them into a single hidden_states, to connect with downstream NNs.
This basic Featurizer expects all the layers to have same stride and hidden_size
When the input upstream only have a single layer of hidden states, use that directly.
If multiple layers are presented, add a trainable weighted-sum on top of those layers.
Args:
upstream (:obj:`S3PRLUpstream`):
the upstream to extract features, this upstream is used only for initialization
and will not be kept in this Featurizer object
layer_selections (List[int]):
To select a subset of hidden states from the given upstream by layer ids (0-index)
If None (default), than all the layer of hidden states are selected
normalize (bool):
Whether to apply layer norm on all the hidden states before weighted-sum
This can help convergence in some cases, but not used in SUPERB to ensure the
fidelity of each upstream's extracted representation.
Example::
>>> import torch
>>> from s3prl.nn import S3PRLUpstream, Featurizer
...
>>> model = S3PRLUpstream("hubert")
>>> model.eval()
...
>>> with torch.no_grad():
... wavs = torch.randn(2, 16000 * 2)
... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
... all_hs, all_hs_len = model(wavs, wavs_len)
...
>>> featurizer = Featurizer(model)
>>> hs, hs_len = featurizer(all_hs, all_hs_len)
...
>>> assert isinstance(hs, torch.FloatTensor)
>>> assert isinstance(hs_len, torch.LongTensor)
>>> batch_size, max_seq_len, hidden_size = hs.shape
>>> assert hs_len.dim() == 1
"""
def __init__(
self,
upstream: S3PRLUpstream,
layer_selections: List[int] = None,
normalize: bool = False,
):
super().__init__()
assert len(set(upstream.hidden_sizes)) == 1
assert len(set(upstream.downsample_rates)) == 1
self._output_size = upstream.hidden_sizes[0]
self._downsample_rate = upstream.downsample_rates[0]
self.normalize = normalize
if upstream.num_layers > 1:
if layer_selections is not None:
assert upstream.num_layers >= len(layer_selections)
self.layer_selections = sorted(layer_selections)
else:
self.layer_selections = list(range(upstream.num_layers))
self.weights = nn.Parameter(torch.zeros(len(self.layer_selections)))
@property
def output_size(self) -> int:
"""
The hidden size of the final weighted-sum output
"""
return self._output_size
@property
def downsample_rate(self) -> int:
"""
The downsample rate (from 16k Hz waveform) of the final weighted-sum output
"""
return self._downsample_rate
def _weighted_sum(self, all_hs, all_lens):
assert len(all_hs) == len(all_lens) > 1
for l in all_lens[1:]:
torch.allclose(all_lens[0], l)
stacked_hs = torch.stack(all_hs, dim=0)
if self.normalize:
stacked_hs = F.layer_norm(stacked_hs, (stacked_hs.shape[-1],))
_, *origin_shape = stacked_hs.shape
stacked_hs = stacked_hs.view(len(self.layer_selections), -1)
norm_weights = F.softmax(self.weights, dim=-1)
weighted_hs = (norm_weights.unsqueeze(-1) * stacked_hs).sum(dim=0)
weighted_hs = weighted_hs.view(*origin_shape)
return weighted_hs, all_lens[0]
def forward(
self, all_hs: List[torch.FloatTensor], all_lens: List[torch.LongTensor]
):
"""
Args:
all_hs (List[torch.FloatTensor]): List[ (batch_size, seq_len, hidden_size) ]
all_lens (List[torch.LongTensor]): List[ (batch_size, ) ]
Return:
torch.FloatTensor, torch.LongTensor
1. The weighted-sum result, (batch_size, seq_len, hidden_size)
2. the valid length of the result, (batch_size, )
"""
if len(all_hs) == 1:
return all_hs[0], all_lens[0]
all_hs = [h for idx, h in enumerate(all_hs) if idx in self.layer_selections]
all_lens = [l for idx, l in enumerate(all_lens) if idx in self.layer_selections]
hs, hs_len = self._weighted_sum(all_hs, all_lens)
return hs, hs_len
class UpstreamDownstreamModel(nn.Module):
def __init__(
self,
upstream: S3PRLUpstream,
featurizer: Featurizer,
downstream,
upstream_trainable: bool = False,
):
super().__init__()
self.upstream = upstream
self.featurizer = featurizer
self.downstream = downstream
self.upstream_trainable = upstream_trainable
@property
def input_size(self):
return 1
@property
def downsample_rate(self):
return self.featurizer.downsample_rate
@property
def output_size(self):
return self.downstream.output_size
def forward(self, wav, wav_len, *args, **kwargs):
with torch.set_grad_enabled(self.upstream_trainable):
if not self.upstream_trainable:
self.upstream.eval()
hs, hs_len = self.upstream(wav, wav_len)
h, h_len = self.featurizer(hs, hs_len)
return self.downstream(h, h_len, *args, **kwargs)
|