lmzjms's picture
Upload 1162 files
0b32ad6 verified
import logging
from s3prl.corpus.hear import hear_scene_kfolds
from s3prl.util.configuration import default_cfg, field
from .scene import HearScene
logger = logging.getLogger(__name__)
NUM_FOLDS = 5
class Stroke(HearScene):
@default_cfg(
**HearScene.setup.default_except(
corpus=dict(
CLS=field(
hear_scene_kfolds,
"\nThe corpus class. You can add the **kwargs right below this CLS key",
str,
),
dataset_root=field(
"???",
"The root path of the corpus",
str,
),
test_fold=field("???", "The testing fold id. Options: [0, 1, 2, 3, 4]"),
num_folds=NUM_FOLDS,
),
train_sampler=dict(
batch_size=32,
),
task=dict(
prediction_type="multiclass",
scores=["top1_acc", "d_prime", "aucroc", "mAP"],
),
)
)
@classmethod
def setup(cls, **cfg):
super().setup(**cfg)
@default_cfg(
**HearScene.train.default_except(
trainer=dict(
valid_metric="top1_acc",
valid_higher_better=True,
)
)
)
@classmethod
def train(cls, **cfg):
super().train(**cfg)
@default_cfg(**HearScene.inference.default_cfg)
@classmethod
def inference(cls, **cfg):
super().inference(**cfg)
@default_cfg(
**HearScene.run.default_except(
stages=["setup", "train", "inference"],
start_stage="setup",
final_stage="inference",
setup=setup.default_cfg.deselect("workspace", "resume"),
train=train.default_cfg.deselect("workspace", "resume"),
inference=inference.default_cfg.deselect("workspace", "resume"),
)
)
@classmethod
def run(cls, **cfg):
super().run(**cfg)
@default_cfg(
num_fold=field(NUM_FOLDS, "The number of folds to run cross validation", int),
**run.default_except(
workspace=field(
"???",
"The root workspace for all folds.\n"
"Each fold will use a 'fold_{id}' sub-workspace under this root workspace",
),
setup=dict(
corpus=dict(
test_fold=field(
"TBD", "This will be auto-set by 'run_cross_validation'"
)
)
),
),
)
@classmethod
def cross_validation(cls, **cfg):
super().cross_validation(**cfg)