Spaces:
Sleeping
Sleeping
# Copyright by HQ-SAM team | |
# All rights reserved. | |
## data loader | |
from __future__ import print_function, division | |
import numpy as np | |
import random | |
from copy import deepcopy | |
from skimage import io | |
import os | |
from glob import glob | |
import torch | |
from torch.utils.data import Dataset, DataLoader, ConcatDataset | |
from torchvision import transforms, utils | |
from torchvision.transforms.functional import normalize | |
import torch.nn.functional as F | |
from torch.utils.data.distributed import DistributedSampler | |
#### --------------------- dataloader online ---------------------#### | |
def get_im_gt_name_dict(datasets, flag='valid'): | |
print("------------------------------", flag, "--------------------------------") | |
name_im_gt_list = [] | |
for i in range(len(datasets)): | |
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") | |
tmp_im_list, tmp_gt_list = [], [] | |
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) | |
print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) | |
if(datasets[i]["gt_dir"]==""): | |
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') | |
tmp_gt_list = [] | |
else: | |
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] | |
print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) | |
name_im_gt_list.append({"dataset_name":datasets[i]["name"], | |
"im_path":tmp_im_list, | |
"gt_path":tmp_gt_list, | |
"im_ext":datasets[i]["im_ext"], | |
"gt_ext":datasets[i]["gt_ext"]}) | |
return name_im_gt_list | |
def create_dataloaders(name_im_gt_list, my_transforms=[], batch_size=1, training=False): | |
gos_dataloaders = [] | |
gos_datasets = [] | |
if(len(name_im_gt_list)==0): | |
return gos_dataloaders, gos_datasets | |
num_workers_ = 1 | |
if(batch_size>1): | |
num_workers_ = 2 | |
if(batch_size>4): | |
num_workers_ = 4 | |
if(batch_size>8): | |
num_workers_ = 8 | |
if training: | |
for i in range(len(name_im_gt_list)): | |
gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms)) | |
gos_datasets.append(gos_dataset) | |
gos_dataset = ConcatDataset(gos_datasets) | |
sampler = DistributedSampler(gos_dataset) | |
batch_sampler_train = torch.utils.data.BatchSampler( | |
sampler, batch_size, drop_last=True) | |
dataloader = DataLoader(gos_dataset, batch_sampler=batch_sampler_train, num_workers=num_workers_) | |
gos_dataloaders = dataloader | |
gos_datasets = gos_dataset | |
else: | |
for i in range(len(name_im_gt_list)): | |
gos_dataset = OnlineDataset([name_im_gt_list[i]], transform = transforms.Compose(my_transforms), eval_ori_resolution = True) | |
sampler = DistributedSampler(gos_dataset, shuffle=False) | |
dataloader = DataLoader(gos_dataset, batch_size, sampler=sampler, drop_last=False, num_workers=num_workers_) | |
gos_dataloaders.append(dataloader) | |
gos_datasets.append(gos_dataset) | |
return gos_dataloaders, gos_datasets | |
class RandomHFlip(object): | |
def __init__(self,prob=0.5): | |
self.prob = prob | |
def __call__(self,sample): | |
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] | |
# random horizontal flip | |
if random.random() >= self.prob: | |
image = torch.flip(image,dims=[2]) | |
label = torch.flip(label,dims=[2]) | |
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} | |
class Resize(object): | |
def __init__(self,size=[320,320]): | |
self.size = size | |
def __call__(self,sample): | |
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] | |
image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0) | |
label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0) | |
return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} | |
class RandomCrop(object): | |
def __init__(self,size=[288,288]): | |
self.size = size | |
def __call__(self,sample): | |
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] | |
h, w = image.shape[1:] | |
new_h, new_w = self.size | |
top = np.random.randint(0, h - new_h) | |
left = np.random.randint(0, w - new_w) | |
image = image[:,top:top+new_h,left:left+new_w] | |
label = label[:,top:top+new_h,left:left+new_w] | |
return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(self.size)} | |
class Normalize(object): | |
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): | |
self.mean = mean | |
self.std = std | |
def __call__(self,sample): | |
imidx, image, label, shape = sample['imidx'], sample['image'], sample['label'], sample['shape'] | |
image = normalize(image,self.mean,self.std) | |
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape} | |
class LargeScaleJitter(object): | |
""" | |
implementation of large scale jitter from copy_paste | |
https://github.com/gaopengcuhk/Pretrained-Pix2Seq/blob/7d908d499212bfabd33aeaa838778a6bfb7b84cc/datasets/transforms.py | |
""" | |
def __init__(self, output_size=1024, aug_scale_min=0.1, aug_scale_max=2.0): | |
self.desired_size = torch.tensor(output_size) | |
self.aug_scale_min = aug_scale_min | |
self.aug_scale_max = aug_scale_max | |
def pad_target(self, padding, target): | |
target = target.copy() | |
if "masks" in target: | |
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[1], 0, padding[0])) | |
return target | |
def __call__(self, sample): | |
imidx, image, label, image_size = sample['imidx'], sample['image'], sample['label'], sample['shape'] | |
#resize keep ratio | |
out_desired_size = (self.desired_size * image_size / max(image_size)).round().int() | |
random_scale = torch.rand(1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min | |
scaled_size = (random_scale * self.desired_size).round() | |
scale = torch.minimum(scaled_size / image_size[0], scaled_size / image_size[1]) | |
scaled_size = (image_size * scale).round().long() | |
scaled_image = torch.squeeze(F.interpolate(torch.unsqueeze(image,0),scaled_size.tolist(),mode='bilinear'),dim=0) | |
scaled_label = torch.squeeze(F.interpolate(torch.unsqueeze(label,0),scaled_size.tolist(),mode='bilinear'),dim=0) | |
# random crop | |
crop_size = (min(self.desired_size, scaled_size[0]), min(self.desired_size, scaled_size[1])) | |
margin_h = max(scaled_size[0] - crop_size[0], 0).item() | |
margin_w = max(scaled_size[1] - crop_size[1], 0).item() | |
offset_h = np.random.randint(0, margin_h + 1) | |
offset_w = np.random.randint(0, margin_w + 1) | |
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0].item() | |
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1].item() | |
scaled_image = scaled_image[:,crop_y1:crop_y2, crop_x1:crop_x2] | |
scaled_label = scaled_label[:,crop_y1:crop_y2, crop_x1:crop_x2] | |
# pad | |
padding_h = max(self.desired_size - scaled_image.size(1), 0).item() | |
padding_w = max(self.desired_size - scaled_image.size(2), 0).item() | |
image = F.pad(scaled_image, [0,padding_w, 0,padding_h],value=128) | |
label = F.pad(scaled_label, [0,padding_w, 0,padding_h],value=0) | |
return {'imidx':imidx,'image':image, 'label':label, 'shape':torch.tensor(image.shape[-2:])} | |
class OnlineDataset(Dataset): | |
def __init__(self, name_im_gt_list, transform=None, eval_ori_resolution=False): | |
self.transform = transform | |
self.dataset = {} | |
## combine different datasets into one | |
dataset_names = [] | |
dt_name_list = [] # dataset name per image | |
im_name_list = [] # image name | |
im_path_list = [] # im path | |
gt_path_list = [] # gt path | |
im_ext_list = [] # im ext | |
gt_ext_list = [] # gt ext | |
for i in range(0,len(name_im_gt_list)): | |
dataset_names.append(name_im_gt_list[i]["dataset_name"]) | |
# dataset name repeated based on the number of images in this dataset | |
dt_name_list.extend([name_im_gt_list[i]["dataset_name"] for x in name_im_gt_list[i]["im_path"]]) | |
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_list[i]["im_ext"])[0] for x in name_im_gt_list[i]["im_path"]]) | |
im_path_list.extend(name_im_gt_list[i]["im_path"]) | |
gt_path_list.extend(name_im_gt_list[i]["gt_path"]) | |
im_ext_list.extend([name_im_gt_list[i]["im_ext"] for x in name_im_gt_list[i]["im_path"]]) | |
gt_ext_list.extend([name_im_gt_list[i]["gt_ext"] for x in name_im_gt_list[i]["gt_path"]]) | |
self.dataset["data_name"] = dt_name_list | |
self.dataset["im_name"] = im_name_list | |
self.dataset["im_path"] = im_path_list | |
self.dataset["ori_im_path"] = deepcopy(im_path_list) | |
self.dataset["gt_path"] = gt_path_list | |
self.dataset["ori_gt_path"] = deepcopy(gt_path_list) | |
self.dataset["im_ext"] = im_ext_list | |
self.dataset["gt_ext"] = gt_ext_list | |
self.eval_ori_resolution = eval_ori_resolution | |
def __len__(self): | |
return len(self.dataset["im_path"]) | |
def __getitem__(self, idx): | |
im_path = self.dataset["im_path"][idx] | |
gt_path = self.dataset["gt_path"][idx] | |
im = io.imread(im_path) | |
gt = io.imread(gt_path) | |
if len(gt.shape) > 2: | |
gt = gt[:, :, 0] | |
if len(im.shape) < 3: | |
im = im[:, :, np.newaxis] | |
if im.shape[2] == 1: | |
im = np.repeat(im, 3, axis=2) | |
im = torch.tensor(im.copy(), dtype=torch.float32) | |
im = torch.transpose(torch.transpose(im,1,2),0,1) | |
gt = torch.unsqueeze(torch.tensor(gt, dtype=torch.float32),0) | |
sample = { | |
"imidx": torch.from_numpy(np.array(idx)), | |
"image": im, | |
"label": gt, | |
"shape": torch.tensor(im.shape[-2:]), | |
} | |
if self.transform: | |
sample = self.transform(sample) | |
if self.eval_ori_resolution: | |
sample["ori_label"] = gt.type(torch.uint8) # NOTE for evaluation only. And no flip here | |
sample['ori_im_path'] = self.dataset["im_path"][idx] | |
sample['ori_gt_path'] = self.dataset["gt_path"][idx] | |
return sample |