import os import random import torch import numpy as np from torch.utils.data import Dataset from PIL import Image import cv2 import glob import scipy.io as io class SHHA(Dataset): def __init__(self, data_root, transform=None, train=False, patch=False, flip=False): self.root_path = data_root self.train_lists = "shanghai_tech_part_a_train.list" self.eval_list = "shanghai_tech_part_a_test.list" # there may exist multiple list files self.img_list_file = self.train_lists.split(',') if train: self.img_list_file = self.train_lists.split(',') else: self.img_list_file = self.eval_list.split(',') self.img_map = {} self.img_list = [] # loads the image/gt pairs for _, train_list in enumerate(self.img_list_file): train_list = train_list.strip() with open(os.path.join(self.root_path, train_list)) as fin: for line in fin: if len(line) < 2: continue line = line.strip().split() self.img_map[os.path.join(self.root_path, line[0].strip())] = \ os.path.join(self.root_path, line[1].strip()) self.img_list = sorted(list(self.img_map.keys())) # number of samples self.nSamples = len(self.img_list) self.transform = transform self.train = train self.patch = patch self.flip = flip def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' img_path = self.img_list[index] gt_path = self.img_map[img_path] # load image and ground truth img, point = load_data((img_path, gt_path), self.train) # applu augumentation if self.transform is not None: img = self.transform(img) if self.train: # data augmentation -> random scale scale_range = [0.7, 1.3] min_size = min(img.shape[1:]) scale = random.uniform(*scale_range) # scale the image and points if scale * min_size > 128: img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0) point *= scale # random crop augumentaiton if self.train and self.patch: img, point = random_crop(img, point) for i, _ in enumerate(point): point[i] = torch.Tensor(point[i]) # random flipping if random.random() > 0.5 and self.train and self.flip: # random flip img = torch.Tensor(img[:, :, :, ::-1].copy()) for i, _ in enumerate(point): point[i][:, 0] = 128 - point[i][:, 0] if not self.train: point = [point] img = torch.Tensor(img) # pack up related infos target = [{} for i in range(len(point))] for i, _ in enumerate(point): target[i]['point'] = torch.Tensor(point[i]) image_id = int(img_path.split('/')[-1].split('.')[0].split('_')[-1]) image_id = torch.Tensor([image_id]).long() target[i]['image_id'] = image_id target[i]['labels'] = torch.ones([point[i].shape[0]]).long() return img, target def load_data(img_gt_path, train): img_path, gt_path = img_gt_path # load the images img = cv2.imread(img_path) img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # load ground truth points points = [] with open(gt_path) as f_label: for line in f_label: x = float(line.strip().split(' ')[0]) y = float(line.strip().split(' ')[1]) points.append([x, y]) return img, np.array(points) # random crop augumentation def random_crop(img, den, num_patch=4): half_h = 128 half_w = 128 result_img = np.zeros([num_patch, img.shape[0], half_h, half_w]) result_den = [] # crop num_patch for each image for i in range(num_patch): start_h = random.randint(0, img.size(1) - half_h) start_w = random.randint(0, img.size(2) - half_w) end_h = start_h + half_h end_w = start_w + half_w # copy the cropped rect result_img[i] = img[:, start_h:end_h, start_w:end_w] # copy the cropped points idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h) # shift the corrdinates record_den = den[idx] record_den[:, 0] -= start_w record_den[:, 1] -= start_h result_den.append(record_den) return result_img, result_den