File size: 2,987 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from s3prl import Container
from s3prl.corpus.speech_commands import gsc_v1_for_superb
from s3prl.dataset.utterance_classification_pipe import UtteranceClassificationPipe
from s3prl.nn import MeanPoolingLinear
from s3prl.sampler import BalancedWeightedSampler, FixedBatchSizeBatchSampler
from s3prl.task.utterance_classification_task import UtteranceClassificationTask
from s3prl.util.configuration import default_cfg

from .base import SuperbProblem

EFFECTS = [["channels", "1"], ["rate", "16000"], ["gain", "-3.0"]]


class SuperbKS(SuperbProblem):
    @default_cfg(
        **SuperbProblem.setup.default_except(
            corpus=dict(
                CLS=gsc_v1_for_superb,
                dataset_root="???",
            ),
            train_datapipe=dict(
                CLS=UtteranceClassificationPipe,
                train_category_encoder=True,
                sox_effects=EFFECTS,
            ),
            train_sampler=dict(
                CLS=BalancedWeightedSampler,
                batch_size=32,
            ),
            valid_datapipe=dict(
                CLS=UtteranceClassificationPipe,
                sox_effects=EFFECTS,
            ),
            valid_sampler=dict(
                CLS=BalancedWeightedSampler,
                batch_size=32,
            ),
            test_datapipe=dict(
                CLS=UtteranceClassificationPipe,
                sox_effects=EFFECTS,
            ),
            test_sampler=dict(
                CLS=FixedBatchSizeBatchSampler,
                batch_size=32,
            ),
            downstream=dict(
                CLS=MeanPoolingLinear,
                hidden_size=256,
            ),
            task=dict(
                CLS=UtteranceClassificationTask,
            ),
        )
    )
    @classmethod
    def setup(cls, **cfg):
        super().setup(**cfg)

    @default_cfg(
        **SuperbProblem.train.default_except(
            optimizer=dict(
                CLS="torch.optim.Adam",
                lr=1.0e-4,
            ),
            trainer=dict(
                total_steps=200000,
                log_step=100,
                eval_step=5000,
                save_step=1000,
                gradient_clipping=1.0,
                gradient_accumulate_steps=1,
                valid_metric="accuracy",
                valid_higher_better=True,
            ),
        )
    )
    @classmethod
    def train(cls, **cfg):
        super().train(**cfg)

    @default_cfg(**SuperbProblem.inference.default_cfg)
    @classmethod
    def inference(cls, **cfg):
        super().inference(**cfg)

    @default_cfg(
        stages=["setup", "train", "inference"],
        start_stage="setup",
        final_stage="inference",
        setup=setup.default_cfg.deselect("workspace", "resume"),
        train=train.default_cfg.deselect("workspace", "resume"),
        inference=inference.default_cfg.deselect("workspace", "resume"),
    )
    @classmethod
    def run(cls, **cfg):
        super().run(**cfg)