File size: 5,241 Bytes
9842c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import copy
import glob
import os
from multiprocessing.dummy import Pool as ThreadPool
from PIL import Image
from torchvision.transforms.functional import to_tensor
from ..Models import *
class ImageSplitter:
# key points:
# Boarder padding and over-lapping img splitting to avoid the instability of edge value
# Thanks Waifu2x's autorh nagadomi for suggestions (https://github.com/nagadomi/waifu2x/issues/238)
def __init__(self, seg_size=48, scale_factor=2, boarder_pad_size=3):
self.seg_size = seg_size
self.scale_factor = scale_factor
self.pad_size = boarder_pad_size
self.height = 0
self.width = 0
self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="bilinear")
def split_img_tensor(self, pil_img, scale_method=Image.BILINEAR, img_pad=0):
# resize image and convert them into tensor
img_tensor = to_tensor(pil_img).unsqueeze(0)
img_tensor = nn.ReplicationPad2d(self.pad_size)(img_tensor)
batch, channel, height, width = img_tensor.size()
self.height = height
self.width = width
if scale_method is not None:
img_up = pil_img.resize(
(2 * pil_img.size[0], 2 * pil_img.size[1]), scale_method
)
img_up = to_tensor(img_up).unsqueeze(0)
img_up = nn.ReplicationPad2d(self.pad_size * self.scale_factor)(img_up)
patch_box = []
# avoid the residual part is smaller than the padded size
if (
height % self.seg_size < self.pad_size
or width % self.seg_size < self.pad_size
):
self.seg_size += self.scale_factor * self.pad_size
# split image into over-lapping pieces
for i in range(self.pad_size, height, self.seg_size):
for j in range(self.pad_size, width, self.seg_size):
part = img_tensor[
:,
:,
(i - self.pad_size) : min(
i + self.pad_size + self.seg_size, height
),
(j - self.pad_size) : min(j + self.pad_size + self.seg_size, width),
]
if img_pad > 0:
part = nn.ZeroPad2d(img_pad)(part)
if scale_method is not None:
# part_up = self.upsampler(part)
part_up = img_up[
:,
:,
self.scale_factor
* (i - self.pad_size) : min(
i + self.pad_size + self.seg_size, height
)
* self.scale_factor,
self.scale_factor
* (j - self.pad_size) : min(
j + self.pad_size + self.seg_size, width
)
* self.scale_factor,
]
patch_box.append((part, part_up))
else:
patch_box.append(part)
return patch_box
def merge_img_tensor(self, list_img_tensor):
out = torch.zeros(
(1, 3, self.height * self.scale_factor, self.width * self.scale_factor)
)
img_tensors = copy.copy(list_img_tensor)
rem = self.pad_size * 2
pad_size = self.scale_factor * self.pad_size
seg_size = self.scale_factor * self.seg_size
height = self.scale_factor * self.height
width = self.scale_factor * self.width
for i in range(pad_size, height, seg_size):
for j in range(pad_size, width, seg_size):
part = img_tensors.pop(0)
part = part[:, :, rem:-rem, rem:-rem]
# might have error
if len(part.size()) > 3:
_, _, p_h, p_w = part.size()
out[:, :, i : i + p_h, j : j + p_w] = part
# out[:,:,
# self.scale_factor*i:self.scale_factor*i+p_h,
# self.scale_factor*j:self.scale_factor*j+p_w] = part
out = out[:, :, rem:-rem, rem:-rem]
return out
def load_single_image(
img_file,
up_scale=False,
up_scale_factor=2,
up_scale_method=Image.BILINEAR,
zero_padding=False,
):
img = Image.open(img_file).convert("RGB")
out = to_tensor(img).unsqueeze(0)
if zero_padding:
out = nn.ZeroPad2d(zero_padding)(out)
if up_scale:
size = tuple(map(lambda x: x * up_scale_factor, img.size))
img_up = img.resize(size, up_scale_method)
img_up = to_tensor(img_up).unsqueeze(0)
out = (out, img_up)
return out
def standardize_img_format(img_folder):
def process(img_file):
img_path = os.path.dirname(img_file)
img_name, _ = os.path.basename(img_file).split(".")
out = os.path.join(img_path, img_name + ".JPEG")
os.rename(img_file, out)
list_imgs = []
for i in ["png", "jpeg", "jpg"]:
list_imgs.extend(glob.glob(img_folder + "**/*." + i, recursive=True))
print("Found {} images.".format(len(list_imgs)))
pool = ThreadPool(4)
pool.map(process, list_imgs)
pool.close()
pool.join()
|