FaceRecognition-LivenessDetection-SDK
/
face_recognition
/face_detect
/vision
/datasets
/voc_dataset.py
import logging | |
import os | |
import pathlib | |
import xml.etree.ElementTree as ET | |
import h5py | |
import cv2 | |
import numpy as np | |
import lmdb | |
from .caffe_pb2 import * | |
class VOCDataset: | |
def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None): | |
"""Dataset for VOC data. | |
Args: | |
root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories: | |
Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject. | |
""" | |
self.root = "D:/test" | |
self.transform = transform | |
self.target_transform = target_transform | |
if is_test: | |
image_sets_file = self.root + '/test.txt' | |
else: | |
image_sets_file = self.root + '/test.txt' | |
self.ids = ['1.hdf5']#VOCDataset._read_image_ids(image_sets_file) | |
self.keep_difficult = keep_difficult | |
# if the labels file exists, read in the class names | |
label_file_name = self.root + "labels.txt" | |
if os.path.isfile(label_file_name): | |
class_string = "" | |
with open(label_file_name, 'r') as infile: | |
for line in infile: | |
class_string += line.rstrip() | |
# classes should be a comma separated list | |
classes = class_string.split(',') | |
# prepend BACKGROUND as first class | |
classes.insert(0, 'BACKGROUND') | |
classes = [elem.replace(" ", "") for elem in classes] | |
self.class_names = tuple(classes) | |
logging.info("VOC Labels read from file: " + str(self.class_names)) | |
else: | |
logging.info("No labels file, using default VOC classes.") | |
self.class_names = ('BACKGROUND', | |
'face') | |
self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)} | |
# def __getitem__(self, index): | |
# image_id = self.ids[index] | |
# boxes, labels, is_difficult = self._get_annotation(image_id) | |
# if not self.keep_difficult: | |
# boxes = boxes[is_difficult == 0] | |
# labels = labels[is_difficult == 0] | |
# image = self._read_image(image_id) | |
# if self.transform: | |
# image, boxes, labels = self.transform(image, boxes, labels) | |
# if self.target_transform: | |
# boxes, labels = self.target_transform(boxes, labels) | |
# return image, boxes, labels | |
def __getitem__(self, index): | |
num_per_shared = 3 | |
file_idx = index // num_per_shared | |
idx_in_file = index % num_per_shared | |
hdf_path = os.path.join(self.root, self.ids[file_idx]) | |
with h5py.File(hdf_path, 'r') as f: | |
boxes = f[str(idx_in_file) + '_boxes'] | |
is_difficult = f[str(idx_in_file) + '_difficult'] | |
image = f[str(idx_in_file) + '_image'] | |
labels = f[str(idx_in_file) + 'labels'] | |
if not self.keep_difficult: | |
boxes = boxes[is_difficult == 0] | |
labels = labels[is_difficult == 0] | |
if self.transform: | |
image, boxes, labels = self.transform(image, boxes, labels) | |
if self.target_transform: | |
boxes, labels = self.target_transform(boxes, labels) | |
return image, boxes, labels | |
def get_image(self, index): | |
image_id = self.ids[index] | |
image = self._read_image(image_id) | |
if self.transform: | |
image, _ = self.transform(image) | |
return image | |
def get_annotation(self, index): | |
image_id = self.ids[index] | |
return image_id, self._get_annotation(image_id) | |
def __len__(self): | |
total = 0 | |
# for file in self.ids: | |
# hdf_path = os.path.join(self.root, file) | |
# f = h5py.File(hdf_path, 'r') | |
# total += len(f.keys()) | |
return total // 4 | |
def _read_image_ids(image_sets_file): | |
ids = [] | |
with open(image_sets_file) as f: | |
for line in f: | |
ids.append(line.rstrip()) | |
return ids | |
def _get_annotation(self, image_id): | |
annotation_file = self.root / f"Annotations/{image_id}.xml" | |
objects = ET.parse(annotation_file).findall("object") | |
boxes = [] | |
labels = [] | |
is_difficult = [] | |
for object in objects: | |
class_name = object.find('name').text.lower().strip() | |
# we're only concerned with clases in our list | |
if class_name in self.class_dict: | |
bbox = object.find('bndbox') | |
# VOC dataset format follows Matlab, in which indexes start from 0 | |
x1 = float(bbox.find('xmin').text) - 1 | |
y1 = float(bbox.find('ymin').text) - 1 | |
x2 = float(bbox.find('xmax').text) - 1 | |
y2 = float(bbox.find('ymax').text) - 1 | |
boxes.append([x1, y1, x2, y2]) | |
labels.append(self.class_dict[class_name]) | |
is_difficult_str = object.find('difficult').text | |
is_difficult.append(int(is_difficult_str) if is_difficult_str else 0) | |
return (np.array(boxes, dtype=np.float32), | |
np.array(labels, dtype=np.int64), | |
np.array(is_difficult, dtype=np.uint8)) | |
def _read_image(self, image_id): | |
image_file = self.root / f"JPEGImages/{image_id}.jpg" | |
image = cv2.imread(str(image_file)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |