lmzjms's picture
Upload 1162 files
0b32ad6 verified
import sys
from typing import Callable, Dict, List, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from s3prl.utility.helper import show
SAMPLE_RATE = 16000
TOLERABLE_SEQLEN_DIFF = 5
class Hook:
def __init__(self, module_path, transform, unique_identifier=None):
self.module_path = module_path
self.transform = transform
self.unique_identifier = unique_identifier or module_path
self.handler = None
assert isinstance(self.module_path, str)
assert callable(self.transform)
assert isinstance(self.unique_identifier, str)
class initHook(type):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
for hook in instance.hooks:
if hook.handler is None:
instance._register_hook_handler(hook)
return instance
class UpstreamBase(nn.Module, metaclass=initHook):
def __init__(
self,
hooks: List[Tuple] = None,
hook_postprocess: Callable[
[List[Tuple[str, Tensor]]], List[Tuple[str, Tensor]]
] = None,
**kwargs,
):
"""
Args:
hooks: each Tuple is an argument list for the Hook initializer
"""
super().__init__()
self.hooks: List[Hook] = [Hook(*hook) for hook in hooks] if hooks else []
self.hook_postprocess = hook_postprocess
self._hook_hiddens: List[Tuple(str, Tensor)] = []
def remove_all_hooks(self):
for hook in self.hooks:
hook.handler.remove()
self.hooks.clear()
def remove_hook(self, unique_identifier: str):
updated_hooks = []
for hook in self.hooks:
if hook.unique_identifier == unique_identifier:
hook.handler.remove()
else:
updated_hooks.append(hook)
self.hooks = updated_hooks
def add_hook(self, *args, **kwargs):
hook = Hook(*args, **kwargs)
self._register_hook_handler(hook)
self.hooks.append(hook)
def _register_hook_handler(self, hook: Hook):
module = eval(hook.module_path)
if not isinstance(module, nn.Module):
show(
f"[UpstreamBase] - {hook.module_path} is not a valid nn.Module. Skip.",
file=sys.stderr,
)
return
if callable(hook.handler):
show(
f"[UpstreamBase] - Existing hook handler for {hook.unique_identifier} is found. Remove the existing one.",
file=sys.stderr,
)
hook.handler.remove()
def generate_hook_handler(hiddens: List, hook: Hook):
def hook_handler(self, input, output):
hiddens.append((hook.unique_identifier, hook.transform(input, output)))
return hook_handler
hook.handler = module.register_forward_hook(
generate_hook_handler(self._hook_hiddens, hook)
)
def __call__(self, wavs: List[Tensor], *args, **kwargs):
self._hook_hiddens.clear()
result = super().__call__(wavs, *args, **kwargs) or {}
assert isinstance(result, dict)
if len(self._hook_hiddens) > 0:
if (
result.get("_hidden_states_info") is not None
or result.get("hidden_states") is not None
or result.get("last_hidden_state") is not None
):
show(
"[UpstreamBase] - If there are registered hooks, '_hidden_states_info', 'hidden_states', and "
"'last_hidden_state' are reserved and should not be included in child class's return dict.",
file=sys.stderr,
)
raise ValueError
hook_hiddens = self._hook_hiddens.copy()
self._hook_hiddens.clear()
if callable(self.hook_postprocess):
hook_hiddens = self.hook_postprocess(hook_hiddens)
result["_hidden_states_info"], result["hidden_states"] = zip(*hook_hiddens)
result["last_hidden_state"] = result["hidden_states"][-1]
for layer_id, hidden_state in enumerate(result["hidden_states"]):
result[f"hidden_state_{layer_id}"] = hidden_state
return result
class Featurizer(nn.Module):
def __init__(
self,
upstream: UpstreamBase,
feature_selection: str = "hidden_states",
upstream_device: str = "cuda",
layer_selection: int = None,
normalize: bool = False,
**kwargs,
):
super().__init__()
self.name = "Featurizer"
upstream.eval()
paired_wavs = [torch.randn(SAMPLE_RATE).to(upstream_device)]
with torch.no_grad():
paired_features = upstream(paired_wavs)
if feature_selection not in paired_features:
if "hidden_states" in paired_features:
show(
f"[{self.name}] - Warning: {feature_selection} is not a supported args.upstream_feature_selection."
f' Using "hidden_states" as the default key.',
file=sys.stderr,
)
feature_selection = "hidden_states"
else:
show(
f"[{self.name}] - Error: {feature_selection} is not a supported args.upstream_feature_selection."
f' The default key "hidden_states" is also not supported.'
f" Please specify -s with the following options: {list(paired_wavs.keys())}",
file=sys.stderr,
)
raise ValueError
self.feature_selection = feature_selection
self.layer_selection = layer_selection
self.normalize = normalize
feature = self._select_feature(paired_features)
if isinstance(feature, (list, tuple)):
self.layer_num = len(feature)
show(
f"[{self.name}] - Take a list of {self.layer_num} features and weighted sum them.",
file=sys.stderr,
)
self.weights = nn.Parameter(torch.zeros(self.layer_num))
feature = self._weighted_sum([f.cpu() for f in feature])
else:
feature = feature.cpu()
self.output_dim = feature.size(-1)
if hasattr(upstream, "get_downsample_rates"):
self.downsample_rate = upstream.get_downsample_rates(feature_selection)
show(
f"[{self.name}] - The selected feature {feature_selection}'s downsample rate is {self.downsample_rate}",
file=sys.stderr,
)
else:
self.downsample_rate = round(
max(len(wav) for wav in paired_wavs) / feature.size(1)
)
show(
f"[{self.name}] - Warning: The provided upstream does not give statis downsample rate"
' by the "get_downsample_rates" interface (see upstream/example/expert.py).'
" The downsample rate is calculated dynamically basing on the shape of the"
f" input waveforms v.s. the output features: {self.downsample_rate}",
file=sys.stderr,
)
def _select_feature(self, features):
feature = features.get(self.feature_selection)
if isinstance(feature, dict):
feature = list(feature.values())
if isinstance(feature, (list, tuple)) and len(feature) == 1:
feature = feature[0]
if isinstance(feature, (list, tuple)) and isinstance(self.layer_selection, int):
feature = feature[self.layer_selection]
return feature
def _weighted_sum(self, feature):
assert self.layer_num == len(feature), (
"If you run into this error, there is a great chance"
" you are finetuning the upstream with wav2vec2's transformer blocks"
" in weighted-sum mode (default), including wav2vec2, hubert, and decoar2."
" These models use the layerdrop technique which causes the different number"
" of layer forwards between different model forwards, resulting in different"
" number of hidden states for different model forwards. Hence, finetuning"
" these upstreams is essentially incompatible with weight-sum mode unless"
" you turn off the layerdrop option in fairseq. See:"
" https://github.com/pytorch/fairseq/blob/f6abcc2a67328bee8b15c596bb626ce2d720aae6/fairseq/models/wav2vec/wav2vec2.py#L857"
" However, since finetuning upstreams will backward the gradient through all layers"
" which serves the same functionality as weighted-sum: all layers can be used for different"
" downstream tasks. Hence instead of finetuning upstream with weighted-sum, we suggest to"
" follow the more common setting: finetuning upstream with the last layer. Please use the"
" following options: --upstream_trainable --upstream_feature_selection last_hidden_state."
" Or: -f -s last_hidden_state"
)
stacked_feature = torch.stack(feature, dim=0)
if self.normalize:
stacked_feature = F.layer_norm(
stacked_feature, (stacked_feature.shape[-1],)
)
_, *origin_shape = stacked_feature.shape
stacked_feature = stacked_feature.view(self.layer_num, -1)
norm_weights = F.softmax(self.weights, dim=-1)
weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0)
weighted_feature = weighted_feature.view(*origin_shape)
return weighted_feature
def tolist(self, paired_wavs: List[Tensor], paired_feature: Tensor):
assert paired_feature.dim() == 3, "(batch_size, max_seq_len, feat_dim)"
feature_len = [round(len(wav) / self.downsample_rate) for wav in paired_wavs]
length_diff = abs(
paired_feature.size(1)
- round(max([len(wav) for wav in paired_wavs]) / self.downsample_rate)
)
assert (
length_diff < TOLERABLE_SEQLEN_DIFF
), f"{length_diff} >= {TOLERABLE_SEQLEN_DIFF}"
feature = [f[:l] for f, l in zip(paired_feature, feature_len)]
return feature
def forward(
self,
paired_wavs: List[Tensor],
paired_features: Dict[str, Union[Tensor, List[Tensor], Dict[str, Tensor]]],
):
feature = self._select_feature(paired_features)
if isinstance(feature, (list, tuple)):
feature = self._weighted_sum(feature)
return self.tolist(paired_wavs, feature)