Spaces:
Runtime error
Runtime error
File size: 1,699 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
from mmengine.runner import IterBasedTrainLoop
from torch.utils.data import DataLoader
class TrainLoop(IterBasedTrainLoop):
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: Optional[int] = None,
max_epochs: Union[int, float] = None,
**kwargs) -> None:
if max_iters is None and max_epochs is None:
raise RuntimeError('Please specify the `max_iters` or '
'`max_epochs` in `train_cfg`.')
elif max_iters is not None and max_epochs is not None:
raise RuntimeError('Only one of `max_iters` or `max_epochs` can '
'exist in `train_cfg`.')
else:
if max_iters is not None:
iters = int(max_iters)
assert iters == max_iters, ('`max_iters` should be a integer '
f'number, but get {max_iters}')
elif max_epochs is not None:
if isinstance(dataloader, dict):
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
dataloader = runner.build_dataloader(
dataloader,
seed=runner.seed,
diff_rank_seed=diff_rank_seed)
iters = max_epochs * len(dataloader)
else:
raise NotImplementedError
super().__init__(
runner=runner, dataloader=dataloader, max_iters=iters, **kwargs)
|