exorcist123's picture
add crowd counting demo
f4634b9
raw
history blame
913 Bytes
import torchvision.transforms as standard_transforms
from .SHHA import SHHA
# DeNormalize used to get original images
class DeNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
def loading_data(data_root):
# the pre-proccssing transform
transform = standard_transforms.Compose([
standard_transforms.ToTensor(),
standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# create the training dataset
train_set = SHHA(data_root, train=True, transform=transform, patch=True, flip=True)
# create the validation dataset
val_set = SHHA(data_root, train=False, transform=transform)
return train_set, val_set