Spaces:
Running
Running
from mmcv.runner import Hook | |
from mmpose.utils import get_root_logger | |
from torch.utils.data import DataLoader | |
class ShufflePairedSamplesHook(Hook): | |
"""Non-Distributed ShufflePairedSamples. | |
After each training epoch, run FewShotKeypointDataset.random_paired_samples() | |
""" | |
def __init__(self, | |
dataloader, | |
interval=1): | |
if not isinstance(dataloader, DataLoader): | |
raise TypeError(f'dataloader must be a pytorch DataLoader, ' | |
f'but got {type(dataloader)}') | |
self.dataloader = dataloader | |
self.interval = interval | |
self.logger = get_root_logger() | |
def after_train_epoch(self, runner): | |
"""Called after every training epoch to evaluate the results.""" | |
if not self.every_n_epochs(runner, self.interval): | |
return | |
# self.logger.info("Run random_paired_samples()") | |
# self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}") | |
self.dataloader.dataset.random_paired_samples() | |
# self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}") | |