|
import os |
|
import os.path |
|
import sys |
|
import torch |
|
import torch.utils.data as data |
|
import cv2 |
|
import numpy as np |
|
|
|
class WiderFaceDetection(data.Dataset): |
|
def __init__(self, txt_path, preproc=None): |
|
self.preproc = preproc |
|
self.imgs_path = [] |
|
self.words = [] |
|
f = open(txt_path,'r') |
|
lines = f.readlines() |
|
isFirst = True |
|
labels = [] |
|
for line in lines: |
|
line = line.rstrip() |
|
if line.startswith('#'): |
|
if isFirst==True: |
|
isFirst = False |
|
else: |
|
labels_copy = labels.copy() |
|
self.words.append(labels_copy) |
|
labels.clear() |
|
path = line[2:] |
|
path = txt_path.replace('label.txt','images/') + path |
|
self.imgs_path.append(path) |
|
else: |
|
line = line.split(' ') |
|
label = [float(x) for x in line] |
|
labels.append(label) |
|
|
|
self.words.append(labels) |
|
|
|
def __len__(self): |
|
return len(self.imgs_path) |
|
|
|
def __getitem__(self, index): |
|
img = cv2.imread(self.imgs_path[index]) |
|
height, width, _ = img.shape |
|
|
|
labels = self.words[index] |
|
annotations = np.zeros((0, 15)) |
|
if len(labels) == 0: |
|
return annotations |
|
for idx, label in enumerate(labels): |
|
annotation = np.zeros((1, 15)) |
|
|
|
annotation[0, 0] = label[0] |
|
annotation[0, 1] = label[1] |
|
annotation[0, 2] = label[0] + label[2] |
|
annotation[0, 3] = label[1] + label[3] |
|
|
|
|
|
annotation[0, 4] = label[4] |
|
annotation[0, 5] = label[5] |
|
annotation[0, 6] = label[7] |
|
annotation[0, 7] = label[8] |
|
annotation[0, 8] = label[10] |
|
annotation[0, 9] = label[11] |
|
annotation[0, 10] = label[13] |
|
annotation[0, 11] = label[14] |
|
annotation[0, 12] = label[16] |
|
annotation[0, 13] = label[17] |
|
if (annotation[0, 4]<0): |
|
annotation[0, 14] = -1 |
|
else: |
|
annotation[0, 14] = 1 |
|
|
|
annotations = np.append(annotations, annotation, axis=0) |
|
target = np.array(annotations) |
|
if self.preproc is not None: |
|
img, target = self.preproc(img, target) |
|
|
|
return torch.from_numpy(img), target |
|
|
|
def detection_collate(batch): |
|
"""Custom collate fn for dealing with batches of images that have a different |
|
number of associated object annotations (bounding boxes). |
|
|
|
Arguments: |
|
batch: (tuple) A tuple of tensor images and lists of annotations |
|
|
|
Return: |
|
A tuple containing: |
|
1) (tensor) batch of images stacked on their 0 dim |
|
2) (list of tensors) annotations for a given image are stacked on 0 dim |
|
""" |
|
targets = [] |
|
imgs = [] |
|
for _, sample in enumerate(batch): |
|
for _, tup in enumerate(sample): |
|
if torch.is_tensor(tup): |
|
imgs.append(tup) |
|
elif isinstance(tup, type(np.empty(0))): |
|
annos = torch.from_numpy(tup).float() |
|
targets.append(annos) |
|
|
|
return (torch.stack(imgs, 0), targets) |
|
|