zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
from mmengine.runner import IterBasedTrainLoop
from torch.utils.data import DataLoader
class TrainLoop(IterBasedTrainLoop):
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: Optional[int] = None,
max_epochs: Union[int, float] = None,
**kwargs) -> None:
if max_iters is None and max_epochs is None:
raise RuntimeError('Please specify the `max_iters` or '
'`max_epochs` in `train_cfg`.')
elif max_iters is not None and max_epochs is not None:
raise RuntimeError('Only one of `max_iters` or `max_epochs` can '
'exist in `train_cfg`.')
else:
if max_iters is not None:
iters = int(max_iters)
assert iters == max_iters, ('`max_iters` should be a integer '
f'number, but get {max_iters}')
elif max_epochs is not None:
if isinstance(dataloader, dict):
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
dataloader = runner.build_dataloader(
dataloader,
seed=runner.seed,
diff_rank_seed=diff_rank_seed)
iters = max_epochs * len(dataloader)
else:
raise NotImplementedError
super().__init__(
runner=runner, dataloader=dataloader, max_iters=iters, **kwargs)