File size: 1,134 Bytes
241adf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from mmcv.runner import Hook
from mmpose.utils import get_root_logger
from torch.utils.data import DataLoader


class ShufflePairedSamplesHook(Hook):
    """Non-Distributed ShufflePairedSamples.
    After each training epoch, run FewShotKeypointDataset.random_paired_samples()
    """

    def __init__(self,
                 dataloader,
                 interval=1):
        if not isinstance(dataloader, DataLoader):
            raise TypeError(f'dataloader must be a pytorch DataLoader, '
                            f'but got {type(dataloader)}')

        self.dataloader = dataloader
        self.interval = interval
        self.logger = get_root_logger()

    def after_train_epoch(self, runner):
        """Called after every training epoch to evaluate the results."""
        if not self.every_n_epochs(runner, self.interval):
            return
        # self.logger.info("Run random_paired_samples()")
        # self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}")
        self.dataloader.dataset.random_paired_samples()
        # self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}")