lmzjms's picture
Upload 1162 files
0b32ad6 verified
"""
The backbone run procedure for the common train/valid/test
Authors
* Leo 2022
"""
import logging
import pickle
import shutil
from pathlib import Path
import pandas as pd
import torch
import yaml
from s3prl.problem.base import Problem
from s3prl.task.utterance_classification_task import UtteranceClassificationTask
logger = logging.getLogger(__name__)
__all__ = ["Common"]
class Common(Problem):
def run(
self,
target_dir: str,
cache_dir: str = None,
remove_all_cache: bool = False,
start: int = 0,
stop: int = None,
num_workers: int = 6,
eval_batch: int = -1,
device: str = "cuda",
world_size: int = 1,
rank: int = 0,
test_ckpt_dir: str = None,
prepare_data: dict = None,
build_encoder: dict = None,
build_dataset: dict = None,
build_batch_sampler: dict = None,
build_collate_fn: dict = None,
build_upstream: dict = None,
build_featurizer: dict = None,
build_downstream: dict = None,
build_model: dict = None,
build_task: dict = None,
build_optimizer: dict = None,
build_scheduler: dict = None,
save_model: dict = None,
save_task: dict = None,
train: dict = None,
evaluate: dict = None,
):
"""
======== ====================
stage description
======== ====================
0 Parse the corpus and save the metadata file (waveform path, label...)
1 Build the encoder to encode the labels
2 Train the model
3 Evaluate the model on multiple test sets
======== ====================
Args:
target_dir (str):
The directory that stores the script result.
cache_dir (str):
The directory that caches the processed data.
Default: /home/user/.cache/s3prl/data
remove_all_cache (bool):
Whether to remove all the cache stored under `cache_dir`.
Default: False
start (int):
The starting stage of the problem script.
Default: 0
stop (int):
The stoping stage of the problem script, set `None` to reach the final stage.
Default: None
num_workers (int): num_workers for all the torch DataLoder
eval_batch (int):
During evaluation (valid or test), limit the number of batch.
This is helpful for the fast development to check everything won't crash.
If is -1, disable this feature and evaluate the entire epoch.
Default: -1
device (str):
The device type for all torch-related operation: "cpu" or "cuda"
Default: "cuda"
world_size (int):
How many processes are running this script simultaneously (in parallel).
Usually this is just 1, however if you are runnig distributed training,
this should be > 1.
Default: 1
rank (int):
When distributed training, world_size > 1. Take :code:`world_size == 8` for
example, this means 8 processes (8 GPUs) are runing in parallel. The script
needs to know which process among 8 processes it is. In this case, :code:`rank`
can range from 0~7. All the 8 processes have the same :code:`world_size` but
different :code:`rank` (process id).
test_ckpt_dir (str):
Specify the checkpoint path for testing. If not, use the validation best
checkpoint under the given :code:`target_dir` directory.
**kwds:
The other arguments like :code:`prepare_data` and :code:`build_model` are
method specific-arguments for methods like :obj:`prepare_data` and
:obj:`build_model`, and will not be used in the core :obj:`run` logic.
See the specific method documentation for their supported arguments and
meaning
"""
yaml_path = Path(target_dir) / "configs" / f"{self._get_time_tag()}.yaml"
yaml_path.parent.mkdir(exist_ok=True, parents=True)
with yaml_path.open("w") as f:
yaml.safe_dump(self._get_current_arguments(), f)
cache_dir: str = cache_dir or Path.home() / ".cache" / "s3prl" / "data"
prepare_data: dict = prepare_data or {}
build_encoder: dict = build_encoder or {}
build_dataset: dict = build_dataset or {}
build_batch_sampler: dict = build_batch_sampler or {}
build_collate_fn: dict = build_collate_fn or {}
build_upstream: dict = build_upstream or {}
build_featurizer: dict = build_featurizer or {}
build_downstream: dict = build_downstream or {}
build_model: dict = build_model or {}
build_task: dict = build_task or {}
build_optimizer: dict = build_optimizer or {}
build_scheduler: dict = build_scheduler or {}
save_model: dict = save_model or {}
save_task: dict = save_task or {}
train: dict = train or {}
evaluate = evaluate or {}
target_dir: Path = Path(target_dir)
target_dir.mkdir(exist_ok=True, parents=True)
cache_dir = Path(cache_dir)
cache_dir.mkdir(exist_ok=True, parents=True)
if remove_all_cache:
shutil.rmtree(cache_dir)
stage_id = 0
if start <= stage_id:
logger.info(f"Stage {stage_id}: prepare data")
train_csv, valid_csv, test_csvs = self.prepare_data(
prepare_data, target_dir, cache_dir, get_path_only=False
)
train_csv, valid_csv, test_csvs = self.prepare_data(
prepare_data, target_dir, cache_dir, get_path_only=True
)
def check_fn():
assert Path(train_csv).is_file() and Path(valid_csv).is_file()
for test_csv in test_csvs:
assert Path(test_csv).is_file()
self._stage_check(stage_id, stop, check_fn)
stage_id = 1
if start <= stage_id:
logger.info(f"Stage {stage_id}: build encoder")
encoder_path = self.build_encoder(
build_encoder,
target_dir,
cache_dir,
train_csv,
valid_csv,
test_csvs,
get_path_only=False,
)
encoder_path = self.build_encoder(
build_encoder,
target_dir,
cache_dir,
train_csv,
valid_csv,
test_csvs,
get_path_only=True,
)
def check_fn():
assert Path(encoder_path).is_file()
self._stage_check(stage_id, stop, check_fn)
with open(encoder_path, "rb") as f:
encoder = pickle.load(f)
model_output_size = len(encoder)
model = self.build_model(
build_model,
model_output_size,
build_upstream,
build_featurizer,
build_downstream,
)
frame_shift = model.downsample_rate
stage_id = 2
train_dir = target_dir / "train"
if start <= stage_id:
logger.info(f"Stage {stage_id}: Train Model")
train_ds, train_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"train",
train_csv,
encoder_path,
frame_shift,
build_dataset,
build_batch_sampler,
)
valid_ds, valid_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"valid",
valid_csv,
encoder_path,
frame_shift,
build_dataset,
build_batch_sampler,
)
with Path(encoder_path).open("rb") as f:
encoder = pickle.load(f)
build_model_all_args = dict(
build_model=build_model,
model_output_size=len(encoder),
build_upstream=build_upstream,
build_featurizer=build_featurizer,
build_downstream=build_downstream,
)
build_task_all_args_except_model = dict(
build_task=build_task,
encoder=encoder,
valid_df=pd.read_csv(valid_csv),
)
self.train(
train,
train_dir,
build_model_all_args,
build_task_all_args_except_model,
save_model,
save_task,
build_optimizer,
build_scheduler,
evaluate,
train_ds,
train_bs,
self.build_collate_fn(build_collate_fn, "train"),
valid_ds,
valid_bs,
self.build_collate_fn(build_collate_fn, "valid"),
device=device,
eval_batch=eval_batch,
num_workers=num_workers,
world_size=world_size,
rank=rank,
)
def check_fn():
assert (train_dir / "valid_best").is_dir()
self._stage_check(stage_id, stop, check_fn)
stage_id = 3
if start <= stage_id:
test_ckpt_dir: Path = Path(
test_ckpt_dir or target_dir / "train" / "valid_best"
)
assert test_ckpt_dir.is_dir()
logger.info(f"Stage {stage_id}: Test model: {test_ckpt_dir}")
for test_idx, test_csv in enumerate(test_csvs):
test_name = Path(test_csv).stem
test_dir: Path = (
target_dir
/ "evaluate"
/ test_ckpt_dir.relative_to(train_dir).as_posix().replace("/", "-")
/ test_name
)
test_dir.mkdir(exist_ok=True, parents=True)
logger.info(f"Stage {stage_id}.{test_idx}: Test model on {test_csv}")
test_ds, test_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"test",
test_csv,
encoder_path,
frame_shift,
build_dataset,
build_batch_sampler,
)
_, valid_best_task = self.load_model_and_task(
test_ckpt_dir, task_overrides={"test_df": pd.read_csv(test_csv)}
)
logs = self.evaluate(
evaluate,
"test",
valid_best_task,
test_ds,
test_bs,
self.build_collate_fn(build_collate_fn, "test"),
eval_batch,
test_dir,
device,
num_workers,
)
test_metrics = {name: float(value) for name, value in logs.items()}
logger.info(f"test results: {test_metrics}")
with (test_dir / f"result.yaml").open("w") as f:
yaml.safe_dump(test_metrics, f)
def _build_dataset_and_sampler(
self,
target_dir: str,
cache_dir: str,
mode: str,
data_csv: str,
encoder_path: str,
frame_shift: int,
build_dataset: dict,
build_batch_sampler: dict,
):
logger.info(f"Build {mode} dataset")
dataset = self.build_dataset(
build_dataset,
target_dir,
cache_dir,
mode,
data_csv,
encoder_path,
frame_shift,
)
logger.info(f"Build {mode} batch sampler")
batch_sampler = self.build_batch_sampler(
build_batch_sampler,
target_dir,
cache_dir,
mode,
data_csv,
dataset,
)
return dataset, batch_sampler
def build_task(
self,
build_task: dict,
model: torch.nn.Module,
encoder,
valid_df: pd.DataFrame = None,
test_df: pd.DataFrame = None,
):
"""
Build the task, which defines the logics for every train/valid/test forward step for the :code:`model`,
and the logics for how to reduce all the batch results from multiple train/valid/test steps into metrics
By default build :obj:`UtteranceClassificationTask`
Args:
build_task (dict): same in :obj:`default_config`, no argument supported for now
model (torch.nn.Module): the model built by :obj:`build_model`
encoder: the encoder built by :obj:`build_encoder`
Returns:
Task
"""
task = UtteranceClassificationTask(model, encoder)
return task