BOPBTL / Face_Enhancement /data /base_dataset.py
manhkhanhUIT's picture
Add code
7fab858
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess_mode == "resize_and_crop":
new_h = new_w = opt.load_size
elif opt.preprocess_mode == "scale_width_and_crop":
new_w = opt.load_size
new_h = opt.load_size * h // w
elif opt.preprocess_mode == "scale_shortside_and_crop":
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(opt.load_size * ls / ss)
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {"crop_pos": (x, y), "flip": flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
transform_list = []
if "resize" in opt.preprocess_mode:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, interpolation=method))
elif "scale_width" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
elif "scale_shortside" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
if "crop" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
if opt.preprocess_mode == "none":
base = 32
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.preprocess_mode == "fixed":
w = opt.crop_size
h = round(opt.crop_size / opt.aspect_ratio)
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __resize(img, w, h, method=Image.BICUBIC):
return img.resize((w, h), method)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_width:
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __scale_shortside(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
width_is_shorter = ow == ss
if ss == target_width:
return img
ls = int(target_width * ls / ss)
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
return img.resize((nw, nh), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
return img.crop((x1, y1, x1 + tw, y1 + th))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img