Spaces:
Running
Running
import torch | |
from typing import Tuple, Optional | |
def augment_flip( | |
img: torch.Tensor, | |
flip: Optional[Tuple[bool, bool]] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Roll pixels horizontally to avoid negative index | |
Args: | |
img (torch.Tensor): [N, 3, H, W] image tensor | |
lab (torch.Tensor): [N, 3, H, W] label tensor | |
flip (Optional[bool], optional): forced flip_h, flip_v value. Defaults to None. | |
If not provided, a random flip_h, flip_v values are used | |
Returns: | |
torch.Tensor, torch.Tensor: flipped image, labels | |
""" | |
if flip is None: | |
flip = torch.randint(0, 2, (2,)) | |
flipped_img = img | |
if flip[0] > 0: | |
flipped_img = torch.flip(flipped_img, (-1,)) | |
if flip[1] > 0: | |
flipped_img = torch.flip(flipped_img, (-2,)) | |
return flipped_img | |