lmzjms's picture
Upload 1162 files
0b32ad6 verified
import logging
from collections import defaultdict
from s3prl.base.container import Container
from s3prl.base.workspace import Workspace
from s3prl.corpus.iemocap import iemocap_for_superb
from s3prl.dataset.utterance_classification_pipe import UtteranceClassificationPipe
from s3prl.nn import MeanPoolingLinear
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler
from s3prl.task.utterance_classification_task import UtteranceClassificationTask
from s3prl.util.configuration import default_cfg, field
from .base import SuperbProblem
logger = logging.getLogger(__name__)
class SuperbER(SuperbProblem):
"""
Superb Emotion Classification problem
"""
@default_cfg(
**SuperbProblem.setup.default_except(
corpus=dict(
CLS=iemocap_for_superb,
dataset_root="???",
test_fold=field(
"???",
"The session in IEMOCAP used for testing.\n"
"The other sessions will be used for training and validation.",
),
),
train_datapipe=dict(
CLS=UtteranceClassificationPipe,
train_category_encoder=True,
),
train_sampler=dict(
CLS=FixedBatchSizeBatchSampler,
batch_size=4,
shuffle=True,
),
valid_datapipe=dict(
CLS=UtteranceClassificationPipe,
),
valid_sampler=dict(
CLS=FixedBatchSizeBatchSampler,
batch_size=4,
),
test_datapipe=dict(
CLS=UtteranceClassificationPipe,
),
test_sampler=dict(
CLS=FixedBatchSizeBatchSampler,
batch_size=4,
),
downstream=dict(
CLS=MeanPoolingLinear,
hidden_size=256,
),
task=dict(
CLS=UtteranceClassificationTask,
),
)
)
@classmethod
def setup(cls, **cfg):
"""
This setups the ER problem, containing train/valid/test datasets & samplers and a task object
"""
super().setup(**cfg)
@default_cfg(
**SuperbProblem.train.default_except(
optimizer=dict(
CLS="torch.optim.Adam",
lr=1.0e-4,
),
trainer=dict(
total_steps=30000,
log_step=500,
eval_step=1000,
save_step=1000,
gradient_clipping=1.0,
gradient_accumulate_steps=8,
valid_metric="accuracy",
valid_higher_better=True,
),
)
)
@classmethod
def train(cls, **cfg):
"""
Train the setup problem with the train/valid datasets & samplers and the task object
"""
super().train(**cfg)
@default_cfg(**SuperbProblem.inference.default_cfg)
@classmethod
def inference(cls, **cfg):
super().inference(**cfg)
@default_cfg(
**SuperbProblem.run.default_except(
stages=["setup", "train", "inference"],
start_stage="setup",
final_stage="inference",
setup=setup.default_cfg.deselect("workspace", "resume"),
train=train.default_cfg.deselect("workspace", "resume"),
inference=inference.default_cfg.deselect("workspace", "resume"),
)
)
@classmethod
def run(cls, **cfg):
super().run(**cfg)
@default_cfg(
num_fold=field(5, "The number of folds to run cross validation", int),
**run.default_except(
workspace=field(
"???",
"The root workspace for all folds.\n"
"Each fold will use a 'fold_{id}' sub-workspace under this root workspace",
),
setup=dict(
corpus=dict(
test_fold=field(
"TBD", "This will be auto-set by 'run_cross_validation'"
)
)
),
),
)
@classmethod
def cross_validation(cls, **cfg):
"""
Except 'num_fold', all other fields are for 'run' for every fold. That is, all folds shared the same
config (training hypers, dataset root, etc) except 'workspace' and 'test_fold' are different
"""
cfg = Container(cfg)
workspaces = [
str(Workspace(cfg.workspace) / f"fold_{fold_id}")
for fold_id in range(cfg.num_fold)
]
for fold_id, workspace in enumerate(workspaces):
fold_cfg = cfg.clone().deselect("num_fold")
fold_cfg.workspace = workspace
fold_cfg.setup.corpus.test_fold = fold_id
cls.run(
**fold_cfg,
)
metrics = defaultdict(list)
for fold_id, workspace in enumerate(workspaces):
workspace = Workspace(workspace)
metric = workspace["test_metrics"]
for key, value in metric.items():
metrics[key].append(value)
avg_result = dict()
for key, values in metrics.items():
avg_score = sum(values) / len(values)
avg_result[key] = avg_score
logger.info(f"Average {key}: {avg_score}")
Workspace(cfg.workspace).put(avg_result, "avg_test_metrics", "yaml")