orhir's picture
Upload 97 files
241adf2
raw
history blame
1.13 kB
from mmcv.runner import Hook
from mmpose.utils import get_root_logger
from torch.utils.data import DataLoader
class ShufflePairedSamplesHook(Hook):
"""Non-Distributed ShufflePairedSamples.
After each training epoch, run FewShotKeypointDataset.random_paired_samples()
"""
def __init__(self,
dataloader,
interval=1):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
f'but got {type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.logger = get_root_logger()
def after_train_epoch(self, runner):
"""Called after every training epoch to evaluate the results."""
if not self.every_n_epochs(runner, self.interval):
return
# self.logger.info("Run random_paired_samples()")
# self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}")
self.dataloader.dataset.random_paired_samples()
# self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}")