Spaces:
Sleeping
Sleeping
File size: 4,752 Bytes
f4634b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import cv2
import glob
import scipy.io as io
class SHHA(Dataset):
def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):
self.root_path = data_root
self.train_lists = "shanghai_tech_part_a_train.list"
self.eval_list = "shanghai_tech_part_a_test.list"
# there may exist multiple list files
self.img_list_file = self.train_lists.split(',')
if train:
self.img_list_file = self.train_lists.split(',')
else:
self.img_list_file = self.eval_list.split(',')
self.img_map = {}
self.img_list = []
# loads the image/gt pairs
for _, train_list in enumerate(self.img_list_file):
train_list = train_list.strip()
with open(os.path.join(self.root_path, train_list)) as fin:
for line in fin:
if len(line) < 2:
continue
line = line.strip().split()
self.img_map[os.path.join(self.root_path, line[0].strip())] = \
os.path.join(self.root_path, line[1].strip())
self.img_list = sorted(list(self.img_map.keys()))
# number of samples
self.nSamples = len(self.img_list)
self.transform = transform
self.train = train
self.patch = patch
self.flip = flip
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
img_path = self.img_list[index]
gt_path = self.img_map[img_path]
# load image and ground truth
img, point = load_data((img_path, gt_path), self.train)
# applu augumentation
if self.transform is not None:
img = self.transform(img)
if self.train:
# data augmentation -> random scale
scale_range = [0.7, 1.3]
min_size = min(img.shape[1:])
scale = random.uniform(*scale_range)
# scale the image and points
if scale * min_size > 128:
img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)
point *= scale
# random crop augumentaiton
if self.train and self.patch:
img, point = random_crop(img, point)
for i, _ in enumerate(point):
point[i] = torch.Tensor(point[i])
# random flipping
if random.random() > 0.5 and self.train and self.flip:
# random flip
img = torch.Tensor(img[:, :, :, ::-1].copy())
for i, _ in enumerate(point):
point[i][:, 0] = 128 - point[i][:, 0]
if not self.train:
point = [point]
img = torch.Tensor(img)
# pack up related infos
target = [{} for i in range(len(point))]
for i, _ in enumerate(point):
target[i]['point'] = torch.Tensor(point[i])
image_id = int(img_path.split('/')[-1].split('.')[0].split('_')[-1])
image_id = torch.Tensor([image_id]).long()
target[i]['image_id'] = image_id
target[i]['labels'] = torch.ones([point[i].shape[0]]).long()
return img, target
def load_data(img_gt_path, train):
img_path, gt_path = img_gt_path
# load the images
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# load ground truth points
points = []
with open(gt_path) as f_label:
for line in f_label:
x = float(line.strip().split(' ')[0])
y = float(line.strip().split(' ')[1])
points.append([x, y])
return img, np.array(points)
# random crop augumentation
def random_crop(img, den, num_patch=4):
half_h = 128
half_w = 128
result_img = np.zeros([num_patch, img.shape[0], half_h, half_w])
result_den = []
# crop num_patch for each image
for i in range(num_patch):
start_h = random.randint(0, img.size(1) - half_h)
start_w = random.randint(0, img.size(2) - half_w)
end_h = start_h + half_h
end_w = start_w + half_w
# copy the cropped rect
result_img[i] = img[:, start_h:end_h, start_w:end_w]
# copy the cropped points
idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)
# shift the corrdinates
record_den = den[idx]
record_den[:, 0] -= start_w
record_den[:, 1] -= start_h
result_den.append(record_den)
return result_img, result_den |