|
import json |
|
import cv2 |
|
import numpy as np |
|
import os |
|
|
|
from torch.utils.data import Dataset |
|
import pycocotools.mask as maskUtils |
|
from torchvision import transforms |
|
import utils.transforms as custom_transforms |
|
from PIL import Image |
|
class SAMDataset(Dataset): |
|
def __init__(self, data_path='../data/files', txt_path='../data/data_85616.txt'): |
|
self.data = [] |
|
with open(txt_path, 'rt') as f: |
|
for line in f: |
|
self.data.append(eval(line)) |
|
self.data_path = data_path |
|
randomresizedcrop = custom_transforms.RandomResizedCrop( |
|
512, |
|
scale=(0.9, 1), |
|
) |
|
self.transform = custom_transforms.Compose([ |
|
randomresizedcrop, |
|
custom_transforms.RandomHorizontalFlip(p=0.5), |
|
custom_transforms.ToTensor(), |
|
transforms.Normalize(mean=0.5, std=0.5) |
|
]) |
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def load_rle_annotations_from_json(self, json_file_path, return_pil=True): |
|
with open(json_file_path, 'r', encoding='utf-8') as f: |
|
anno_data = json.load(f) |
|
annotations = anno_data['annotations'] |
|
height = int(anno_data['image']['height']) |
|
width = int(anno_data['image']['width']) |
|
|
|
map = np.zeros((height,width), dtype=np.uint16) |
|
for i in range(len(annotations)): |
|
ann = annotations[i] |
|
mask = maskUtils.decode(ann['segmentation']) |
|
map[mask != 0] = i + 1 |
|
if return_pil: |
|
res = np.zeros((map.shape[0], map.shape[1], 3)) |
|
res[:, :, 0] = map % 256 |
|
res[:, :, 1] = map // 256 |
|
res = Image.fromarray(res.astype(np.uint8)) |
|
return res |
|
return map |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
|
|
source_filename = item['source'] |
|
target_filename = item['target'] |
|
prompt = item['prompt'] |
|
|
|
|
|
source = self.load_rle_annotations_from_json(os.path.join(self.data_path, source_filename)) |
|
target = Image.open(os.path.join(self.data_path, target_filename)) |
|
|
|
|
|
target, source = self.transform(target, source) |
|
|
|
print(source.max(), source.min()) |
|
target = target.permute(1,2,0) |
|
|
|
return dict(jpg=target, txt=prompt, hint=source) |
|
|
|
|