|
|
|
|
|
|
|
"""Encoder definition.""" |
|
import contextlib |
|
import copy |
|
from filelock import FileLock |
|
import logging |
|
import os |
|
from typing import Optional |
|
from typing import Tuple |
|
|
|
import torch |
|
from typeguard import check_argument_types |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm |
|
from espnet2.asr.encoder.abs_encoder import AbsEncoder |
|
|
|
|
|
class FairSeqWav2Vec2Encoder(AbsEncoder): |
|
"""FairSeq Wav2Vec2 encoder module. |
|
|
|
Args: |
|
input_size: input dim |
|
output_size: dimension of attention |
|
w2v_url: url to Wav2Vec2.0 pretrained model |
|
w2v_dir_path: directory to download the Wav2Vec2.0 pretrained model. |
|
normalize_before: whether to use layer_norm before the first block |
|
finetune_last_n_layers: last n layers to be finetuned in Wav2Vec2.0 |
|
0 means to finetune every layer if freeze_w2v=False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
w2v_url: str, |
|
w2v_dir_path: str = "./", |
|
output_size: int = 256, |
|
normalize_before: bool = False, |
|
freeze_finetune_updates: int = 0, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
|
|
if w2v_url != "": |
|
try: |
|
import fairseq |
|
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model |
|
except Exception as e: |
|
print("Error: FairSeq is not properly installed.") |
|
print( |
|
"Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done" |
|
) |
|
raise e |
|
|
|
self.w2v_model_path = download_w2v(w2v_url, w2v_dir_path) |
|
|
|
self._output_size = output_size |
|
|
|
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
|
[self.w2v_model_path], |
|
arg_overrides={"data": w2v_dir_path}, |
|
) |
|
model = models[0] |
|
|
|
if not isinstance(model, Wav2Vec2Model): |
|
try: |
|
model = model.w2v_encoder.w2v_model |
|
except Exception as e: |
|
print( |
|
"Error: pretrained models should be within: " |
|
"'Wav2Vec2Model, Wav2VecCTC' classes, etc." |
|
) |
|
raise e |
|
|
|
self.encoders = model |
|
|
|
self.pretrained_params = copy.deepcopy(model.state_dict()) |
|
|
|
self.normalize_before = normalize_before |
|
if self.normalize_before: |
|
self.after_norm = LayerNorm(output_size) |
|
|
|
if model.cfg.encoder_embed_dim != output_size: |
|
|
|
self.output_layer = torch.nn.Sequential( |
|
torch.nn.Linear(model.cfg.encoder_embed_dim, output_size), |
|
) |
|
else: |
|
self.output_layer = None |
|
|
|
self.freeze_finetune_updates = freeze_finetune_updates |
|
self.register_buffer("num_updates", torch.LongTensor([0])) |
|
|
|
def output_size(self) -> int: |
|
return self._output_size |
|
|
|
def forward( |
|
self, |
|
xs_pad: torch.Tensor, |
|
ilens: torch.Tensor, |
|
prev_states: torch.Tensor = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|
"""Forward FairSeqWav2Vec2 Encoder. |
|
|
|
Args: |
|
xs_pad: input tensor (B, L, D) |
|
ilens: input length (B) |
|
prev_states: Not to be used now. |
|
Returns: |
|
position embedded tensor and mask |
|
""" |
|
masks = make_pad_mask(ilens).to(xs_pad.device) |
|
|
|
ft = self.freeze_finetune_updates <= self.num_updates |
|
if self.num_updates <= self.freeze_finetune_updates: |
|
self.num_updates += 1 |
|
elif ft and self.num_updates == self.freeze_finetune_updates + 1: |
|
self.num_updates += 1 |
|
logging.info("Start fine-tuning wav2vec parameters!") |
|
|
|
with torch.no_grad() if not ft else contextlib.nullcontext(): |
|
enc_outputs = self.encoders( |
|
xs_pad, |
|
masks, |
|
features_only=True, |
|
) |
|
|
|
xs_pad = enc_outputs["x"] |
|
masks = enc_outputs["padding_mask"] |
|
|
|
olens = (~masks).sum(dim=1) |
|
|
|
if self.output_layer is not None: |
|
xs_pad = self.output_layer(xs_pad) |
|
|
|
if self.normalize_before: |
|
xs_pad = self.after_norm(xs_pad) |
|
|
|
return xs_pad, olens, None |
|
|
|
def reload_pretrained_parameters(self): |
|
self.encoders.load_state_dict(self.pretrained_params) |
|
logging.info("Pretrained Wav2Vec model parameters reloaded!") |
|
|
|
|
|
def download_w2v(model_url, dir_path): |
|
os.makedirs(dir_path, exist_ok=True) |
|
|
|
model_name = model_url.split("/")[-1] |
|
model_path = os.path.join(dir_path, model_name) |
|
|
|
dict_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt" |
|
dict_path = os.path.join(dir_path, dict_url.split("/")[-1]) |
|
|
|
with FileLock(model_path + ".lock"): |
|
if not os.path.exists(model_path): |
|
torch.hub.download_url_to_file(model_url, model_path) |
|
torch.hub.download_url_to_file(dict_url, dict_path) |
|
logging.info(f"Wav2Vec model downloaded {model_path}") |
|
else: |
|
logging.info(f"Wav2Vec model {model_path} already exists.") |
|
|
|
return model_path |
|
|