Spaces:
Running
Running
import copy | |
import logging | |
import os | |
from argparse import Namespace | |
from typing import Optional | |
from typing import Tuple | |
from typing import Union | |
import humanfriendly | |
import torch | |
import torch.nn as nn | |
from funasr_detach.frontends.utils.frontend import Frontend | |
from funasr_detach.models.transformer.utils.nets_utils import pad_list | |
def base_s3prl_setup(args): | |
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None) | |
args.upstream_model_config = getattr(args, "upstream_model_config", None) | |
args.upstream_refresh = getattr(args, "upstream_refresh", False) | |
args.upstream_ckpt = getattr(args, "upstream_ckpt", None) | |
args.init_ckpt = getattr(args, "init_ckpt", None) | |
args.verbose = getattr(args, "verbose", False) | |
args.tile_factor = getattr(args, "tile_factor", 1) | |
return args | |
class S3prlFrontend(nn.Module): | |
"""Speech Pretrained Representation frontend structure for ASR.""" | |
def __init__( | |
self, | |
fs: Union[int, str] = 16000, | |
frontend_conf: Optional[dict] = None, | |
download_dir: str = None, | |
multilayer_feature: bool = False, | |
): | |
super().__init__() | |
if isinstance(fs, str): | |
fs = humanfriendly.parse_size(fs) | |
if download_dir is not None: | |
torch.hub.set_dir(download_dir) | |
self.multilayer_feature = multilayer_feature | |
self.upstream, self.featurizer = self._get_upstream(frontend_conf) | |
self.pretrained_params = copy.deepcopy(self.upstream.state_dict()) | |
self.output_dim = self.featurizer.output_dim | |
self.frontend_type = "s3prl" | |
self.hop_length = self.upstream.get_downsample_rates("key") | |
def _get_upstream(self, frontend_conf): | |
"""Get S3PRL upstream model.""" | |
s3prl_args = base_s3prl_setup( | |
Namespace(**frontend_conf, device="cpu"), | |
) | |
self.args = s3prl_args | |
s3prl_path = None | |
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":") | |
for p in python_path_list: | |
if p.endswith("s3prl"): | |
s3prl_path = p | |
break | |
assert s3prl_path is not None | |
s3prl_upstream = torch.hub.load( | |
s3prl_path, | |
s3prl_args.upstream, | |
ckpt=s3prl_args.upstream_ckpt, | |
model_config=s3prl_args.upstream_model_config, | |
refresh=s3prl_args.upstream_refresh, | |
source="local", | |
).to("cpu") | |
if getattr( | |
s3prl_upstream, "model", None | |
) is not None and s3prl_upstream.model.__class__.__name__ in [ | |
"Wav2Vec2Model", | |
"HubertModel", | |
]: | |
s3prl_upstream.model.encoder.layerdrop = 0.0 | |
from s3prl.upstream.interfaces import Featurizer | |
if self.multilayer_feature is None: | |
feature_selection = "last_hidden_state" | |
else: | |
feature_selection = "hidden_states" | |
s3prl_featurizer = Featurizer( | |
upstream=s3prl_upstream, | |
feature_selection=feature_selection, | |
upstream_device="cpu", | |
) | |
return s3prl_upstream, s3prl_featurizer | |
def _tile_representations(self, feature): | |
"""Tile up the representations by `tile_factor`. | |
Input - sequence of representations | |
shape: (batch_size, seq_len, feature_dim) | |
Output - sequence of tiled representations | |
shape: (batch_size, seq_len * factor, feature_dim) | |
""" | |
assert ( | |
len(feature.shape) == 3 | |
), "Input argument `feature` has invalid shape: {}".format(feature.shape) | |
tiled_feature = feature.repeat(1, 1, self.args.tile_factor) | |
tiled_feature = tiled_feature.reshape( | |
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2) | |
) | |
return tiled_feature | |
def output_size(self) -> int: | |
return self.output_dim | |
def forward( | |
self, input: torch.Tensor, input_lengths: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)] | |
self.upstream.eval() | |
with torch.no_grad(): | |
feats = self.upstream(wavs) | |
feats = self.featurizer(wavs, feats) | |
if self.args.tile_factor != 1: | |
feats = self._tile_representations(feats) | |
input_feats = pad_list(feats, 0.0) | |
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long) | |
# Saving CUDA Memory | |
del feats | |
return input_feats, feats_lens | |
def reload_pretrained_parameters(self): | |
self.upstream.load_state_dict(self.pretrained_params) | |
logging.info("Pretrained S3PRL frontend model parameters reloaded!") | |