promptsai's picture
Upload 30 files
023485e verified
raw
history blame
8.54 kB
from PIL import Image
import torch.utils.data as data
import os
from glob import glob
import torch
import torchvision.transforms.functional as F
from torchvision import transforms
import random
import numpy as np
import scipy.io as sio
def random_crop(im_h, im_w, crop_h, crop_w):
res_h = im_h - crop_h
res_w = im_w - crop_w
i = random.randint(0, res_h)
j = random.randint(0, res_w)
return i, j, crop_h, crop_w
def gen_discrete_map(im_height, im_width, points):
"""
func: generate the discrete map.
points: [num_gt, 2], for each row: [width, height]
"""
discrete_map = np.zeros([im_height, im_width], dtype=np.float32)
h, w = discrete_map.shape[:2]
num_gt = points.shape[0]
if num_gt == 0:
return discrete_map
# fast create discrete map
points_np = np.array(points).round().astype(int)
p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int))
p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int))
p_index = torch.from_numpy(p_h* im_width + p_w)
discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy()
''' slow method
for p in points:
p = np.round(p).astype(int)
p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0])
discrete_map[p[0], p[1]] += 1
'''
assert np.sum(discrete_map) == num_gt
return discrete_map
class Base(data.Dataset):
def __init__(self, root_path, crop_size, downsample_ratio=8):
self.root_path = root_path
self.c_size = crop_size
self.d_ratio = downsample_ratio
assert self.c_size % self.d_ratio == 0
self.dc_size = self.c_size // self.d_ratio
self.trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __len__(self):
pass
def __getitem__(self, item):
pass
def train_transform(self, img, keypoints):
wd, ht = img.size
st_size = 1.0 * min(wd, ht)
assert st_size >= self.c_size
assert len(keypoints) >= 0
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
img = F.crop(img, i, j, h, w)
if len(keypoints) > 0:
keypoints = keypoints - [j, i]
idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
(keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
keypoints = keypoints[idx_mask]
else:
keypoints = np.empty([0, 2])
gt_discrete = gen_discrete_map(h, w, keypoints)
down_w = w // self.d_ratio
down_h = h // self.d_ratio
gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
assert np.sum(gt_discrete) == len(keypoints)
if len(keypoints) > 0:
if random.random() > 0.5:
img = F.hflip(img)
gt_discrete = np.fliplr(gt_discrete)
keypoints[:, 0] = w - keypoints[:, 0]
else:
if random.random() > 0.5:
img = F.hflip(img)
gt_discrete = np.fliplr(gt_discrete)
gt_discrete = np.expand_dims(gt_discrete, 0)
return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(
gt_discrete.copy()).float()
class Crowd_qnrf(Base):
def __init__(self, root_path, crop_size,
downsample_ratio=8,
method='train'):
super().__init__(root_path, crop_size, downsample_ratio)
self.method = method
self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg')))
print('number of img: {}'.format(len(self.im_list)))
if method not in ['train', 'val']:
raise Exception("not implement")
def __len__(self):
return len(self.im_list)
def __getitem__(self, item):
img_path = self.im_list[item]
gd_path = img_path.replace('jpg', 'npy')
img = Image.open(img_path).convert('RGB')
if self.method == 'train':
keypoints = np.load(gd_path)
return self.train_transform(img, keypoints)
elif self.method == 'val':
keypoints = np.load(gd_path)
img = self.trans(img)
name = os.path.basename(img_path).split('.')[0]
return img, len(keypoints), name
class Crowd_nwpu(Base):
def __init__(self, root_path, crop_size,
downsample_ratio=8,
method='train'):
super().__init__(root_path, crop_size, downsample_ratio)
self.method = method
self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg')))
print('number of img: {}'.format(len(self.im_list)))
if method not in ['train', 'val', 'test']:
raise Exception("not implement")
def __len__(self):
return len(self.im_list)
def __getitem__(self, item):
img_path = self.im_list[item]
gd_path = img_path.replace('jpg', 'npy')
img = Image.open(img_path).convert('RGB')
if self.method == 'train':
keypoints = np.load(gd_path)
return self.train_transform(img, keypoints)
elif self.method == 'val':
keypoints = np.load(gd_path)
img = self.trans(img)
name = os.path.basename(img_path).split('.')[0]
return img, len(keypoints), name
elif self.method == 'test':
img = self.trans(img)
name = os.path.basename(img_path).split('.')[0]
return img, name
class Crowd_sh(Base):
def __init__(self, root_path, crop_size,
downsample_ratio=8,
method='train'):
super().__init__(root_path, crop_size, downsample_ratio)
self.method = method
if method not in ['train', 'val']:
raise Exception("not implement")
self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
print('number of img: {}'.format(len(self.im_list)))
def __len__(self):
return len(self.im_list)
def __getitem__(self, item):
img_path = self.im_list[item]
name = os.path.basename(img_path).split('.')[0]
gd_path = os.path.join(self.root_path, 'ground-truth', 'GT_{}.mat'.format(name))
img = Image.open(img_path).convert('RGB')
keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0]
if self.method == 'train':
return self.train_transform(img, keypoints)
elif self.method == 'val':
img = self.trans(img)
return img, len(keypoints), name
def train_transform(self, img, keypoints):
wd, ht = img.size
st_size = 1.0 * min(wd, ht)
# resize the image to fit the crop size
if st_size < self.c_size:
rr = 1.0 * self.c_size / st_size
wd = round(wd * rr)
ht = round(ht * rr)
st_size = 1.0 * min(wd, ht)
img = img.resize((wd, ht), Image.BICUBIC)
keypoints = keypoints * rr
assert st_size >= self.c_size, print(wd, ht)
assert len(keypoints) >= 0
i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
img = F.crop(img, i, j, h, w)
if len(keypoints) > 0:
keypoints = keypoints - [j, i]
idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
(keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
keypoints = keypoints[idx_mask]
else:
keypoints = np.empty([0, 2])
gt_discrete = gen_discrete_map(h, w, keypoints)
down_w = w // self.d_ratio
down_h = h // self.d_ratio
gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
assert np.sum(gt_discrete) == len(keypoints)
if len(keypoints) > 0:
if random.random() > 0.5:
img = F.hflip(img)
gt_discrete = np.fliplr(gt_discrete)
keypoints[:, 0] = w - keypoints[:, 0] - 1
else:
if random.random() > 0.5:
img = F.hflip(img)
gt_discrete = np.fliplr(gt_discrete)
gt_discrete = np.expand_dims(gt_discrete, 0)
return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(
gt_discrete.copy()).float()