Spaces:
Sleeping
Sleeping
File size: 3,475 Bytes
fa7be76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from pathlib import Path
import pytest
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from src.train import train
from tests.helpers.run_if import RunIf
def test_train_fast_dev_run(cfg_train: DictConfig) -> None:
"""Run for 1 train, val and test step.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
HydraConfig().set_config(cfg_train)
with open_dict(cfg_train):
cfg_train.trainer.fast_dev_run = True
cfg_train.trainer.accelerator = "cpu"
train(cfg_train)
@RunIf(min_gpus=1)
def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None:
"""Run for 1 train, val and test step on GPU.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
HydraConfig().set_config(cfg_train)
with open_dict(cfg_train):
cfg_train.trainer.fast_dev_run = True
cfg_train.trainer.accelerator = "gpu"
train(cfg_train)
@RunIf(min_gpus=1)
@pytest.mark.slow
def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None:
"""Train 1 epoch on GPU with mixed-precision.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
HydraConfig().set_config(cfg_train)
with open_dict(cfg_train):
cfg_train.trainer.max_epochs = 1
cfg_train.trainer.accelerator = "gpu"
cfg_train.trainer.precision = 16
train(cfg_train)
@pytest.mark.slow
def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None:
"""Train 1 epoch with validation loop twice per epoch.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
HydraConfig().set_config(cfg_train)
with open_dict(cfg_train):
cfg_train.trainer.max_epochs = 1
cfg_train.trainer.val_check_interval = 0.5
train(cfg_train)
@pytest.mark.slow
def test_train_ddp_sim(cfg_train: DictConfig) -> None:
"""Simulate DDP (Distributed Data Parallel) on 2 CPU processes.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
HydraConfig().set_config(cfg_train)
with open_dict(cfg_train):
cfg_train.trainer.max_epochs = 2
cfg_train.trainer.accelerator = "cpu"
cfg_train.trainer.devices = 2
cfg_train.trainer.strategy = "ddp_spawn"
train(cfg_train)
@pytest.mark.slow
def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None:
"""Run 1 epoch, finish, and resume for another epoch.
:param tmp_path: The temporary logging path.
:param cfg_train: A DictConfig containing a valid training configuration.
"""
with open_dict(cfg_train):
cfg_train.trainer.max_epochs = 1
HydraConfig().set_config(cfg_train)
metric_dict_1, _ = train(cfg_train)
files = os.listdir(tmp_path / "checkpoints")
assert "last.ckpt" in files
assert "epoch_000.ckpt" in files
with open_dict(cfg_train):
cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
cfg_train.trainer.max_epochs = 2
metric_dict_2, _ = train(cfg_train)
files = os.listdir(tmp_path / "checkpoints")
assert "epoch_001.ckpt" in files
assert "epoch_002.ckpt" not in files
assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"]
assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]
|