File size: 5,448 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import logging
import torch.nn as nn
from s3prl import Container, field
from s3prl.base import Logs
from s3prl.dataset.base import AugmentedDynamicItemDataset
from s3prl.nn import S3PRLUpstream, UpstreamDownstreamModel
from s3prl.problem.base import Problem
from s3prl.problem.trainer import Trainer
from s3prl.util import workspace
from s3prl.util.configuration import default_cfg
from s3prl.util.workspace import Workspace
logger = logging.getLogger(__name__)
class SslProblem(Problem, Trainer):
@default_cfg(
workspace=field(
"???",
"\nWill put the following keys into this workspace:\n"
" 'train_dataset', 'train_sampler', 'valid_dataset', 'valid_sampler', and 'task'",
"str or Path or Workspace",
),
corpus=dict(
_cls=field(
"???",
"\nThe corpus class. You can add the **kwargs right below this _cls key",
str,
),
dataset_root=field("???", "The root path of the corpus", str),
),
train_datapipe=dict(
_cls=field(
"???",
"\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
str,
),
),
train_sampler=dict(
_cls=field(
"???",
"\nThe batch sampler class. You can add the **kwargs right below this _cls key",
str,
),
),
valid_datapipe=dict(
_cls=field(
"???",
"\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
str,
),
),
valid_sampler=dict(
_cls=field(
"???",
"\nThe batch sampler class. You can add the **kwargs right below this _cls key",
str,
),
),
test_datapipe=dict(
_cls=field(
"???",
"\nThe datapipe class to be applied to the corpus. You can add the **kwargs right below this _cls key",
str,
),
),
test_sampler=dict(
_cls=field(
"???",
"\nThe batch sampler class. You can add the **kwargs right below this _cls key",
str,
),
),
upstream=dict(
_cls=field(
S3PRLUpstream,
"\nThe class of the upstream NN model. You can add the **kwargs right below this _cls key",
str,
),
),
predictor=dict(
_cls=field(
"???",
"\nThe class of the predictor NN model class for pre-train. You can add the **kwargs right below this _cls key",
str,
),
),
task=dict(
_cls=field(
"???",
"\nThe task class defining what to do for each train/valid/test step in the train/valid/test dataloader loop"
"\nYou can add the **kwargs right below this _cls key",
str,
),
),
)
@classmethod
def setup_problem(cls, **cfg):
cfg = Container(cfg)
workspace = Workspace(cfg.workspace)
if not isinstance(cfg.upstream, nn.Module):
upstream = cfg.upstream._cls(**cfg.upstream.kwds())
else:
upstream = cfg.upstream
stats = Container()
logger.info("Preparing corpus")
train_data, valid_data, test_data, corpus_stats = cfg.corpus._cls(
**cfg.corpus.kwds()
).split(3)
stats.add(corpus_stats)
logger.info("Preparing train data")
train_dataset = AugmentedDynamicItemDataset(train_data, tools=stats)
train_dataset = cfg.train_datapipe._cls(**cfg.train_datapipe.kwds())(
train_dataset
)
train_sampler = cfg.train_sampler._cls(
train_dataset, **cfg.train_sampler.kwds()
)
stats.add(train_dataset.all_tools())
logger.info("Preparing valid data")
valid_dataset = AugmentedDynamicItemDataset(valid_data, tools=stats)
valid_dataset = cfg.valid_datapipe._cls(**cfg.valid_datapipe.kwds())(
valid_dataset
)
valid_sampler = cfg.valid_sampler._cls(
valid_dataset, **cfg.valid_sampler.kwds()
)
logger.info("Preparing test data")
test_dataset = AugmentedDynamicItemDataset(test_data, tools=stats)
test_dataset = cfg.test_datapipe._cls(**cfg.test_datapipe.kwds())(test_dataset)
test_sampler = cfg.test_sampler._cls(test_dataset, **cfg.test_sampler.kwds())
logger.info("Preparing model and task")
predictor = cfg.predictor._cls(
**stats,
**cfg.predictor.kwds(),
)
task = cfg.task._cls(
upstream, predictor, workspace=workspace, **stats, **cfg.task.kwds()
)
workspace["train_dataset"] = train_dataset
workspace["train_sampler"] = train_sampler
workspace["valid_dataset"] = valid_dataset
workspace["valid_sampler"] = valid_sampler
workspace["test_dataset"] = test_dataset
workspace["test_sampler"] = test_sampler
workspace["task"] = task
workspace.environ.update(stats)
|