exorcist123's picture
add crowd counting demo
f4634b9
import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import cv2
import glob
import scipy.io as io
class SHHA(Dataset):
def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):
self.root_path = data_root
self.train_lists = "shanghai_tech_part_a_train.list"
self.eval_list = "shanghai_tech_part_a_test.list"
# there may exist multiple list files
self.img_list_file = self.train_lists.split(',')
if train:
self.img_list_file = self.train_lists.split(',')
else:
self.img_list_file = self.eval_list.split(',')
self.img_map = {}
self.img_list = []
# loads the image/gt pairs
for _, train_list in enumerate(self.img_list_file):
train_list = train_list.strip()
with open(os.path.join(self.root_path, train_list)) as fin:
for line in fin:
if len(line) < 2:
continue
line = line.strip().split()
self.img_map[os.path.join(self.root_path, line[0].strip())] = \
os.path.join(self.root_path, line[1].strip())
self.img_list = sorted(list(self.img_map.keys()))
# number of samples
self.nSamples = len(self.img_list)
self.transform = transform
self.train = train
self.patch = patch
self.flip = flip
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
img_path = self.img_list[index]
gt_path = self.img_map[img_path]
# load image and ground truth
img, point = load_data((img_path, gt_path), self.train)
# applu augumentation
if self.transform is not None:
img = self.transform(img)
if self.train:
# data augmentation -> random scale
scale_range = [0.7, 1.3]
min_size = min(img.shape[1:])
scale = random.uniform(*scale_range)
# scale the image and points
if scale * min_size > 128:
img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)
point *= scale
# random crop augumentaiton
if self.train and self.patch:
img, point = random_crop(img, point)
for i, _ in enumerate(point):
point[i] = torch.Tensor(point[i])
# random flipping
if random.random() > 0.5 and self.train and self.flip:
# random flip
img = torch.Tensor(img[:, :, :, ::-1].copy())
for i, _ in enumerate(point):
point[i][:, 0] = 128 - point[i][:, 0]
if not self.train:
point = [point]
img = torch.Tensor(img)
# pack up related infos
target = [{} for i in range(len(point))]
for i, _ in enumerate(point):
target[i]['point'] = torch.Tensor(point[i])
image_id = int(img_path.split('/')[-1].split('.')[0].split('_')[-1])
image_id = torch.Tensor([image_id]).long()
target[i]['image_id'] = image_id
target[i]['labels'] = torch.ones([point[i].shape[0]]).long()
return img, target
def load_data(img_gt_path, train):
img_path, gt_path = img_gt_path
# load the images
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# load ground truth points
points = []
with open(gt_path) as f_label:
for line in f_label:
x = float(line.strip().split(' ')[0])
y = float(line.strip().split(' ')[1])
points.append([x, y])
return img, np.array(points)
# random crop augumentation
def random_crop(img, den, num_patch=4):
half_h = 128
half_w = 128
result_img = np.zeros([num_patch, img.shape[0], half_h, half_w])
result_den = []
# crop num_patch for each image
for i in range(num_patch):
start_h = random.randint(0, img.size(1) - half_h)
start_w = random.randint(0, img.size(2) - half_w)
end_h = start_h + half_h
end_w = start_w + half_w
# copy the cropped rect
result_img[i] = img[:, start_h:end_h, start_w:end_w]
# copy the cropped points
idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)
# shift the corrdinates
record_den = den[idx]
record_den[:, 0] -= start_w
record_den[:, 1] -= start_h
result_den.append(record_den)
return result_img, result_den