pg56714's picture
Upload 115 files
8e5cc83 verified
import json
import numpy as np
import torch
import torchvision.transforms as transforms
from pycocotools import mask as mask_utils
from skimage import io
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from efficientvit.apps.data_provider import DataProvider
from efficientvit.samcore.data_provider.utils import (
Normalize_and_Pad,
RandomHFlip,
ResizeLongestSide,
SAMDistributedSampler,
)
__all__ = ["SAMDataProvider"]
class OnlineDataset(Dataset):
def __init__(self, root, train=True, num_masks=64, transform=None):
self.root = root
self.train = train
self.num_masks = num_masks
self.transform = transform
self.data = open(f"{self.root}/sa_images_ids.txt", "r").read().splitlines()
if self.train:
self.data = self.data[: int(len(self.data) * 0.99)]
else:
self.data = self.data[int(len(self.data) * 0.99) :]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""
Note: We provide the simplest data organization here. You can modify the code according to your data organization.
"""
index = int(self.data[idx])
image_path = f"{self.root}/images/sa_{index}.jpg"
image = io.imread(image_path)
json_path = f"{self.root}/masks/sa_{index}.json"
annotations = json.load(open(json_path))["annotations"]
if self.train:
if len(annotations) > self.num_masks:
r = np.random.choice(len(annotations), size=self.num_masks, replace=False)
else:
repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations)
r = np.random.choice(len(annotations), size=residue, replace=False)
r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0)
else:
if len(annotations) > self.num_masks:
r = np.arange(self.num_masks)
else:
repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations)
r = np.arange(residue)
r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0)
masks = np.stack([mask_utils.decode(annotations[i]["segmentation"]) for i in r])
points = np.stack([annotations[i]["point_coords"][0] for i in r])
bboxs = np.stack([annotations[i]["bbox"] for i in r])
image = torch.tensor(image, dtype=torch.float32)
image = torch.transpose(torch.transpose(image, 1, 2), 0, 1)
masks = torch.tensor(masks, dtype=torch.float32)
points = torch.tensor(points, dtype=torch.float32)
bboxs = torch.tensor(bboxs, dtype=torch.float32)
sample = {
"image": image,
"masks": masks,
"points": points,
"bboxs": bboxs,
"shape": torch.tensor(image.shape[-2:]),
}
if self.transform:
sample = self.transform(sample)
return sample
class SAMDataProvider(DataProvider):
name = "sam"
def __init__(
self,
root: str,
sub_epochs_per_epoch: int,
num_masks: int,
train_batch_size: int,
test_batch_size: int,
valid_size: int or float or None = None,
n_worker=8,
image_size: int = 1024,
num_replicas: int or None = None,
rank: int or None = None,
train_ratio: float or None = None,
drop_last: bool = False,
):
self.root = root
self.num_masks = num_masks
self.sub_epochs_per_epoch = sub_epochs_per_epoch
super().__init__(
train_batch_size,
test_batch_size,
valid_size,
n_worker,
image_size,
num_replicas,
rank,
train_ratio,
drop_last,
)
def build_train_transform(self):
train_transforms = [
RandomHFlip(),
ResizeLongestSide(target_length=self.image_size[0]),
Normalize_and_Pad(target_length=self.image_size[0]),
]
return transforms.Compose(train_transforms)
def build_valid_transform(self):
valid_transforms = [
ResizeLongestSide(target_length=self.image_size[0]),
Normalize_and_Pad(target_length=self.image_size[0]),
]
return transforms.Compose(valid_transforms)
def build_datasets(self) -> tuple[any, any, any]:
train_transform = self.build_train_transform()
valid_transform = self.build_valid_transform()
train_dataset = OnlineDataset(root=self.root, train=True, num_masks=self.num_masks, transform=train_transform)
val_dataset = OnlineDataset(root=self.root, train=False, num_masks=2, transform=valid_transform)
test_dataset = None
return train_dataset, val_dataset, test_dataset
def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool):
if dataset is None:
return None
if train:
sampler = SAMDistributedSampler(dataset, sub_epochs_per_epoch=self.sub_epochs_per_epoch)
dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=True, num_workers=n_worker)
return dataloader
else:
sampler = DistributedSampler(dataset, shuffle=False)
dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=False, num_workers=n_worker)
return dataloader
def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None:
if isinstance(self.train.sampler, SAMDistributedSampler):
self.train.sampler.set_epoch_and_sub_epoch(epoch, sub_epoch)