File size: 4,240 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from torch.nn import L1Loss

from s3prl.corpus.librispeech import librispeech_for_pretrain
from s3prl.dataset.pretrain_apc_pipe import PretrainApcPipe
from s3prl.nn.predictor_identity import PredictorIdentity
from s3prl.nn.rnn_apc import RnnApc
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler
from s3prl.task import Task
from s3prl.task.autoregressive_reconstruction_task import (
    AutoregressiveReconstructionTask,
)
from s3prl.util.configuration import override_parent_cfg
from s3prl.util.workspace import Workspace

from .base import SslProblem

_input_size = 80
_audio_config = dict(
    feat_type="fbank",  # Feature type
    feat_dim=_input_size,  # Feature dimension
    frame_length=25,  # Window size in ms
    frame_shift=10,  # Hop size in ms
    decode_wav=False,
    cmvn=True,  # Apply uttr.-wised CMVN on Mel spectrogram
)
_pretrain_task_pipe_config = dict(
    _cls=PretrainApcPipe,
    n_future=5,
    n_jobs=8,
    **_audio_config,
)


class Apc(SslProblem):
    """
    Apc pre-train problem
    """

    @override_parent_cfg(
        corpus=dict(
            _cls=librispeech_for_pretrain,
            dataset_root="???",
        ),
        train_datapipe=_pretrain_task_pipe_config,
        train_sampler=dict(
            _cls=MaxTimestampBatchSampler,
            max_timestamp=16000 * 20,
            shuffle=True,
        ),
        valid_datapipe=_pretrain_task_pipe_config,
        valid_sampler=dict(
            _cls=FixedBatchSizeBatchSampler,
            batch_size=2,
        ),
        test_datapipe=_pretrain_task_pipe_config,
        test_sampler=dict(
            _cls=FixedBatchSizeBatchSampler,
            batch_size=2,
        ),
        upstream=dict(
            _cls=RnnApc,
            input_size=_input_size,
            num_layers=3,
            hidden_size=512,
            dropout=0.1,
            residual=True,
        ),
        predictor=dict(
            _cls=PredictorIdentity,
        ),
        task=dict(
            _cls=AutoregressiveReconstructionTask,
            loss=L1Loss,
        ),
    )
    @classmethod
    def setup_problem(cls, **cfg):
        """
        This setups the Apc problem, containing train/valid/test datasets & samplers and a task object
        """
        super().setup_problem(**cfg)

    @override_parent_cfg(
        optimizer=dict(
            _cls="torch.optim.AdamW",
            lr=0.0001,  # set to 0.00001 for some datasets if you encounter NaN during training
        ),
        trainer=dict(
            total_steps=1000000,
            eval_step=50000,
            save_step=50000,
            gradient_clipping=5.0,
            gradient_accumulate_steps=4,
            valid_metric="loss",
            valid_higher_better=False,
        ),
    )
    @classmethod
    def train(cls, **cfg):
        """
        Train the setup problem with the train/valid datasets & samplers and the task object
        """
        super().train(**cfg)

    @override_parent_cfg()
    @classmethod
    def inference(cls, **cfg):
        super().inference(**cfg)

    @classmethod
    def save_additional(
        cls,
        additional_dir: Workspace,
        workspace: Workspace,
        task: Task,
    ):
        setup_problem_cfg = workspace.get_cfg(cls.setup_problem)
        setup_problem_cfg["upstream"].pop("_cls")
        setup_problem_cfg["upstream"].pop("input_size")
        apc_config = dict(
            model=dict(
                paras=setup_problem_cfg["upstream"],
            ),
            data=dict(
                audio=_audio_config,
            ),
        )
        all_states = dict(
            config=apc_config,
            model=task.upstream.state_dict(),
            Upstream_Config=apc_config,
        )
        torch.save(
            all_states, str(additional_dir.parent.resolve()) + "/all_states.ckpt"
        )

    @override_parent_cfg(
        start_stage=0,
        final_stage=2,
        stage_0=dict(
            _method="setup_problem",
        ),
        stage_1=dict(
            _method="train",
        ),
        stage_2=dict(
            _method="inference",
        ),
    )
    @classmethod
    def run_stages(cls, **cfg):
        super().run_stages(**cfg)