pg56714's picture
Upload 115 files
8e5cc83 verified
import random
from copy import deepcopy
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
class SAMDistributedSampler(DistributedSampler):
"""
Modified from https://github.com/pytorch/pytorch/blob/97261be0a8f09bed9ab95d0cee82e75eebd249c3/torch/utils/data/distributed.py.
"""
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
sub_epochs_per_epoch: int = 1,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.sub_epoch = 0
self.sub_epochs_per_epoch = sub_epochs_per_epoch
self.set_sub_num_samples()
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
indices = indices[(self.sub_epoch % self.sub_epochs_per_epoch) :: self.sub_epochs_per_epoch]
return iter(indices)
def __len__(self) -> int:
return self.sub_num_samples
def set_sub_num_samples(self) -> int:
self.sub_num_samples = self.num_samples // self.sub_epochs_per_epoch
if self.sub_num_samples % self.sub_epochs_per_epoch > self.sub_epoch:
self.sub_num_samples += 1
def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None:
r"""
Set the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
sub_epoch (int): Sub epoch number.
"""
self.epoch = epoch
self.sub_epoch = sub_epoch
self.set_sub_num_samples()
class RandomHFlip(object):
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, sample):
image, masks, points, bboxs, shape = (
sample["image"],
sample["masks"],
sample["points"],
sample["bboxs"],
sample["shape"],
)
if random.random() >= self.prob:
image = torch.flip(image, dims=[2])
masks = torch.flip(masks, dims=[2])
points = deepcopy(points).to(torch.float)
bboxs = deepcopy(bboxs).to(torch.float)
points[:, 0] = shape[-1] - points[:, 0]
bboxs[:, 0] = shape[-1] - bboxs[:, 2] - bboxs[:, 0]
return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}
class ResizeLongestSide(object):
"""
Modified from https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/transforms.py.
"""
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
target_size = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
def apply_boxes(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_coords(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def __call__(self, sample):
image, masks, points, bboxs, shape = (
sample["image"],
sample["masks"],
sample["points"],
sample["bboxs"],
sample["shape"],
)
image = self.apply_image(image.unsqueeze(0), shape).squeeze(0)
masks = self.apply_image(masks.unsqueeze(1), shape).squeeze(1)
points = self.apply_coords(points, shape)
bboxs = self.apply_boxes(bboxs, shape)
return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}
class Normalize_and_Pad(object):
def __init__(self, target_length: int) -> None:
self.target_length = target_length
self.transform = transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
def __call__(self, sample):
image, masks, points, bboxs, shape = (
sample["image"],
sample["masks"],
sample["points"],
sample["bboxs"],
sample["shape"],
)
h, w = image.shape[-2:]
image = self.transform(image)
padh = self.target_length - h
padw = self.target_length - w
image = F.pad(image.unsqueeze(0), (0, padw, 0, padh), value=0).squeeze(0)
masks = F.pad(masks.unsqueeze(1), (0, padw, 0, padh), value=0).squeeze(1)
return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}