|
import logging |
|
from collections import defaultdict |
|
|
|
from s3prl.base.container import Container |
|
from s3prl.base.workspace import Workspace |
|
from s3prl.corpus.iemocap import iemocap_for_superb |
|
from s3prl.dataset.utterance_classification_pipe import UtteranceClassificationPipe |
|
from s3prl.nn import MeanPoolingLinear |
|
from s3prl.sampler import FixedBatchSizeBatchSampler, MaxTimestampBatchSampler |
|
from s3prl.task.utterance_classification_task import UtteranceClassificationTask |
|
from s3prl.util.configuration import default_cfg, field |
|
|
|
from .base import SuperbProblem |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SuperbER(SuperbProblem): |
|
""" |
|
Superb Emotion Classification problem |
|
""" |
|
|
|
@default_cfg( |
|
**SuperbProblem.setup.default_except( |
|
corpus=dict( |
|
CLS=iemocap_for_superb, |
|
dataset_root="???", |
|
test_fold=field( |
|
"???", |
|
"The session in IEMOCAP used for testing.\n" |
|
"The other sessions will be used for training and validation.", |
|
), |
|
), |
|
train_datapipe=dict( |
|
CLS=UtteranceClassificationPipe, |
|
train_category_encoder=True, |
|
), |
|
train_sampler=dict( |
|
CLS=FixedBatchSizeBatchSampler, |
|
batch_size=4, |
|
shuffle=True, |
|
), |
|
valid_datapipe=dict( |
|
CLS=UtteranceClassificationPipe, |
|
), |
|
valid_sampler=dict( |
|
CLS=FixedBatchSizeBatchSampler, |
|
batch_size=4, |
|
), |
|
test_datapipe=dict( |
|
CLS=UtteranceClassificationPipe, |
|
), |
|
test_sampler=dict( |
|
CLS=FixedBatchSizeBatchSampler, |
|
batch_size=4, |
|
), |
|
downstream=dict( |
|
CLS=MeanPoolingLinear, |
|
hidden_size=256, |
|
), |
|
task=dict( |
|
CLS=UtteranceClassificationTask, |
|
), |
|
) |
|
) |
|
@classmethod |
|
def setup(cls, **cfg): |
|
""" |
|
This setups the ER problem, containing train/valid/test datasets & samplers and a task object |
|
""" |
|
super().setup(**cfg) |
|
|
|
@default_cfg( |
|
**SuperbProblem.train.default_except( |
|
optimizer=dict( |
|
CLS="torch.optim.Adam", |
|
lr=1.0e-4, |
|
), |
|
trainer=dict( |
|
total_steps=30000, |
|
log_step=500, |
|
eval_step=1000, |
|
save_step=1000, |
|
gradient_clipping=1.0, |
|
gradient_accumulate_steps=8, |
|
valid_metric="accuracy", |
|
valid_higher_better=True, |
|
), |
|
) |
|
) |
|
@classmethod |
|
def train(cls, **cfg): |
|
""" |
|
Train the setup problem with the train/valid datasets & samplers and the task object |
|
""" |
|
super().train(**cfg) |
|
|
|
@default_cfg(**SuperbProblem.inference.default_cfg) |
|
@classmethod |
|
def inference(cls, **cfg): |
|
super().inference(**cfg) |
|
|
|
@default_cfg( |
|
**SuperbProblem.run.default_except( |
|
stages=["setup", "train", "inference"], |
|
start_stage="setup", |
|
final_stage="inference", |
|
setup=setup.default_cfg.deselect("workspace", "resume"), |
|
train=train.default_cfg.deselect("workspace", "resume"), |
|
inference=inference.default_cfg.deselect("workspace", "resume"), |
|
) |
|
) |
|
@classmethod |
|
def run(cls, **cfg): |
|
super().run(**cfg) |
|
|
|
@default_cfg( |
|
num_fold=field(5, "The number of folds to run cross validation", int), |
|
**run.default_except( |
|
workspace=field( |
|
"???", |
|
"The root workspace for all folds.\n" |
|
"Each fold will use a 'fold_{id}' sub-workspace under this root workspace", |
|
), |
|
setup=dict( |
|
corpus=dict( |
|
test_fold=field( |
|
"TBD", "This will be auto-set by 'run_cross_validation'" |
|
) |
|
) |
|
), |
|
), |
|
) |
|
@classmethod |
|
def cross_validation(cls, **cfg): |
|
""" |
|
Except 'num_fold', all other fields are for 'run' for every fold. That is, all folds shared the same |
|
config (training hypers, dataset root, etc) except 'workspace' and 'test_fold' are different |
|
""" |
|
cfg = Container(cfg) |
|
workspaces = [ |
|
str(Workspace(cfg.workspace) / f"fold_{fold_id}") |
|
for fold_id in range(cfg.num_fold) |
|
] |
|
for fold_id, workspace in enumerate(workspaces): |
|
fold_cfg = cfg.clone().deselect("num_fold") |
|
|
|
fold_cfg.workspace = workspace |
|
fold_cfg.setup.corpus.test_fold = fold_id |
|
cls.run( |
|
**fold_cfg, |
|
) |
|
metrics = defaultdict(list) |
|
for fold_id, workspace in enumerate(workspaces): |
|
workspace = Workspace(workspace) |
|
metric = workspace["test_metrics"] |
|
for key, value in metric.items(): |
|
metrics[key].append(value) |
|
|
|
avg_result = dict() |
|
for key, values in metrics.items(): |
|
avg_score = sum(values) / len(values) |
|
avg_result[key] = avg_score |
|
logger.info(f"Average {key}: {avg_score}") |
|
|
|
Workspace(cfg.workspace).put(avg_result, "avg_test_metrics", "yaml") |
|
|