image-deblurring / src /rstor /data /augmentation.py
balthou's picture
initiate demo
cec5823
raw
history blame contribute delete
837 Bytes
import torch
from typing import Tuple, Optional
def augment_flip(
img: torch.Tensor,
flip: Optional[Tuple[bool, bool]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Roll pixels horizontally to avoid negative index
Args:
img (torch.Tensor): [N, 3, H, W] image tensor
lab (torch.Tensor): [N, 3, H, W] label tensor
flip (Optional[bool], optional): forced flip_h, flip_v value. Defaults to None.
If not provided, a random flip_h, flip_v values are used
Returns:
torch.Tensor, torch.Tensor: flipped image, labels
"""
if flip is None:
flip = torch.randint(0, 2, (2,))
flipped_img = img
if flip[0] > 0:
flipped_img = torch.flip(flipped_img, (-1,))
if flip[1] > 0:
flipped_img = torch.flip(flipped_img, (-2,))
return flipped_img