|
import torch |
|
from torch.nn import L1Loss |
|
|
|
from s3prl.corpus.librispeech import librispeech_for_pretrain |
|
from s3prl.dataset.pretrain_mockingjay_pipe import PretrainMockingjayPipe |
|
from s3prl.nn.predictor_mockingjay import PredictorMockingjay |
|
from s3prl.nn.transformer_mockingjay import TransformerMockingjay |
|
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler |
|
from s3prl.task import Task |
|
from s3prl.task.feat_reconstruction_task import FeatReconstructionTask |
|
from s3prl.util.configuration import override_parent_cfg |
|
from s3prl.util.workspace import Workspace |
|
|
|
from .base import SslProblem |
|
|
|
_input_size = 240 |
|
_mask_args = dict( |
|
position_encoding_size=768, |
|
mask_proportion=0.15, |
|
mask_consecutive_min=7, |
|
mask_consecutive_max=7, |
|
mask_allow_overlap=True, |
|
mask_bucket_ratio=1.5, |
|
mask_frequency=0.0, |
|
) |
|
_audio_config = dict( |
|
kaldi={ |
|
"feat_type": "fbank", |
|
"fbank": { |
|
"frame_length": 25.0, |
|
"frame_shift": 10.0, |
|
"num_mel_bins": _input_size // 3, |
|
"use_log_fbank": True, |
|
}, |
|
"mfcc": {"frame_length": 25.0, "frame_shift": 10.0, "num_ceps": 13}, |
|
"spectrogram": {"frame_length": 25.0, "frame_shift": 10.0}, |
|
}, |
|
delta={"order": 2, "win_length": 5}, |
|
cmvn={"use_cmvn": True}, |
|
) |
|
pretrain_task_pipe_config = dict( |
|
_cls=PretrainMockingjayPipe, |
|
**_mask_args, |
|
**_audio_config, |
|
) |
|
_transformer_config = dict( |
|
hidden_size=768, |
|
num_hidden_layers=3, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
initializer_range=0.02, |
|
layer_norm_eps=1.0e-12, |
|
share_layer=False, |
|
pre_layer_norm=False, |
|
) |
|
|
|
|
|
class Mockingjay(SslProblem): |
|
""" |
|
Mockingjay pre-train problem |
|
""" |
|
|
|
@override_parent_cfg( |
|
corpus=dict( |
|
_cls=librispeech_for_pretrain, |
|
dataset_root="???", |
|
), |
|
train_datapipe=pretrain_task_pipe_config, |
|
train_sampler=dict( |
|
_cls=MaxTimestampBatchSampler, |
|
max_timestamp=16000 * 20, |
|
shuffle=True, |
|
), |
|
valid_datapipe=pretrain_task_pipe_config, |
|
valid_sampler=dict( |
|
_cls=FixedBatchSizeBatchSampler, |
|
batch_size=2, |
|
), |
|
test_datapipe=pretrain_task_pipe_config, |
|
test_sampler=dict( |
|
_cls=FixedBatchSizeBatchSampler, |
|
batch_size=2, |
|
), |
|
upstream=dict( |
|
_cls=TransformerMockingjay, |
|
config=_transformer_config, |
|
input_dim=_input_size, |
|
output_attentions=False, |
|
keep_multihead_output=False, |
|
with_input_module=True, |
|
), |
|
predictor=dict( |
|
_cls=PredictorMockingjay, |
|
config=_transformer_config, |
|
output_dim=_input_size, |
|
input_dim=None, |
|
), |
|
task=dict( |
|
_cls=FeatReconstructionTask, |
|
loss=L1Loss, |
|
), |
|
) |
|
@classmethod |
|
def setup_problem(cls, **cfg): |
|
""" |
|
This setups the Mockingjay problem, containing train/valid/test datasets & samplers and a task object |
|
""" |
|
super().setup_problem(**cfg) |
|
|
|
@override_parent_cfg( |
|
optimizer=dict( |
|
_cls="torch.optim.AdamW", |
|
lr=2.0e-4, |
|
), |
|
trainer=dict( |
|
total_steps=1000000, |
|
eval_step=50000, |
|
save_step=50000, |
|
gradient_clipping=5.0, |
|
gradient_accumulate_steps=4, |
|
valid_metric="loss", |
|
valid_higher_better=False, |
|
), |
|
) |
|
@classmethod |
|
def train(cls, **cfg): |
|
""" |
|
Train the setup problem with the train/valid datasets & samplers and the task object |
|
""" |
|
super().train(**cfg) |
|
|
|
@override_parent_cfg() |
|
@classmethod |
|
def inference(cls, **cfg): |
|
super().inference(**cfg) |
|
|
|
@classmethod |
|
def save_additional( |
|
cls, |
|
additional_dir: Workspace, |
|
workspace: Workspace, |
|
task: Task, |
|
): |
|
all_states = dict( |
|
Config={}, |
|
SpecHead=task.predictor.state_dict(), |
|
Transformer=task.upstream.state_dict(), |
|
Upstream_Config=dict( |
|
transformer=_transformer_config, |
|
audio=_audio_config, |
|
task=dict(sequence_length=0), |
|
), |
|
) |
|
torch.save( |
|
all_states, str(additional_dir.parent.resolve()) + "/all_states.ckpt" |
|
) |
|
|
|
@override_parent_cfg( |
|
start_stage=0, |
|
final_stage=2, |
|
stage_0=dict( |
|
_method="setup_problem", |
|
), |
|
stage_1=dict( |
|
_method="train", |
|
), |
|
stage_2=dict( |
|
_method="inference", |
|
), |
|
) |
|
@classmethod |
|
def run_stages(cls, **cfg): |
|
super().run_stages(**cfg) |
|
|