|
import os |
|
import cv2 |
|
import sys |
|
|
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
|
|
from torch.utils.data.dataset import Dataset |
|
from PIL import Image |
|
from threading import Thread |
|
|
|
filepath = os.path.split(__file__)[0] |
|
repopath = os.path.split(filepath)[0] |
|
sys.path.append(repopath) |
|
|
|
from data.custom_transforms import * |
|
from utils.misc import * |
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
def get_transform(tfs): |
|
comp = [] |
|
for key, value in zip(tfs.keys(), tfs.values()): |
|
if value is not None: |
|
tf = eval(key)(**value) |
|
else: |
|
tf = eval(key)() |
|
comp.append(tf) |
|
return transforms.Compose(comp) |
|
|
|
class RGB_Dataset(Dataset): |
|
def __init__(self, root, sets, tfs): |
|
self.images, self.gts = [], [] |
|
|
|
for set in sets: |
|
image_root, gt_root = os.path.join(root, set, 'images'), os.path.join(root, set, 'masks') |
|
|
|
images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.lower().endswith(('.jpg', '.png'))] |
|
images = sort(images) |
|
|
|
gts = [os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.lower().endswith(('.jpg', '.png'))] |
|
gts = sort(gts) |
|
|
|
self.images.extend(images) |
|
self.gts.extend(gts) |
|
|
|
self.filter_files() |
|
|
|
self.size = len(self.images) |
|
self.transform = get_transform(tfs) |
|
|
|
def __getitem__(self, index): |
|
image = Image.open(self.images[index]).convert('RGB') |
|
gt = Image.open(self.gts[index]).convert('L') |
|
shape = gt.size[::-1] |
|
name = self.images[index].split(os.sep)[-1] |
|
name = os.path.splitext(name)[0] |
|
|
|
sample = {'image': image, 'gt': gt, 'name': name, 'shape': shape} |
|
|
|
sample = self.transform(sample) |
|
return sample |
|
|
|
def filter_files(self): |
|
assert len(self.images) == len(self.gts) |
|
images, gts = [], [] |
|
for img_path, gt_path in zip(self.images, self.gts): |
|
img, gt = Image.open(img_path), Image.open(gt_path) |
|
if img.size == gt.size: |
|
images.append(img_path) |
|
gts.append(gt_path) |
|
self.images, self.gts = images, gts |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
class ImageLoader: |
|
def __init__(self, root, tfs): |
|
if os.path.isdir(root): |
|
self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] |
|
self.images = sort(self.images) |
|
elif os.path.isfile(root): |
|
self.images = [root] |
|
self.size = len(self.images) |
|
self.transform = get_transform(tfs) |
|
|
|
def __iter__(self): |
|
self.index = 0 |
|
return self |
|
|
|
def __next__(self): |
|
if self.index == self.size: |
|
raise StopIteration |
|
image = Image.open(self.images[self.index]).convert('RGB') |
|
shape = image.size[::-1] |
|
name = self.images[self.index].split(os.sep)[-1] |
|
name = os.path.splitext(name)[0] |
|
|
|
sample = {'image': image, 'name': name, 'shape': shape, 'original': image} |
|
sample = self.transform(sample) |
|
sample['image'] = sample['image'].unsqueeze(0) |
|
if 'image_resized' in sample.keys(): |
|
sample['image_resized'] = sample['image_resized'].unsqueeze(0) |
|
|
|
self.index += 1 |
|
return sample |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
class VideoLoader: |
|
def __init__(self, root, tfs): |
|
if os.path.isdir(root): |
|
self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))] |
|
elif os.path.isfile(root): |
|
self.videos = [root] |
|
self.size = len(self.videos) |
|
self.transform = get_transform(tfs) |
|
|
|
def __iter__(self): |
|
self.index = 0 |
|
self.cap = None |
|
self.fps = None |
|
return self |
|
|
|
def __next__(self): |
|
if self.index == self.size: |
|
raise StopIteration |
|
|
|
if self.cap is None: |
|
self.cap = cv2.VideoCapture(self.videos[self.index]) |
|
self.fps = self.cap.get(cv2.CAP_PROP_FPS) |
|
ret, frame = self.cap.read() |
|
name = self.videos[self.index].split(os.sep)[-1] |
|
name = os.path.splitext(name)[0] |
|
if ret is False: |
|
self.cap.release() |
|
self.cap = None |
|
sample = {'image': None, 'shape': None, 'name': name, 'original': None} |
|
self.index += 1 |
|
|
|
else: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
image = Image.fromarray(frame).convert('RGB') |
|
shape = image.size[::-1] |
|
sample = {'image': image, 'shape': shape, 'name': name, 'original': image} |
|
sample = self.transform(sample) |
|
sample['image'] = sample['image'].unsqueeze(0) |
|
if 'image_resized' in sample.keys(): |
|
sample['image_resized'] = sample['image_resized'].unsqueeze(0) |
|
|
|
return sample |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
|
|
class WebcamLoader: |
|
def __init__(self, ID, tfs): |
|
self.ID = int(ID) |
|
self.cap = cv2.VideoCapture(self.ID) |
|
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) |
|
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) |
|
self.transform = get_transform(tfs) |
|
self.imgs = [] |
|
self.imgs.append(self.cap.read()[1]) |
|
self.thread = Thread(target=self.update, daemon=True) |
|
self.thread.start() |
|
|
|
def update(self): |
|
while self.cap.isOpened(): |
|
ret, frame = self.cap.read() |
|
if ret is True: |
|
self.imgs.append(frame) |
|
else: |
|
break |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
if len(self.imgs) > 0: |
|
frame = self.imgs[-1] |
|
else: |
|
frame = np.zeros((480, 640, 3)).astype(np.uint8) |
|
if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'): |
|
cv2.destroyAllWindows() |
|
raise StopIteration |
|
|
|
else: |
|
image = Image.fromarray(frame).convert('RGB') |
|
shape = image.size[::-1] |
|
sample = {'image': image, 'shape': shape, 'name': 'webcam', 'original': image} |
|
sample = self.transform(sample) |
|
sample['image'] = sample['image'].unsqueeze(0) |
|
if 'image_resized' in sample.keys(): |
|
sample['image_resized'] = sample['image_resized'].unsqueeze(0) |
|
|
|
del self.imgs[:-1] |
|
return sample |
|
|
|
|
|
def __len__(self): |
|
return 0 |
|
|
|
class RefinementLoader: |
|
def __init__(self, image_dir, seg_dir, tfs): |
|
self.images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] |
|
self.images = sort(self.images) |
|
|
|
self.segs = [os.path.join(seg_dir, f) for f in os.listdir(seg_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] |
|
self.segs = sort(self.segs) |
|
|
|
self.size = len(self.images) |
|
self.transform = get_transform(tfs) |
|
|
|
def __iter__(self): |
|
self.index = 0 |
|
return self |
|
|
|
def __next__(self): |
|
if self.index == self.size: |
|
raise StopIteration |
|
image = Image.open(self.images[self.index]).convert('RGB') |
|
seg = Image.open(self.segs[self.index]).convert('L') |
|
shape = image.size[::-1] |
|
name = self.images[self.index].split(os.sep)[-1] |
|
name = os.path.splitext(name)[0] |
|
|
|
sample = {'image': image, 'gt': seg, 'name': name, 'shape': shape, 'original': image} |
|
sample = self.transform(sample) |
|
sample['image'] = sample['image'].unsqueeze(0) |
|
sample['mask'] = sample['gt'].unsqueeze(0) |
|
if 'image_resized' in sample.keys(): |
|
sample['image_resized'] = sample['image_resized'].unsqueeze(0) |
|
del sample['gt'] |
|
|
|
self.index += 1 |
|
return sample |
|
|
|
def __len__(self): |
|
return self.size |
|
|