Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch.utils.data as data | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import numpy as np | |
import random | |
class BaseDataset(data.Dataset): | |
def __init__(self): | |
super(BaseDataset, self).__init__() | |
def modify_commandline_options(parser, is_train): | |
return parser | |
def initialize(self, opt): | |
pass | |
def get_params(opt, size): | |
w, h = size | |
new_h = h | |
new_w = w | |
if opt.preprocess_mode == "resize_and_crop": | |
new_h = new_w = opt.load_size | |
elif opt.preprocess_mode == "scale_width_and_crop": | |
new_w = opt.load_size | |
new_h = opt.load_size * h // w | |
elif opt.preprocess_mode == "scale_shortside_and_crop": | |
ss, ls = min(w, h), max(w, h) # shortside and longside | |
width_is_shorter = w == ss | |
ls = int(opt.load_size * ls / ss) | |
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) | |
x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) | |
y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) | |
flip = random.random() > 0.5 | |
return {"crop_pos": (x, y), "flip": flip} | |
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): | |
transform_list = [] | |
if "resize" in opt.preprocess_mode: | |
osize = [opt.load_size, opt.load_size] | |
transform_list.append(transforms.Resize(osize, interpolation=method)) | |
elif "scale_width" in opt.preprocess_mode: | |
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) | |
elif "scale_shortside" in opt.preprocess_mode: | |
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) | |
if "crop" in opt.preprocess_mode: | |
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size))) | |
if opt.preprocess_mode == "none": | |
base = 32 | |
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) | |
if opt.preprocess_mode == "fixed": | |
w = opt.crop_size | |
h = round(opt.crop_size / opt.aspect_ratio) | |
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) | |
if opt.isTrain and not opt.no_flip: | |
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"]))) | |
if toTensor: | |
transform_list += [transforms.ToTensor()] | |
if normalize: | |
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
return transforms.Compose(transform_list) | |
def normalize(): | |
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
def __resize(img, w, h, method=Image.BICUBIC): | |
return img.resize((w, h), method) | |
def __make_power_2(img, base, method=Image.BICUBIC): | |
ow, oh = img.size | |
h = int(round(oh / base) * base) | |
w = int(round(ow / base) * base) | |
if (h == oh) and (w == ow): | |
return img | |
return img.resize((w, h), method) | |
def __scale_width(img, target_width, method=Image.BICUBIC): | |
ow, oh = img.size | |
if ow == target_width: | |
return img | |
w = target_width | |
h = int(target_width * oh / ow) | |
return img.resize((w, h), method) | |
def __scale_shortside(img, target_width, method=Image.BICUBIC): | |
ow, oh = img.size | |
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside | |
width_is_shorter = ow == ss | |
if ss == target_width: | |
return img | |
ls = int(target_width * ls / ss) | |
nw, nh = (ss, ls) if width_is_shorter else (ls, ss) | |
return img.resize((nw, nh), method) | |
def __crop(img, pos, size): | |
ow, oh = img.size | |
x1, y1 = pos | |
tw = th = size | |
return img.crop((x1, y1, x1 + tw, y1 + th)) | |
def __flip(img, flip): | |
if flip: | |
return img.transpose(Image.FLIP_LEFT_RIGHT) | |
return img | |