|
import sys |
|
from typing import Callable, Dict, List, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
from s3prl.utility.helper import show |
|
|
|
SAMPLE_RATE = 16000 |
|
TOLERABLE_SEQLEN_DIFF = 5 |
|
|
|
|
|
class Hook: |
|
def __init__(self, module_path, transform, unique_identifier=None): |
|
self.module_path = module_path |
|
self.transform = transform |
|
self.unique_identifier = unique_identifier or module_path |
|
self.handler = None |
|
|
|
assert isinstance(self.module_path, str) |
|
assert callable(self.transform) |
|
assert isinstance(self.unique_identifier, str) |
|
|
|
|
|
class initHook(type): |
|
def __call__(cls, *args, **kwargs): |
|
instance = super().__call__(*args, **kwargs) |
|
for hook in instance.hooks: |
|
if hook.handler is None: |
|
instance._register_hook_handler(hook) |
|
return instance |
|
|
|
|
|
class UpstreamBase(nn.Module, metaclass=initHook): |
|
def __init__( |
|
self, |
|
hooks: List[Tuple] = None, |
|
hook_postprocess: Callable[ |
|
[List[Tuple[str, Tensor]]], List[Tuple[str, Tensor]] |
|
] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
hooks: each Tuple is an argument list for the Hook initializer |
|
""" |
|
super().__init__() |
|
self.hooks: List[Hook] = [Hook(*hook) for hook in hooks] if hooks else [] |
|
self.hook_postprocess = hook_postprocess |
|
self._hook_hiddens: List[Tuple(str, Tensor)] = [] |
|
|
|
def remove_all_hooks(self): |
|
for hook in self.hooks: |
|
hook.handler.remove() |
|
self.hooks.clear() |
|
|
|
def remove_hook(self, unique_identifier: str): |
|
updated_hooks = [] |
|
for hook in self.hooks: |
|
if hook.unique_identifier == unique_identifier: |
|
hook.handler.remove() |
|
else: |
|
updated_hooks.append(hook) |
|
self.hooks = updated_hooks |
|
|
|
def add_hook(self, *args, **kwargs): |
|
hook = Hook(*args, **kwargs) |
|
self._register_hook_handler(hook) |
|
self.hooks.append(hook) |
|
|
|
def _register_hook_handler(self, hook: Hook): |
|
module = eval(hook.module_path) |
|
if not isinstance(module, nn.Module): |
|
show( |
|
f"[UpstreamBase] - {hook.module_path} is not a valid nn.Module. Skip.", |
|
file=sys.stderr, |
|
) |
|
return |
|
|
|
if callable(hook.handler): |
|
show( |
|
f"[UpstreamBase] - Existing hook handler for {hook.unique_identifier} is found. Remove the existing one.", |
|
file=sys.stderr, |
|
) |
|
hook.handler.remove() |
|
|
|
def generate_hook_handler(hiddens: List, hook: Hook): |
|
def hook_handler(self, input, output): |
|
hiddens.append((hook.unique_identifier, hook.transform(input, output))) |
|
|
|
return hook_handler |
|
|
|
hook.handler = module.register_forward_hook( |
|
generate_hook_handler(self._hook_hiddens, hook) |
|
) |
|
|
|
def __call__(self, wavs: List[Tensor], *args, **kwargs): |
|
self._hook_hiddens.clear() |
|
|
|
result = super().__call__(wavs, *args, **kwargs) or {} |
|
assert isinstance(result, dict) |
|
|
|
if len(self._hook_hiddens) > 0: |
|
if ( |
|
result.get("_hidden_states_info") is not None |
|
or result.get("hidden_states") is not None |
|
or result.get("last_hidden_state") is not None |
|
): |
|
show( |
|
"[UpstreamBase] - If there are registered hooks, '_hidden_states_info', 'hidden_states', and " |
|
"'last_hidden_state' are reserved and should not be included in child class's return dict.", |
|
file=sys.stderr, |
|
) |
|
raise ValueError |
|
|
|
hook_hiddens = self._hook_hiddens.copy() |
|
self._hook_hiddens.clear() |
|
|
|
if callable(self.hook_postprocess): |
|
hook_hiddens = self.hook_postprocess(hook_hiddens) |
|
|
|
result["_hidden_states_info"], result["hidden_states"] = zip(*hook_hiddens) |
|
result["last_hidden_state"] = result["hidden_states"][-1] |
|
|
|
for layer_id, hidden_state in enumerate(result["hidden_states"]): |
|
result[f"hidden_state_{layer_id}"] = hidden_state |
|
|
|
return result |
|
|
|
|
|
class Featurizer(nn.Module): |
|
def __init__( |
|
self, |
|
upstream: UpstreamBase, |
|
feature_selection: str = "hidden_states", |
|
upstream_device: str = "cuda", |
|
layer_selection: int = None, |
|
normalize: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.name = "Featurizer" |
|
|
|
upstream.eval() |
|
paired_wavs = [torch.randn(SAMPLE_RATE).to(upstream_device)] |
|
with torch.no_grad(): |
|
paired_features = upstream(paired_wavs) |
|
|
|
if feature_selection not in paired_features: |
|
if "hidden_states" in paired_features: |
|
show( |
|
f"[{self.name}] - Warning: {feature_selection} is not a supported args.upstream_feature_selection." |
|
f' Using "hidden_states" as the default key.', |
|
file=sys.stderr, |
|
) |
|
feature_selection = "hidden_states" |
|
else: |
|
show( |
|
f"[{self.name}] - Error: {feature_selection} is not a supported args.upstream_feature_selection." |
|
f' The default key "hidden_states" is also not supported.' |
|
f" Please specify -s with the following options: {list(paired_wavs.keys())}", |
|
file=sys.stderr, |
|
) |
|
raise ValueError |
|
self.feature_selection = feature_selection |
|
self.layer_selection = layer_selection |
|
self.normalize = normalize |
|
|
|
feature = self._select_feature(paired_features) |
|
if isinstance(feature, (list, tuple)): |
|
self.layer_num = len(feature) |
|
show( |
|
f"[{self.name}] - Take a list of {self.layer_num} features and weighted sum them.", |
|
file=sys.stderr, |
|
) |
|
self.weights = nn.Parameter(torch.zeros(self.layer_num)) |
|
feature = self._weighted_sum([f.cpu() for f in feature]) |
|
else: |
|
feature = feature.cpu() |
|
|
|
self.output_dim = feature.size(-1) |
|
if hasattr(upstream, "get_downsample_rates"): |
|
self.downsample_rate = upstream.get_downsample_rates(feature_selection) |
|
show( |
|
f"[{self.name}] - The selected feature {feature_selection}'s downsample rate is {self.downsample_rate}", |
|
file=sys.stderr, |
|
) |
|
else: |
|
self.downsample_rate = round( |
|
max(len(wav) for wav in paired_wavs) / feature.size(1) |
|
) |
|
show( |
|
f"[{self.name}] - Warning: The provided upstream does not give statis downsample rate" |
|
' by the "get_downsample_rates" interface (see upstream/example/expert.py).' |
|
" The downsample rate is calculated dynamically basing on the shape of the" |
|
f" input waveforms v.s. the output features: {self.downsample_rate}", |
|
file=sys.stderr, |
|
) |
|
|
|
def _select_feature(self, features): |
|
feature = features.get(self.feature_selection) |
|
|
|
if isinstance(feature, dict): |
|
feature = list(feature.values()) |
|
|
|
if isinstance(feature, (list, tuple)) and len(feature) == 1: |
|
feature = feature[0] |
|
|
|
if isinstance(feature, (list, tuple)) and isinstance(self.layer_selection, int): |
|
feature = feature[self.layer_selection] |
|
|
|
return feature |
|
|
|
def _weighted_sum(self, feature): |
|
assert self.layer_num == len(feature), ( |
|
"If you run into this error, there is a great chance" |
|
" you are finetuning the upstream with wav2vec2's transformer blocks" |
|
" in weighted-sum mode (default), including wav2vec2, hubert, and decoar2." |
|
" These models use the layerdrop technique which causes the different number" |
|
" of layer forwards between different model forwards, resulting in different" |
|
" number of hidden states for different model forwards. Hence, finetuning" |
|
" these upstreams is essentially incompatible with weight-sum mode unless" |
|
" you turn off the layerdrop option in fairseq. See:" |
|
" https://github.com/pytorch/fairseq/blob/f6abcc2a67328bee8b15c596bb626ce2d720aae6/fairseq/models/wav2vec/wav2vec2.py#L857" |
|
" However, since finetuning upstreams will backward the gradient through all layers" |
|
" which serves the same functionality as weighted-sum: all layers can be used for different" |
|
" downstream tasks. Hence instead of finetuning upstream with weighted-sum, we suggest to" |
|
" follow the more common setting: finetuning upstream with the last layer. Please use the" |
|
" following options: --upstream_trainable --upstream_feature_selection last_hidden_state." |
|
" Or: -f -s last_hidden_state" |
|
) |
|
stacked_feature = torch.stack(feature, dim=0) |
|
|
|
if self.normalize: |
|
stacked_feature = F.layer_norm( |
|
stacked_feature, (stacked_feature.shape[-1],) |
|
) |
|
|
|
_, *origin_shape = stacked_feature.shape |
|
stacked_feature = stacked_feature.view(self.layer_num, -1) |
|
norm_weights = F.softmax(self.weights, dim=-1) |
|
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) |
|
weighted_feature = weighted_feature.view(*origin_shape) |
|
|
|
return weighted_feature |
|
|
|
def tolist(self, paired_wavs: List[Tensor], paired_feature: Tensor): |
|
assert paired_feature.dim() == 3, "(batch_size, max_seq_len, feat_dim)" |
|
feature_len = [round(len(wav) / self.downsample_rate) for wav in paired_wavs] |
|
length_diff = abs( |
|
paired_feature.size(1) |
|
- round(max([len(wav) for wav in paired_wavs]) / self.downsample_rate) |
|
) |
|
assert ( |
|
length_diff < TOLERABLE_SEQLEN_DIFF |
|
), f"{length_diff} >= {TOLERABLE_SEQLEN_DIFF}" |
|
feature = [f[:l] for f, l in zip(paired_feature, feature_len)] |
|
return feature |
|
|
|
def forward( |
|
self, |
|
paired_wavs: List[Tensor], |
|
paired_features: Dict[str, Union[Tensor, List[Tensor], Dict[str, Tensor]]], |
|
): |
|
feature = self._select_feature(paired_features) |
|
if isinstance(feature, (list, tuple)): |
|
feature = self._weighted_sum(feature) |
|
|
|
return self.tolist(paired_wavs, feature) |
|
|