lmzjms's picture
Upload 1162 files
0b32ad6 verified
import torch
from torch.nn import L1Loss
from s3prl.corpus.librispeech import librispeech_for_pretrain
from s3prl.dataset.pretrain_mockingjay_pipe import PretrainMockingjayPipe
from s3prl.nn.predictor_mockingjay import PredictorMockingjay
from s3prl.nn.transformer_mockingjay import TransformerMockingjay
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler
from s3prl.task import Task
from s3prl.task.feat_reconstruction_task import FeatReconstructionTask
from s3prl.util.configuration import override_parent_cfg
from s3prl.util.workspace import Workspace
from .base import SslProblem
_input_size = 240
_mask_args = dict(
position_encoding_size=768, # int, this should be identical to `hidden_size`
mask_proportion=0.15, # float, mask this percentage of all spectrogram frames in each sequence at random during MAM training
mask_consecutive_min=7, # int, mask this amount of consecutive frames
mask_consecutive_max=7, # int, mask this amount of consecutive frames
mask_allow_overlap=True, # bool, allow overlap masking
mask_bucket_ratio=1.5, # float, only used when overlap is not allowed. sample a mask from each bucket in size of [sampled mask_consecutive * mask_bucket_ratio]
mask_frequency=0.0, # float, mask maximum this percentage of frequency bands, set to 0 for no frequency mask
)
_audio_config = dict(
kaldi={
"feat_type": "fbank",
"fbank": {
"frame_length": 25.0,
"frame_shift": 10.0,
"num_mel_bins": _input_size // 3, # because delta={"order": 2}
"use_log_fbank": True,
},
"mfcc": {"frame_length": 25.0, "frame_shift": 10.0, "num_ceps": 13},
"spectrogram": {"frame_length": 25.0, "frame_shift": 10.0},
},
delta={"order": 2, "win_length": 5},
cmvn={"use_cmvn": True},
)
pretrain_task_pipe_config = dict(
_cls=PretrainMockingjayPipe,
**_mask_args,
**_audio_config,
)
_transformer_config = dict(
hidden_size=768, # Size of the encoder layers and the pooler layer.
num_hidden_layers=3, # Number of hidden layers in the Transformer encoder.
num_attention_heads=12, # Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size=3072, # The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act="gelu", # The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob=0.1, # The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob=0.1, # The dropout ratio for the attention probabilities.
initializer_range=0.02, # The sttdev of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps=1.0e-12, # The epsilon used by LayerNorm.
share_layer=False, # Share layer weights
pre_layer_norm=False, # To apply the pre layer normalization technique introduced in: https://arxiv.org/abs/2002.04745
)
class Mockingjay(SslProblem):
"""
Mockingjay pre-train problem
"""
@override_parent_cfg(
corpus=dict(
_cls=librispeech_for_pretrain,
dataset_root="???",
),
train_datapipe=pretrain_task_pipe_config,
train_sampler=dict(
_cls=MaxTimestampBatchSampler,
max_timestamp=16000 * 20,
shuffle=True,
),
valid_datapipe=pretrain_task_pipe_config,
valid_sampler=dict(
_cls=FixedBatchSizeBatchSampler,
batch_size=2,
),
test_datapipe=pretrain_task_pipe_config,
test_sampler=dict(
_cls=FixedBatchSizeBatchSampler,
batch_size=2,
),
upstream=dict(
_cls=TransformerMockingjay,
config=_transformer_config,
input_dim=_input_size,
output_attentions=False,
keep_multihead_output=False,
with_input_module=True,
),
predictor=dict(
_cls=PredictorMockingjay,
config=_transformer_config,
output_dim=_input_size,
input_dim=None, # automatically use `hidden_size` from `_transformer_config`
),
task=dict(
_cls=FeatReconstructionTask,
loss=L1Loss,
),
)
@classmethod
def setup_problem(cls, **cfg):
"""
This setups the Mockingjay problem, containing train/valid/test datasets & samplers and a task object
"""
super().setup_problem(**cfg)
@override_parent_cfg(
optimizer=dict(
_cls="torch.optim.AdamW",
lr=2.0e-4,
),
trainer=dict(
total_steps=1000000,
eval_step=50000,
save_step=50000,
gradient_clipping=5.0,
gradient_accumulate_steps=4,
valid_metric="loss",
valid_higher_better=False,
),
)
@classmethod
def train(cls, **cfg):
"""
Train the setup problem with the train/valid datasets & samplers and the task object
"""
super().train(**cfg)
@override_parent_cfg()
@classmethod
def inference(cls, **cfg):
super().inference(**cfg)
@classmethod
def save_additional(
cls,
additional_dir: Workspace,
workspace: Workspace,
task: Task,
):
all_states = dict(
Config={}, # placeholder
SpecHead=task.predictor.state_dict(),
Transformer=task.upstream.state_dict(),
Upstream_Config=dict(
transformer=_transformer_config,
audio=_audio_config,
task=dict(sequence_length=0),
),
)
torch.save(
all_states, str(additional_dir.parent.resolve()) + "/all_states.ckpt"
)
@override_parent_cfg(
start_stage=0,
final_stage=2,
stage_0=dict(
_method="setup_problem",
),
stage_1=dict(
_method="train",
),
stage_2=dict(
_method="inference",
),
)
@classmethod
def run_stages(cls, **cfg):
super().run_stages(**cfg)