import logging import os import pathlib import xml.etree.ElementTree as ET import h5py import cv2 import numpy as np import lmdb from .caffe_pb2 import * class VOCDataset: def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None): """Dataset for VOC data. Args: root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories: Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject. """ self.root = "D:/test" self.transform = transform self.target_transform = target_transform if is_test: image_sets_file = self.root + '/test.txt' else: image_sets_file = self.root + '/test.txt' self.ids = ['1.hdf5']#VOCDataset._read_image_ids(image_sets_file) self.keep_difficult = keep_difficult # if the labels file exists, read in the class names label_file_name = self.root + "labels.txt" if os.path.isfile(label_file_name): class_string = "" with open(label_file_name, 'r') as infile: for line in infile: class_string += line.rstrip() # classes should be a comma separated list classes = class_string.split(',') # prepend BACKGROUND as first class classes.insert(0, 'BACKGROUND') classes = [elem.replace(" ", "") for elem in classes] self.class_names = tuple(classes) logging.info("VOC Labels read from file: " + str(self.class_names)) else: logging.info("No labels file, using default VOC classes.") self.class_names = ('BACKGROUND', 'face') self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)} # def __getitem__(self, index): # image_id = self.ids[index] # boxes, labels, is_difficult = self._get_annotation(image_id) # if not self.keep_difficult: # boxes = boxes[is_difficult == 0] # labels = labels[is_difficult == 0] # image = self._read_image(image_id) # if self.transform: # image, boxes, labels = self.transform(image, boxes, labels) # if self.target_transform: # boxes, labels = self.target_transform(boxes, labels) # return image, boxes, labels def __getitem__(self, index): num_per_shared = 3 file_idx = index // num_per_shared idx_in_file = index % num_per_shared hdf_path = os.path.join(self.root, self.ids[file_idx]) with h5py.File(hdf_path, 'r') as f: boxes = f[str(idx_in_file) + '_boxes'] is_difficult = f[str(idx_in_file) + '_difficult'] image = f[str(idx_in_file) + '_image'] labels = f[str(idx_in_file) + 'labels'] if not self.keep_difficult: boxes = boxes[is_difficult == 0] labels = labels[is_difficult == 0] if self.transform: image, boxes, labels = self.transform(image, boxes, labels) if self.target_transform: boxes, labels = self.target_transform(boxes, labels) return image, boxes, labels def get_image(self, index): image_id = self.ids[index] image = self._read_image(image_id) if self.transform: image, _ = self.transform(image) return image def get_annotation(self, index): image_id = self.ids[index] return image_id, self._get_annotation(image_id) def __len__(self): total = 0 # for file in self.ids: # hdf_path = os.path.join(self.root, file) # f = h5py.File(hdf_path, 'r') # total += len(f.keys()) return total // 4 @staticmethod def _read_image_ids(image_sets_file): ids = [] with open(image_sets_file) as f: for line in f: ids.append(line.rstrip()) return ids def _get_annotation(self, image_id): annotation_file = self.root / f"Annotations/{image_id}.xml" objects = ET.parse(annotation_file).findall("object") boxes = [] labels = [] is_difficult = [] for object in objects: class_name = object.find('name').text.lower().strip() # we're only concerned with clases in our list if class_name in self.class_dict: bbox = object.find('bndbox') # VOC dataset format follows Matlab, in which indexes start from 0 x1 = float(bbox.find('xmin').text) - 1 y1 = float(bbox.find('ymin').text) - 1 x2 = float(bbox.find('xmax').text) - 1 y2 = float(bbox.find('ymax').text) - 1 boxes.append([x1, y1, x2, y2]) labels.append(self.class_dict[class_name]) is_difficult_str = object.find('difficult').text is_difficult.append(int(is_difficult_str) if is_difficult_str else 0) return (np.array(boxes, dtype=np.float32), np.array(labels, dtype=np.int64), np.array(is_difficult, dtype=np.uint8)) def _read_image(self, image_id): image_file = self.root / f"JPEGImages/{image_id}.jpg" image = cv2.imread(str(image_file)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image