my-cool-model / data /dataloader.py
crapthings's picture
Upload folder using huggingface_hub
f7f604d
import os
import cv2
import sys
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from PIL import Image
from threading import Thread
filepath = os.path.split(__file__)[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from data.custom_transforms import *
from utils.misc import *
Image.MAX_IMAGE_PIXELS = None
def get_transform(tfs):
comp = []
for key, value in zip(tfs.keys(), tfs.values()):
if value is not None:
tf = eval(key)(**value)
else:
tf = eval(key)()
comp.append(tf)
return transforms.Compose(comp)
class RGB_Dataset(Dataset):
def __init__(self, root, sets, tfs):
self.images, self.gts = [], []
for set in sets:
image_root, gt_root = os.path.join(root, set, 'images'), os.path.join(root, set, 'masks')
images = [os.path.join(image_root, f) for f in os.listdir(image_root) if f.lower().endswith(('.jpg', '.png'))]
images = sort(images)
gts = [os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.lower().endswith(('.jpg', '.png'))]
gts = sort(gts)
self.images.extend(images)
self.gts.extend(gts)
self.filter_files()
self.size = len(self.images)
self.transform = get_transform(tfs)
def __getitem__(self, index):
image = Image.open(self.images[index]).convert('RGB')
gt = Image.open(self.gts[index]).convert('L')
shape = gt.size[::-1]
name = self.images[index].split(os.sep)[-1]
name = os.path.splitext(name)[0]
sample = {'image': image, 'gt': gt, 'name': name, 'shape': shape}
sample = self.transform(sample)
return sample
def filter_files(self):
assert len(self.images) == len(self.gts)
images, gts = [], []
for img_path, gt_path in zip(self.images, self.gts):
img, gt = Image.open(img_path), Image.open(gt_path)
if img.size == gt.size:
images.append(img_path)
gts.append(gt_path)
self.images, self.gts = images, gts
def __len__(self):
return self.size
class ImageLoader:
def __init__(self, root, tfs):
if os.path.isdir(root):
self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
self.images = sort(self.images)
elif os.path.isfile(root):
self.images = [root]
self.size = len(self.images)
self.transform = get_transform(tfs)
def __iter__(self):
self.index = 0
return self
def __next__(self):
if self.index == self.size:
raise StopIteration
image = Image.open(self.images[self.index]).convert('RGB')
shape = image.size[::-1]
name = self.images[self.index].split(os.sep)[-1]
name = os.path.splitext(name)[0]
sample = {'image': image, 'name': name, 'shape': shape, 'original': image}
sample = self.transform(sample)
sample['image'] = sample['image'].unsqueeze(0)
if 'image_resized' in sample.keys():
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
self.index += 1
return sample
def __len__(self):
return self.size
class VideoLoader:
def __init__(self, root, tfs):
if os.path.isdir(root):
self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))]
elif os.path.isfile(root):
self.videos = [root]
self.size = len(self.videos)
self.transform = get_transform(tfs)
def __iter__(self):
self.index = 0
self.cap = None
self.fps = None
return self
def __next__(self):
if self.index == self.size:
raise StopIteration
if self.cap is None:
self.cap = cv2.VideoCapture(self.videos[self.index])
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
ret, frame = self.cap.read()
name = self.videos[self.index].split(os.sep)[-1]
name = os.path.splitext(name)[0]
if ret is False:
self.cap.release()
self.cap = None
sample = {'image': None, 'shape': None, 'name': name, 'original': None}
self.index += 1
else:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame).convert('RGB')
shape = image.size[::-1]
sample = {'image': image, 'shape': shape, 'name': name, 'original': image}
sample = self.transform(sample)
sample['image'] = sample['image'].unsqueeze(0)
if 'image_resized' in sample.keys():
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
return sample
def __len__(self):
return self.size
class WebcamLoader:
def __init__(self, ID, tfs):
self.ID = int(ID)
self.cap = cv2.VideoCapture(self.ID)
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
self.transform = get_transform(tfs)
self.imgs = []
self.imgs.append(self.cap.read()[1])
self.thread = Thread(target=self.update, daemon=True)
self.thread.start()
def update(self):
while self.cap.isOpened():
ret, frame = self.cap.read()
if ret is True:
self.imgs.append(frame)
else:
break
def __iter__(self):
return self
def __next__(self):
if len(self.imgs) > 0:
frame = self.imgs[-1]
else:
frame = np.zeros((480, 640, 3)).astype(np.uint8)
if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'):
cv2.destroyAllWindows()
raise StopIteration
else:
image = Image.fromarray(frame).convert('RGB')
shape = image.size[::-1]
sample = {'image': image, 'shape': shape, 'name': 'webcam', 'original': image}
sample = self.transform(sample)
sample['image'] = sample['image'].unsqueeze(0)
if 'image_resized' in sample.keys():
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
del self.imgs[:-1]
return sample
def __len__(self):
return 0
class RefinementLoader:
def __init__(self, image_dir, seg_dir, tfs):
self.images = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
self.images = sort(self.images)
self.segs = [os.path.join(seg_dir, f) for f in os.listdir(seg_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
self.segs = sort(self.segs)
self.size = len(self.images)
self.transform = get_transform(tfs)
def __iter__(self):
self.index = 0
return self
def __next__(self):
if self.index == self.size:
raise StopIteration
image = Image.open(self.images[self.index]).convert('RGB')
seg = Image.open(self.segs[self.index]).convert('L')
shape = image.size[::-1]
name = self.images[self.index].split(os.sep)[-1]
name = os.path.splitext(name)[0]
sample = {'image': image, 'gt': seg, 'name': name, 'shape': shape, 'original': image}
sample = self.transform(sample)
sample['image'] = sample['image'].unsqueeze(0)
sample['mask'] = sample['gt'].unsqueeze(0)
if 'image_resized' in sample.keys():
sample['image_resized'] = sample['image_resized'].unsqueeze(0)
del sample['gt']
self.index += 1
return sample
def __len__(self):
return self.size