File size: 986 Bytes
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import abc
from typing import List, Optional, Tuple
from torch import nn, Tensor
from PIL import Image


class BaseDataAugmentation(nn.Module):
    def __init__(self):
        super(BaseDataAugmentation, self).__init__()

    @abc.abstractmethod
    def forward(
        self,
        image: Image.Image,
        depth: Image.Image,
        gt: Optional[Image.Image] = None,
        ranking_gt: Optional[Image.Image] = None,
        multi_gts: Optional[List[Image.Image]] = None,
        is_transform: bool = True,  # is augmented?
        is_debug: bool = False,
    ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
        """
        Usual case:
            If gt is provided, return [image, depth, gt]
            Otherwise, return [image, depth]

        When ranking_gt is provided, gt will be ignored
            Return [image, depth, ranking_gt]

        For debugging:
            Return [image, depth, gt|ranking_gt, unnormalized, Optional[ranking_gts]]
        """
        pass