File size: 5,448 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging

import torch.nn as nn

from s3prl import Container, field
from s3prl.base import Logs
from s3prl.dataset.base import AugmentedDynamicItemDataset
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
from s3prl.problem.base import Problem
from s3prl.problem.trainer import Trainer
from s3prl.util import workspace
from s3prl.util.configuration import default_cfg
from s3prl.util.workspace import Workspace

logger = logging.getLogger(__name__)


class SslProblem(Problem, Trainer):
    @default_cfg(
        workspace=field(
            "???",
            "\nWill put the following keys into this workspace:\n"
            "  'train_dataset', 'train_sampler', 'valid_dataset', 'valid_sampler', and 'task'",
            "str or Path or Workspace",
        ),
        corpus=dict(
            _cls=field(
                "???",
                "\nThe corpus class. You can add the **kwargs right below this _cls key",
                str,
            ),
            dataset_root=field("???", "The root path of the corpus", str),
        ),
        train_datapipe=dict(
            _cls=field(
                "???",
                "\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        train_sampler=dict(
            _cls=field(
                "???",
                "\nThe batch sampler class. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        valid_datapipe=dict(
            _cls=field(
                "???",
                "\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        valid_sampler=dict(
            _cls=field(
                "???",
                "\nThe batch sampler class. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        test_datapipe=dict(
            _cls=field(
                "???",
                "\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        test_sampler=dict(
            _cls=field(
                "???",
                "\nThe batch sampler class. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        upstream=dict(
            _cls=field(
                S3PRLUpstream,
                "\nThe class of the upstream NN model. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        predictor=dict(
            _cls=field(
                "???",
                "\nThe class of the predictor NN model class for pre-train. You can add the **kwargs right below this _cls key",
                str,
            ),
        ),
        task=dict(
            _cls=field(
                "???",
                "\nThe task class defining what to do for each train/valid/test step in the train/valid/test dataloader loop"
                "\nYou can add the **kwargs right below this _cls key",
                str,
            ),
        ),
    )
    @classmethod
    def setup_problem(cls, **cfg):
        cfg = Container(cfg)
        workspace = Workspace(cfg.workspace)

        if not isinstance(cfg.upstream, nn.Module):
            upstream = cfg.upstream._cls(**cfg.upstream.kwds())
        else:
            upstream = cfg.upstream

        stats = Container()

        logger.info("Preparing corpus")
        train_data, valid_data, test_data, corpus_stats = cfg.corpus._cls(
            **cfg.corpus.kwds()
        ).split(3)
        stats.add(corpus_stats)

        logger.info("Preparing train data")
        train_dataset = AugmentedDynamicItemDataset(train_data, tools=stats)
        train_dataset = cfg.train_datapipe._cls(**cfg.train_datapipe.kwds())(
            train_dataset
        )
        train_sampler = cfg.train_sampler._cls(
            train_dataset, **cfg.train_sampler.kwds()
        )
        stats.add(train_dataset.all_tools())

        logger.info("Preparing valid data")
        valid_dataset = AugmentedDynamicItemDataset(valid_data, tools=stats)
        valid_dataset = cfg.valid_datapipe._cls(**cfg.valid_datapipe.kwds())(
            valid_dataset
        )
        valid_sampler = cfg.valid_sampler._cls(
            valid_dataset, **cfg.valid_sampler.kwds()
        )

        logger.info("Preparing test data")
        test_dataset = AugmentedDynamicItemDataset(test_data, tools=stats)
        test_dataset = cfg.test_datapipe._cls(**cfg.test_datapipe.kwds())(test_dataset)
        test_sampler = cfg.test_sampler._cls(test_dataset, **cfg.test_sampler.kwds())

        logger.info("Preparing model and task")
        predictor = cfg.predictor._cls(
            **stats,
            **cfg.predictor.kwds(),
        )
        task = cfg.task._cls(
            upstream, predictor, workspace=workspace, **stats, **cfg.task.kwds()
        )

        workspace["train_dataset"] = train_dataset
        workspace["train_sampler"] = train_sampler
        workspace["valid_dataset"] = valid_dataset
        workspace["valid_sampler"] = valid_sampler
        workspace["test_dataset"] = test_dataset
        workspace["test_sampler"] = test_sampler
        workspace["task"] = task
        workspace.environ.update(stats)