File size: 837 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from typing import Tuple, Optional


def augment_flip(
    img: torch.Tensor,
    flip: Optional[Tuple[bool, bool]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Roll pixels horizontally to avoid negative index

    Args:
        img (torch.Tensor): [N, 3, H, W] image tensor
        lab (torch.Tensor): [N, 3, H, W] label tensor
        flip (Optional[bool], optional): forced flip_h, flip_v value. Defaults to None.
        If not provided, a random flip_h, flip_v values are used
    Returns:
        torch.Tensor, torch.Tensor: flipped image, labels

    """
    if flip is None:
        flip = torch.randint(0, 2, (2,))
    flipped_img = img
    if flip[0] > 0:
        flipped_img = torch.flip(flipped_img, (-1,))
    if flip[1] > 0:
        flipped_img = torch.flip(flipped_img, (-2,))
    return flipped_img