File size: 5,364 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
from torch.nn import L1Loss

from s3prl.corpus.librispeech import librispeech_for_pretrain
from s3prl.dataset.pretrain_npc_pipe import PretrainNpcPipe
from s3prl.nn.cnn_npc import CnnNpc
from s3prl.nn.predictor_identity import PredictorIdentity
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler
from s3prl.task import Task
from s3prl.task.feat_reconstruction_task import FeatReconstructionTask
from s3prl.util.configuration import override_parent_cfg
from s3prl.util.workspace import Workspace

from .base import SslProblem

_input_size = 80
_audio_config = dict(
    feat_type="fbank",  # Feature type
    feat_dim=_input_size,  # Feature dimension
    frame_length=25,  # Window size in ms
    frame_shift=10,  # Hop size in ms
    decode_wav=False,
    cmvn=True,  # Apply uttr.-wised CMVN on Mel spectrogram
)
_pretrain_task_pipe_config = dict(
    _cls=PretrainNpcPipe,
    n_jobs=8,
    **_audio_config,
)


class Npc(SslProblem):
    """
    Npc pre-train problem
    """

    @override_parent_cfg(
        corpus=dict(
            _cls=librispeech_for_pretrain,
            dataset_root="???",
        ),
        train_datapipe=_pretrain_task_pipe_config,
        train_sampler=dict(
            _cls=MaxTimestampBatchSampler,
            max_timestamp=16000 * 20,
            shuffle=True,
        ),
        valid_datapipe=_pretrain_task_pipe_config,
        valid_sampler=dict(
            _cls=FixedBatchSizeBatchSampler,
            batch_size=2,
        ),
        test_datapipe=_pretrain_task_pipe_config,
        test_sampler=dict(
            _cls=FixedBatchSizeBatchSampler,
            batch_size=2,
        ),
        upstream=dict(
            _cls=CnnNpc,
            input_size=_input_size,
            kernel_size=15,  # Receptive field size (R) = kernel_size + 2*(n_blocks)
            mask_size=5,  # Desired input mask size (M_in) as described in NPC paper
            n_blocks=4,  # Number of ConvBlocks stacked in NPC model
            hidden_size=512,  # Dimension of feature of all layers
            dropout=0.1,  # Dropout in ConvBlock
            residual=True,  # Residual connection in ConvBlock
            batch_norm=True,  # Apply BatchNorm in ConvBlock
            activate="relu",  # Activation function of ConvBlock
            disable_cross_layer=False,  # Apply Masked ConvBlock at last layer only
            vq=dict(
                codebook_size=[
                    64,
                    64,
                    64,
                    64,
                ],  # Codebook size of each group in VQ-layer
                code_dim=[
                    128,
                    128,
                    128,
                    128,
                ],  # Dim of each group summing up to hidden_size
                gumbel_temperature=1.0,  # Temperature of Gumbel Softmax in VQ-layer
            ),
        ),
        predictor=dict(
            _cls=PredictorIdentity,
        ),
        task=dict(
            _cls=FeatReconstructionTask,
            loss=L1Loss,
            loss_config=dict(
                reduction="mean"
            ),  # the npc official implementation use reduction='none', then calculates the mean loss on valid part manually, here we use a label mask to replace it.
        ),
    )
    @classmethod
    def setup_problem(cls, **cfg):
        """
        This setups the Npc problem, containing train/valid/test datasets & samplers and a task object
        """
        super().setup_problem(**cfg)

    @override_parent_cfg(
        optimizer=dict(
            _cls="torch.optim.Adam",
            lr=0.001,
        ),
        trainer=dict(
            total_steps=1000000,
            eval_step=50000,
            save_step=50000,
            gradient_clipping=5.0,
            gradient_accumulate_steps=4,
            valid_metric="loss",
            valid_higher_better=False,
        ),
    )
    @classmethod
    def train(cls, **cfg):
        """
        Train the setup problem with the train/valid datasets & samplers and the task object
        """
        super().train(**cfg)

    @override_parent_cfg()
    @classmethod
    def inference(cls, **cfg):
        super().inference(**cfg)

    @classmethod
    def save_additional(
        cls,
        additional_dir: Workspace,
        workspace: Workspace,
        task: Task,
    ):
        setup_problem_cfg = workspace.get_cfg(cls.setup_problem)
        setup_problem_cfg["upstream"].pop("_cls")
        setup_problem_cfg["upstream"].pop("input_size")
        apc_config = dict(
            model=dict(
                paras=setup_problem_cfg["upstream"],
            ),
            data=dict(
                audio=_audio_config,
            ),
        )
        all_states = dict(
            config=apc_config,
            model=task.upstream.state_dict(),
            Upstream_Config=apc_config,
        )
        torch.save(
            all_states, str(additional_dir.parent.resolve()) + "/all_states.ckpt"
        )

    @override_parent_cfg(
        start_stage=0,
        final_stage=2,
        stage_0=dict(
            _method="setup_problem",
        ),
        stage_1=dict(
            _method="train",
        ),
        stage_2=dict(
            _method="inference",
        ),
    )
    @classmethod
    def run_stages(cls, **cfg):
        super().run_stages(**cfg)